diff --git a/Code/train.py b/Code/train.py index e0ce70646185116ce049971e5778019eede7e5dd..1dd7ad4a02b6b0c74990ae2742d31d24246afced 100644 --- a/Code/train.py +++ b/Code/train.py @@ -144,11 +144,14 @@ def cross_entropy(logits, target): Computes the cross-entropy loss between the predicted logits and the target labels. Args: - - logits (torch.Tensor): A tensor of shape (batch_size, num_classes) representing the predicted logits for each input example. - - target (torch.Tensor): A tensor of shape (batch_size,) representing the target labels for each input example. + - logits (torch.Tensor): A tensor of shape (batch_size, num_classes) + representing the predicted logits for each input example. + - target (torch.Tensor): A tensor of shape (batch_size,) + representing the target labels for each input example. Returns: - - batch_loss (torch.Tensor): A scalar tensor representing the average cross-entropy loss across the batch. + - batch_loss (torch.Tensor): A scalar tensor representing + the average cross-entropy loss across the batch. """ results = torch.tensor([], device='cuda') for i in range (logits.shape[0]): @@ -168,7 +171,7 @@ def cross_entropy(logits, target): loss_mixed_labels = -((mixed_vec[0] * logprobs[0][0]) + (mixed_vec[1] * logprobs[0][1])) #calculation for mixed with indexing results = torch.cat((loss_mixed_labels.view(1), results), dim=0)#append resultts to result tensor - print("ALL BATCH Results: ", results) + print("LOSS BATCH (Results): ", results) batch_loss = results.mean() #compute average of all results return batch_loss