diff --git a/Code/train.py b/Code/train.py index a215b15503c7748fe1e450158e2b7f26ccea4e22..acd07f2217d64ae2ef8565d73b0885ce1b6e61a4 100644 --- a/Code/train.py +++ b/Code/train.py @@ -178,7 +178,7 @@ def cross_entropy(logits, target): if value == 1 or value == 0: #check if non-mixed label one_hot = torch.tensor([1-value,value], device='cuda:0') #creating one-hot vector e.g. [0. ,1.] loss_clear_labels = -((one_hot[0] * logprobs[0][0]) + (one_hot[1] * logprobs[0][1])) - #calculation with indexing (- 1-label * ) + #calculation with indexing results = torch.cat((loss_clear_labels.view(1), results), dim=0) else: mixed_vec = torch.tensor([value, 1-value]) #creating on-hot mixed vec.