From 93d60fed339e4c17ccc2be39105e14cfaccadf93 Mon Sep 17 00:00:00 2001
From: Thomas Wolf <thomas_wolf@online.de>
Date: Mon, 10 Mar 2025 11:06:00 +0100
Subject: [PATCH] T5 NEC

---
 src/common_interface.py        |  2 +-
 src/experiments/NER_with_T5.py | 12 ------------
 src/models/T5.py               | 24 ++++++++++++++++++++----
 3 files changed, 21 insertions(+), 17 deletions(-)
 delete mode 100644 src/experiments/NER_with_T5.py

diff --git a/src/common_interface.py b/src/common_interface.py
index 4bd6871..821ea2f 100644
--- a/src/common_interface.py
+++ b/src/common_interface.py
@@ -5,7 +5,7 @@ Makes evaluating models easier.
 from src.models.llms_interface import available_models as llms
 from src.models.GLiNER import find_entities as find_entities_gliner
 from src.models.GLiNER import classify_entity as classify_entity_gliner
-from src.experiments.NER_with_T5 import classify_entity as classify_entity_t5
+from src.models.T5 import classify_entity as classify_entity_t5
 from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities as find_entities_llm
 
 
diff --git a/src/experiments/NER_with_T5.py b/src/experiments/NER_with_T5.py
deleted file mode 100644
index bf0754c..0000000
--- a/src/experiments/NER_with_T5.py
+++ /dev/null
@@ -1,12 +0,0 @@
-from src.models.T5 import infer_nli
-
-
-def classify_entity(sentence, entity, labels):
-    print("classify entity")
-    for label in labels:
-        print(f"Label: {label}")
-        hypothesis = f"{entity} is a {label}"
-        result = infer_nli(sentence, hypothesis)
-        print(f"Hypothesis: {hypothesis}, Result: {result}")
-    # TODO: determine highest confidence prediction
-    return labels[0]
diff --git a/src/models/T5.py b/src/models/T5.py
index 84cbb9d..022a396 100644
--- a/src/models/T5.py
+++ b/src/models/T5.py
@@ -3,7 +3,8 @@ from datasets import Dataset, DatasetDict
 
 label_map = {True: "entailment", False: "contradiction"}
 
-model_name = "google/t5_xxl_true_nli_mixture"
+# Use t5-base for testing because it is smaller
+model_name = "google-t5/t5-base"  # google/t5_xxl_true_nli_mixture
 
 print("Loading model: T5 NLI")
 tokenizer = T5Tokenizer.from_pretrained(model_name)
@@ -14,18 +15,33 @@ print("Finished loading model: T5 NLI")
 def infer_nli(premise, hypothesis):
     input_text = f"nli hypothesis: {hypothesis} premise: {premise}"
 
-    print("tokenize")
+    # print("tokenize")
     inputs = tokenizer(input_text, return_tensors="pt")
 
-    print("generate")
+    # print("generate")
     output_ids = model.generate(**inputs)
 
-    print("decode")
+    # print("decode")
     result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
     return result
 
 
+def classify_entity(sentence, entity, labels):
+    best_label = None
+    for label in labels:
+        # rint(f"Label: {label}")
+        hypothesis = f"{entity} is {label}"
+        result = infer_nli(sentence, hypothesis)
+        # print(f"Hypothesis: {hypothesis}, Result: {result}")
+        if result == "entailment":
+            return label  # Return immediately if entailment is found
+        elif result == "neutral" and best_label is None:
+            best_label = label  # Store the first neutral label as a fallback
+
+    return best_label if best_label else labels[0]
+
+
 def preprocess_data(sample):
     input_text = f"nli hypothesis: {sample['hypothesis']} premise: {sample['premise']}"
     target_text = label_map[bool(sample['entailment'])]
-- 
GitLab