Skip to content
Snippets Groups Projects
Commit 16fd7ca2 authored by friebolin's avatar friebolin
Browse files
parents 6bd4faf5 0522f82a
No related branches found
No related tags found
No related merge requests found
......@@ -139,34 +139,40 @@ def train(model, name, imdb, seed,mixup,lambda_value, mixepoch, tmix, mixlayer,
return evaluation_test, evaluation_train
def cross_entropy(logits, target, l):
def cross_entropy(logits, target):
"""
Computes the cross-entropy loss between the predicted logits and the target labels.
Args:
- logits (torch.Tensor): A tensor of shape (batch_size, num_classes)
representing the predicted logits for each input example.
- target (torch.Tensor): A tensor of shape (batch_size,)
representing the target labels for each input example.
Returns:
- batch_loss (torch.Tensor): A scalar tensor representing
the average cross-entropy loss across the batch.
"""
results = torch.tensor([], device='cuda')
for i in range (logits.shape[0]):
lg = logits[i:i+1,:] #comment to explain the process in this Code Line
lg = logits[i:i+1,:] # Extract a row from the logits tensor for the i-th input example in the batch
t = target[i]
#makes the logits in log (base e) probabilities
logprobs = torch.nn.functional.log_softmax(lg, dim=1)
value = t.item() #gets Item (0. or 1.)
if value == 1 or value == 0:
logprobs = torch.nn.functional.log_softmax(lg, dim=1) #logits in log (base e) probabilities
value = t.item() #get scalar value
if value == 1 or value == 0: #check if non-mixed label
one_hot = torch.tensor([1-value,value], device='cuda:0') #creating one-hot vector e.g. [0. ,1.]
#class 1 and 2 mixed
loss_clear_labels = -((one_hot[0] * logprobs[0][0]) + (one_hot[1] * logprobs[0][1]))
#calculation with indexing (- 1-label * )
results = torch.cat((loss_clear_labels.view(1), results), dim=0)
else:
value_r = round(value, 1) #to make it equal to lambda_value e.g. 0.4
#Wert mit Flag
mixed_vec = torch.tensor([value_r, 1-value_r])
print("Mixed Vec: ", mixed_vec)
logprobs = torch.nn.functional.log_softmax(lg, dim=1)
print("Log:", logprobs)
#loss_mixed_labels = -torch.mul(mixed_vec, logprobs).sum()
mixed_vec = torch.tensor([value_r, 1-value_r]) #creating on-hot mixed vec.
logprobs = torch.nn.functional.log_softmax(lg, dim=1)#logits in log probabilities
loss_mixed_labels = -((mixed_vec[0] * logprobs[0][0]) + (mixed_vec[1] * logprobs[0][1]))
print("Loss Mixed Lables l: ", loss_mixed_labels)
results = torch.cat((loss_mixed_labels.view(1), results), dim=0)
print("Results Mixed 1: ", results)
print("ALL BATCH Results: ", results)
batch_loss = results.mean() #compute average
#print("Batch Loss: ", batch_loss)
#calculation for mixed with indexing
results = torch.cat((loss_mixed_labels.view(1), results), dim=0)#append resultts to result tensor
print("LOSS BATCH (Results): ", results)
batch_loss = results.mean() #compute average of all results
return batch_loss
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment