From 71855c8d912d46fe3884d4593f28e59013df9239 Mon Sep 17 00:00:00 2001 From: umlauf <> Date: Fri, 24 Feb 2023 13:43:55 +0100 Subject: [PATCH] Comments/Doc CEL --- Code/train.py | 39 +++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/Code/train.py b/Code/train.py index a0abfdb..e0ce706 100644 --- a/Code/train.py +++ b/Code/train.py @@ -139,34 +139,37 @@ def train(model, name, imdb, seed,mixup,lambda_value, mixepoch, tmix, mixlayer, return evaluation_test, evaluation_train -def cross_entropy(logits, target, l): +def cross_entropy(logits, target): + """ + Computes the cross-entropy loss between the predicted logits and the target labels. + + Args: + - logits (torch.Tensor): A tensor of shape (batch_size, num_classes) representing the predicted logits for each input example. + - target (torch.Tensor): A tensor of shape (batch_size,) representing the target labels for each input example. + + Returns: + - batch_loss (torch.Tensor): A scalar tensor representing the average cross-entropy loss across the batch. + """ results = torch.tensor([], device='cuda') for i in range (logits.shape[0]): - lg = logits[i:i+1,:] #comment to explain the process in this Code Line + lg = logits[i:i+1,:] # Extract a row from the logits tensor for the i-th input example in the batch t = target[i] - #makes the logits in log (base e) probabilities - logprobs = torch.nn.functional.log_softmax(lg, dim=1) - value = t.item() #gets Item (0. or 1.) - if value == 1 or value == 0: + logprobs = torch.nn.functional.log_softmax(lg, dim=1) #logits in log (base e) probabilities + value = t.item() #get scalar value + if value == 1 or value == 0: #check if non-mixed label one_hot = torch.tensor([1-value,value], device='cuda:0') #creating one-hot vector e.g. [0. ,1.] - #class 1 and 2 mixed loss_clear_labels = -((one_hot[0] * logprobs[0][0]) + (one_hot[1] * logprobs[0][1])) + #calculation with indexing (- 1-label * ) results = torch.cat((loss_clear_labels.view(1), results), dim=0) else: value_r = round(value, 1) #to make it equal to lambda_value e.g. 0.4 - #Wert mit Flag - mixed_vec = torch.tensor([value_r, 1-value_r]) - print("Mixed Vec: ", mixed_vec) - logprobs = torch.nn.functional.log_softmax(lg, dim=1) - print("Log:", logprobs) - #loss_mixed_labels = -torch.mul(mixed_vec, logprobs).sum() + mixed_vec = torch.tensor([value_r, 1-value_r]) #creating on-hot mixed vec. + logprobs = torch.nn.functional.log_softmax(lg, dim=1)#logits in log probabilities loss_mixed_labels = -((mixed_vec[0] * logprobs[0][0]) + (mixed_vec[1] * logprobs[0][1])) - print("Loss Mixed Lables l: ", loss_mixed_labels) - results = torch.cat((loss_mixed_labels.view(1), results), dim=0) - print("Results Mixed 1: ", results) + #calculation for mixed with indexing + results = torch.cat((loss_mixed_labels.view(1), results), dim=0)#append resultts to result tensor print("ALL BATCH Results: ", results) - batch_loss = results.mean() #compute average - #print("Batch Loss: ", batch_loss) + batch_loss = results.mean() #compute average of all results return batch_loss -- GitLab