diff --git a/Code/train.py b/Code/train.py index 3236419e81d6e54573761b3430585cd394c0c215..e45ea60e4eb4779e0c7fbbce3354b58b42d3fc29 100644 --- a/Code/train.py +++ b/Code/train.py @@ -186,7 +186,7 @@ def cross_entropy(logits, target, l): if value == 1 or value == 0: #check if non-mixed label one_hot = torch.tensor([1-value,value], device='cuda:0') #creating one-hot vector e.g. [0. ,1.] loss_clear_labels = -((one_hot[0] * logprobs[0][0]) + (one_hot[1] * logprobs[0][1])) - #calculation with indexing (- 1-label * ) + #calculation with indexing results = torch.cat((loss_clear_labels.view(1), results), dim=0) else: mixed_vec = torch.tensor([l, 1-l]) #creating on-hot mixed vec.