Skip to content
Snippets Groups Projects
Commit bb7addfb authored by umlauf's avatar umlauf
Browse files

print statements

parent aa7af7d2
No related branches found
No related tags found
No related merge requests found
......@@ -146,6 +146,7 @@ def train(model, name, seed,gradient_accumulation_steps,mixup, threshold, lambda
#pytorch Forum, try with dim=1 for targets
def cross_entropy(logits, target):
logprobs = torch.nn.functional.log_softmax (logits, dim = 1)
print("target for cross entropy: ", target)
loss = -torch.sum(target * logprobs, dim=1)
print("Logits Shape Index 0: ", logits.shape[0])
return loss.div(logits.shape[0])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment