Skip to content
Snippets Groups Projects
Commit 71855c8d authored by umlauf's avatar umlauf
Browse files

Comments/Doc CEL

parent 1bf3cb79
No related branches found
No related tags found
No related merge requests found
......@@ -139,34 +139,37 @@ def train(model, name, imdb, seed,mixup,lambda_value, mixepoch, tmix, mixlayer,
return evaluation_test, evaluation_train
def cross_entropy(logits, target, l):
def cross_entropy(logits, target):
"""
Computes the cross-entropy loss between the predicted logits and the target labels.
Args:
- logits (torch.Tensor): A tensor of shape (batch_size, num_classes) representing the predicted logits for each input example.
- target (torch.Tensor): A tensor of shape (batch_size,) representing the target labels for each input example.
Returns:
- batch_loss (torch.Tensor): A scalar tensor representing the average cross-entropy loss across the batch.
"""
results = torch.tensor([], device='cuda')
for i in range (logits.shape[0]):
lg = logits[i:i+1,:] #comment to explain the process in this Code Line
lg = logits[i:i+1,:] # Extract a row from the logits tensor for the i-th input example in the batch
t = target[i]
#makes the logits in log (base e) probabilities
logprobs = torch.nn.functional.log_softmax(lg, dim=1)
value = t.item() #gets Item (0. or 1.)
if value == 1 or value == 0:
logprobs = torch.nn.functional.log_softmax(lg, dim=1) #logits in log (base e) probabilities
value = t.item() #get scalar value
if value == 1 or value == 0: #check if non-mixed label
one_hot = torch.tensor([1-value,value], device='cuda:0') #creating one-hot vector e.g. [0. ,1.]
#class 1 and 2 mixed
loss_clear_labels = -((one_hot[0] * logprobs[0][0]) + (one_hot[1] * logprobs[0][1]))
#calculation with indexing (- 1-label * )
results = torch.cat((loss_clear_labels.view(1), results), dim=0)
else:
value_r = round(value, 1) #to make it equal to lambda_value e.g. 0.4
#Wert mit Flag
mixed_vec = torch.tensor([value_r, 1-value_r])
print("Mixed Vec: ", mixed_vec)
logprobs = torch.nn.functional.log_softmax(lg, dim=1)
print("Log:", logprobs)
#loss_mixed_labels = -torch.mul(mixed_vec, logprobs).sum()
mixed_vec = torch.tensor([value_r, 1-value_r]) #creating on-hot mixed vec.
logprobs = torch.nn.functional.log_softmax(lg, dim=1)#logits in log probabilities
loss_mixed_labels = -((mixed_vec[0] * logprobs[0][0]) + (mixed_vec[1] * logprobs[0][1]))
print("Loss Mixed Lables l: ", loss_mixed_labels)
results = torch.cat((loss_mixed_labels.view(1), results), dim=0)
print("Results Mixed 1: ", results)
#calculation for mixed with indexing
results = torch.cat((loss_mixed_labels.view(1), results), dim=0)#append resultts to result tensor
print("ALL BATCH Results: ", results)
batch_loss = results.mean() #compute average
#print("Batch Loss: ", batch_loss)
batch_loss = results.mean() #compute average of all results
return batch_loss
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment