Skip to content
Snippets Groups Projects
Commit 9607cd90 authored by umlauf's avatar umlauf
Browse files

CEL results

parent 61a33621
No related branches found
No related tags found
No related merge requests found
...@@ -40,9 +40,9 @@ for i in range(logits.shape[0]): ...@@ -40,9 +40,9 @@ for i in range(logits.shape[0]):
if value == 1 or value == 0: if value == 1 or value == 0:
one_hot = torch.tensor([1-value,value]) #creating one-hot vector e.g. [0. ,1.] one_hot = torch.tensor([1-value,value]) #creating one-hot vector e.g. [0. ,1.]
loss_clear_labels = -((one_hot[0] * logprobs[0][0]) + (one_hot[1] * logprobs[0][1])) loss_clear_labels = -((one_hot[0] * logprobs[0][0]) + (one_hot[1] * logprobs[0][1]))
print("Clear Labels: ", loss_clear_labels) print("Clear Labels Loss: ", loss_clear_labels)
results = torch.cat((loss_mixed_labels.view(1), results), dim=0) results = torch.cat((loss_mixed_labels.view(1), results), dim=0)
print("Results: ", results) print("Results Clear Labels: ", results)
else: else:
if value_r == l: if value_r == l:
mixed_vec = torch.tensor([l, 1-l]) mixed_vec = torch.tensor([l, 1-l])
...@@ -53,7 +53,7 @@ for i in range(logits.shape[0]): ...@@ -53,7 +53,7 @@ for i in range(logits.shape[0]):
loss_mixed_labels = -((mixed_vec[0] * logprobs[0][0]) + (mixed_vec[1] * logprobs[0][1])) loss_mixed_labels = -((mixed_vec[0] * logprobs[0][0]) + (mixed_vec[1] * logprobs[0][1]))
print("Loss Mixed Lables l: ", loss_mixed_labels) print("Loss Mixed Lables l: ", loss_mixed_labels)
results = torch.cat((loss_mixed_labels.view(1), results), dim=0) results = torch.cat((loss_mixed_labels.view(1), results), dim=0)
print("Results: ", results) print("Results Mixed 1: ", results)
else: else:
mixed_vec = torch.tensor([1-l, l]) mixed_vec = torch.tensor([1-l, l])
logprobs = torch.nn.functional.log_softmax(lg, dim=1) logprobs = torch.nn.functional.log_softmax(lg, dim=1)
...@@ -61,7 +61,7 @@ for i in range(logits.shape[0]): ...@@ -61,7 +61,7 @@ for i in range(logits.shape[0]):
loss_mixed_labels = -((mixed_vec[0] * logprobs[0][0]) + (mixed_vec[1] * logprobs[0][1])) loss_mixed_labels = -((mixed_vec[0] * logprobs[0][0]) + (mixed_vec[1] * logprobs[0][1]))
print("Mixed Labels 1-l: ", loss_mixed_labels) print("Mixed Labels 1-l: ", loss_mixed_labels)
results = torch.cat((loss_mixed_labels.view(1), results), dim=0) results = torch.cat((loss_mixed_labels.view(1), results), dim=0)
print("Results: ", results) print("Results Mixed 2: ", results)
print("Results all", results) print("Results all", results)
average_loss = results.mean() #compute the average average_loss = results.mean() #compute the average
print("Avg", average_loss) print("Avg", average_loss)
......
...@@ -160,16 +160,17 @@ def cross_entropy(logits, target, l): ...@@ -160,16 +160,17 @@ def cross_entropy(logits, target, l):
if value_r == l: if value_r == l:
#create vector: e.g. [l, 1-l] #create vector: e.g. [l, 1-l]
mixed_vec = torch.tensor([l, 1-l], device='cuda') mixed_vec = torch.tensor([l, 1-l], device='cuda')
loss_mixed_labels = -((mixed_vec[0] * logprobs[0][0]) + (mixed_vec[1] * logprobs[0][1])) loss_mixed_labels_1 = -((mixed_vec[0] * logprobs[0][0]) + (mixed_vec[1] * logprobs[0][1]))
print("Mixed Labels: ", loss_mixed_labels) print("Mixed Labels1: ", loss_mixed_labels_1)
results = torch.cat((loss_mixed_labels.view(1), results), dim=0) results = torch.cat((loss_mixed_labels_1.view(1), results), dim=0)
print("Results Mixed Label: ", results) print("Results Mixed Label1: ", results)
else: else:
mixed_vec = torch.tensor([1-l, l], device='cuda') mixed_vec = torch.tensor([1-l, l], device='cuda')
loss_mixed_labels = -((mixed_vec[0] * logprobs[0][0]) + (mixed_vec[1] * logprobs[0][1])) loss_mixed_labels_2 = -((mixed_vec[0] * logprobs[0][0]) + (mixed_vec[1] * logprobs[0][1]))
print("Mixed Labels: ", loss_mixed_labels) print("Mixed Labels2: ", loss_mixed_labels_2)
results = torch.cat((loss_mixed_labels.view(1), results), dim=0) results = torch.cat((loss_mixed_labels_2.view(1), results), dim=0)
print("Results Mixed Label: ", results) print("Results Mixed Label2: ", results)
print("ALL BATCH Results: ", results)
batch_loss = results.mean() #compute average batch_loss = results.mean() #compute average
print("Batch Loss: ", batch_loss) print("Batch Loss: ", batch_loss)
return batch_loss return batch_loss
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment