From cf8f776494e122f323f67851e3e59d3937138144 Mon Sep 17 00:00:00 2001 From: umlauf <> Date: Fri, 24 Feb 2023 21:32:29 +0100 Subject: [PATCH] Update CEL --- Code/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Code/train.py b/Code/train.py index a215b15..acd07f2 100644 --- a/Code/train.py +++ b/Code/train.py @@ -178,7 +178,7 @@ def cross_entropy(logits, target): if value == 1 or value == 0: #check if non-mixed label one_hot = torch.tensor([1-value,value], device='cuda:0') #creating one-hot vector e.g. [0. ,1.] loss_clear_labels = -((one_hot[0] * logprobs[0][0]) + (one_hot[1] * logprobs[0][1])) - #calculation with indexing (- 1-label * ) + #calculation with indexing results = torch.cat((loss_clear_labels.view(1), results), dim=0) else: mixed_vec = torch.tensor([value, 1-value]) #creating on-hot mixed vec. -- GitLab