Skip to content
Snippets Groups Projects
Commit 0522f82a authored by umlauf's avatar umlauf
Browse files

CEL small change

parent 9c9c7c43
No related branches found
No related tags found
No related merge requests found
......@@ -144,11 +144,14 @@ def cross_entropy(logits, target):
Computes the cross-entropy loss between the predicted logits and the target labels.
Args:
- logits (torch.Tensor): A tensor of shape (batch_size, num_classes) representing the predicted logits for each input example.
- target (torch.Tensor): A tensor of shape (batch_size,) representing the target labels for each input example.
- logits (torch.Tensor): A tensor of shape (batch_size, num_classes)
representing the predicted logits for each input example.
- target (torch.Tensor): A tensor of shape (batch_size,)
representing the target labels for each input example.
Returns:
- batch_loss (torch.Tensor): A scalar tensor representing the average cross-entropy loss across the batch.
- batch_loss (torch.Tensor): A scalar tensor representing
the average cross-entropy loss across the batch.
"""
results = torch.tensor([], device='cuda')
for i in range (logits.shape[0]):
......@@ -168,7 +171,7 @@ def cross_entropy(logits, target):
loss_mixed_labels = -((mixed_vec[0] * logprobs[0][0]) + (mixed_vec[1] * logprobs[0][1]))
#calculation for mixed with indexing
results = torch.cat((loss_mixed_labels.view(1), results), dim=0)#append resultts to result tensor
print("ALL BATCH Results: ", results)
print("LOSS BATCH (Results): ", results)
batch_loss = results.mean() #compute average of all results
return batch_loss
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment