Skip to content
Snippets Groups Projects
Commit 329db1da authored by umlauf's avatar umlauf
Browse files

back again

parent 9da4da9c
No related branches found
No related tags found
No related merge requests found
......@@ -225,7 +225,7 @@ if __name__ == "__main__":
parser.add_argument(
"-tb",
"--test_batch_size",
help="The batch size for the training process",
help="The batch size for the test process",
type=int,
default=16)
......
......@@ -103,16 +103,16 @@ def train(model, name, seed,gradient_accumulation_steps,mixup, threshold, lambda
#print("span output: ", span_output)
logits=model.classifier(span_output.detach()) #target_value?
print("logits: ", logits)
print("logits shape: ", list(logits.shape))
print("Newlabels: ", new_labels_batch)
print("labels shape: ", list(new_labels_batch.shape))
# print("logits: ", logits)
# print("logits shape: ", list(logits.shape))
# print("Newlabels: ", new_labels_batch)
# print("labels shape: ", list(new_labels_batch.shape))
preds = logits.view(-1, 2).to("cuda")
logits = logits.view(-1, 2).to("cuda")
target = new_labels_batch.view(-1).to("cuda")
# loss_2 = cross_entropy(preds, target)
loss_2 = cross_entropy(logits, target)
#loss_2 = SoftCrossEntropyLoss(logits.view(-1, 2).to("cuda"), new_labels_batch.view(-1).to("cuda"))
loss_2 = torch.nn.functional.cross_entropy(preds, target.long())
#loss_2 = torch.nn.functional.cross_entropy(preds, target.long())
print("MixUp Loss: ", loss_2)
#update entire model
loss_2.backward()
......@@ -137,6 +137,11 @@ def train(model, name, seed,gradient_accumulation_steps,mixup, threshold, lambda
return evaluation_test, evaluation_train
def cross_entropy(logits, target):
log_softmax = torch.nn.functional.log_softmax(logits, dim=1)
loss = -torch.sum(target * log_softmax, dim=1).mean()
return loss
# def SoftCrossEntropyLoss (input, target):
# logprobs = torch.nn.functional.log_softmax (input, dim = 1)
# return -(target * logprobs).sum() / input.shape[0]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment