diff --git a/src/models/T5.py b/src/models/T5.py
index 08ee89c5930a31615fe83326f0595b623a762701..8e4cfd9958e6d4c95b853f866ed4401314865d27 100644
--- a/src/models/T5.py
+++ b/src/models/T5.py
@@ -1,3 +1,5 @@
+import torch
+from torch.nn.functional import softmax
 from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq
 from datasets import Dataset, DatasetDict
 
@@ -14,33 +16,55 @@ print("Finished loading model: T5 NLI")
 
 def infer_nli(premise, hypothesis):
     input_text = f"nli hypothesis: {hypothesis} premise: {premise}"
-
-    # print("tokenize")
     inputs = tokenizer(input_text, return_tensors="pt")
-
-    # print("generate")
     output_ids = model.generate(**inputs)
-
-    # print("decode")
     result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
     return result
 
+def infer_nli_probabilities(premise, hypothesis):
+    input_text = f"nli hypothesis: {hypothesis} premise: {premise}"
+
+    inputs = tokenizer(input_text, return_tensors="pt")
+
+    decoder_input_ids = torch.tensor([[tokenizer.pad_token_id]])
+
+    with torch.no_grad():
+        outputs = model(**inputs, decoder_input_ids=decoder_input_ids)
+
+    logits = outputs.logits
+
+    # Use this code if we also want to get the classification:
+    # predicted_ids = torch.argmax(logits, dim=-1)
+    # predicted_label = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
+
+    probabilities = softmax(logits, dim=-1)
+
+    label_tokens = ["entailment", "contradiction"]
+    label_ids = [tokenizer.encode(label, add_special_tokens=False)[0] for label in label_tokens]
+
+    label_probs = {label: probabilities[0, 0, token_id].item() for label, token_id in zip(label_tokens, label_ids)}
+
+    return label_probs
+
 
 def classify_entity(sentence, entity, labels):
     best_label = None
+    best_prob = 0
+    # Find label with highest probability of entailment
     for label in labels:
-        # print(f"Label: {label}")
         hypothesis = f"{entity} is {label}"
-        result = infer_nli(sentence, hypothesis)
-        # print(f"Hypothesis: {hypothesis}, Result: {result}")
-        if result == "entailment":
-            return label  # Return immediately if entailment is found
-        elif result == "neutral" and best_label is None:
-            best_label = label  # Store the first neutral label as a fallback
+        label_probs = infer_nli_probabilities(sentence, hypothesis)
 
-    return best_label if best_label else labels[0]
+        # print(f"label: {label}, entailment prob: {label_probs["entailment"]}, contradiction prob: {label_probs["contradiction"]}")
+
+        entailment_prob = label_probs["entailment"]
 
+        if entailment_prob > best_prob:
+            best_prob = entailment_prob
+            best_label = label
+
+    return best_label if best_label else labels[0]
 
 def preprocess_data(sample):
     input_text = f"nli hypothesis: {sample['hypothesis']} premise: {sample['premise']}"