Skip to content
Snippets Groups Projects
Commit b67c061b authored by kulcsar's avatar kulcsar
Browse files
parents b7a9bc0a 715c3621
No related branches found
No related tags found
No related merge requests found
......@@ -152,7 +152,6 @@ def train(model, name, imdb, seed,mixup,lambda_value, mixepoch, tmix, mixlayer,
return evaluation_test, evaluation_train
#BINARY
def cross_entropy(logits, target, l):
results = torch.tensor([], device='cuda')
for i in range (logits.shape[0]):
......@@ -165,9 +164,7 @@ def cross_entropy(logits, target, l):
one_hot = torch.tensor([1-value,value], device='cuda:0') #creating one-hot vector e.g. [0. ,1.]
#class 1 and 2 mixed
loss_clear_labels = -((one_hot[0] * logprobs[0][0]) + (one_hot[1] * logprobs[0][1]))
#print("Clear Labels: ", loss_clear_labels)
results = torch.cat((loss_clear_labels.view(1), results), dim=0)
#print("Results Clear Label: ", results)
else:
value_r = round(value, 1) #to make it equal to lambda_value e.g. 0.4
#Wert mit Flag
......@@ -185,13 +182,6 @@ def cross_entropy(logits, target, l):
#print("Batch Loss: ", batch_loss)
return batch_loss
#noch in eine if schleife (beide mixed labels)
#evt. nicht immer torch.cat aufrufen
#Matrix -> 2 dim
#pro Zeile Label wenn gleich -> Standard instanz
# wenn ungleich mixed
def mixup_function(batch_of_matrices, batch_of_labels, l):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment