From 0522f82aa24682d0ea8133e9088756e3938a8b17 Mon Sep 17 00:00:00 2001
From: umlauf <>
Date: Fri, 24 Feb 2023 13:45:43 +0100
Subject: [PATCH] CEL small change

---
 Code/train.py | 11 +++++++----
 1 file changed, 7 insertions(+), 4 deletions(-)

diff --git a/Code/train.py b/Code/train.py
index e0ce706..1dd7ad4 100644
--- a/Code/train.py
+++ b/Code/train.py
@@ -144,11 +144,14 @@ def cross_entropy(logits, target):
     Computes the cross-entropy loss between the predicted logits and the target labels.
     
     Args:
-    - logits (torch.Tensor): A tensor of shape (batch_size, num_classes) representing the predicted logits for each input example.
-    - target (torch.Tensor): A tensor of shape (batch_size,) representing the target labels for each input example.
+    - logits (torch.Tensor): A tensor of shape (batch_size, num_classes) 
+	  representing the predicted logits for each input example.
+    - target (torch.Tensor): A tensor of shape (batch_size,) 
+	  representing the target labels for each input example.
     
     Returns:
-    - batch_loss (torch.Tensor): A scalar tensor representing the average cross-entropy loss across the batch.
+    - batch_loss (torch.Tensor): A scalar tensor representing 
+	  the average cross-entropy loss across the batch.
     """
 	results = torch.tensor([], device='cuda')
 	for i in range (logits.shape[0]):
@@ -168,7 +171,7 @@ def cross_entropy(logits, target):
 			loss_mixed_labels = -((mixed_vec[0] * logprobs[0][0]) + (mixed_vec[1] * logprobs[0][1]))
 			#calculation for mixed with indexing
 			results = torch.cat((loss_mixed_labels.view(1), results), dim=0)#append resultts to result tensor
-	print("ALL BATCH Results: ", results)
+	print("LOSS BATCH (Results): ", results)
 	batch_loss = results.mean() #compute average of all results
 	return batch_loss
 
-- 
GitLab