diff --git a/Code/train.py b/Code/train.py index 1dd7ad4a02b6b0c74990ae2742d31d24246afced..55bdd7fc099cc75e096f1fa24cfd404567e02b42 100644 --- a/Code/train.py +++ b/Code/train.py @@ -165,8 +165,7 @@ def cross_entropy(logits, target): #calculation with indexing (- 1-label * ) results = torch.cat((loss_clear_labels.view(1), results), dim=0) else: - value_r = round(value, 1) #to make it equal to lambda_value e.g. 0.4 - mixed_vec = torch.tensor([value_r, 1-value_r]) #creating on-hot mixed vec. + mixed_vec = torch.tensor([value, 1-value]) #creating on-hot mixed vec. logprobs = torch.nn.functional.log_softmax(lg, dim=1)#logits in log probabilities loss_mixed_labels = -((mixed_vec[0] * logprobs[0][0]) + (mixed_vec[1] * logprobs[0][1])) #calculation for mixed with indexing