Skip to content
Snippets Groups Projects
Commit 935a2b71 authored by umlauf's avatar umlauf
Browse files

CEL

parent 69c97729
No related branches found
No related tags found
No related merge requests found
......@@ -152,7 +152,6 @@ def train(model, name, imdb, seed,mixup,lambda_value, mixepoch, tmix, mixlayer,
return evaluation_test, evaluation_train
#BINARY
def cross_entropy(logits, target, l):
results = torch.tensor([], device='cuda')
for i in range (logits.shape[0]):
......@@ -165,9 +164,7 @@ def cross_entropy(logits, target, l):
one_hot = torch.tensor([1-value,value], device='cuda:0') #creating one-hot vector e.g. [0. ,1.]
#class 1 and 2 mixed
loss_clear_labels = -((one_hot[0] * logprobs[0][0]) + (one_hot[1] * logprobs[0][1]))
#print("Clear Labels: ", loss_clear_labels)
results = torch.cat((loss_clear_labels.view(1), results), dim=0)
#print("Results Clear Label: ", results)
else:
value_r = round(value, 1) #to make it equal to lambda_value e.g. 0.4
#Wert mit Flag
......@@ -185,13 +182,6 @@ def cross_entropy(logits, target, l):
#print("Batch Loss: ", batch_loss)
return batch_loss
#noch in eine if schleife (beide mixed labels)
#evt. nicht immer torch.cat aufrufen
#Matrix -> 2 dim
#pro Zeile Label wenn gleich -> Standard instanz
# wenn ungleich mixed
def mixup_function(batch_of_matrices, batch_of_labels, l):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment