From 61c49c3e93d4fc096f076986aca66ae3d8ce99f3 Mon Sep 17 00:00:00 2001
From: JulianFP <julian@partanengroup.de>
Date: Thu, 13 Mar 2025 17:56:19 +0000
Subject: [PATCH] Fix T5_MLM_label

---
 src/common_interface.py    | 2 +-
 src/models/T5_MLM_label.py | 8 ++++----
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/src/common_interface.py b/src/common_interface.py
index 956591f..934fb08 100644
--- a/src/common_interface.py
+++ b/src/common_interface.py
@@ -6,7 +6,7 @@ from src.models.llms_interface import available_models as llms
 from src.models.GLiNER import find_entities as find_entities_gliner
 from src.models.GLiNER import classify_entity as classify_entity_gliner
 from src.models.T5 import classify_entity as classify_entity_t5
-from src.models.T5 import classify_entity as classify_entity_t5_mlm_label
+from src.models.T5_MLM_label import classify_entity as classify_entity_t5_mlm_label
 from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities as find_entities_llm
 
 
diff --git a/src/models/T5_MLM_label.py b/src/models/T5_MLM_label.py
index 0b875f4..39bd94d 100644
--- a/src/models/T5_MLM_label.py
+++ b/src/models/T5_MLM_label.py
@@ -10,13 +10,13 @@ model = T5ForConditionalGeneration.from_pretrained(model_name)
 print("Finished loading model: T5 MLM")
 
 def classify_entity(sentence, entity, labels):
-    sentence_with_masked_hypothesis = f"{sentence} {entity} is a <extra_id_0>"
-    inputs_ids = tokenizer(sentence_with_masked_hypothesis, return_tensors="pt").inputs_ids
+    sentence_with_masked_hypothesis = f"{sentence} {entity} is a <extra_id_0>."
+    input_ids = tokenizer(sentence_with_masked_hypothesis, return_tensors="pt").input_ids
 
     results = {}
     for label in labels:
-        label_ids = tokenizer(f"<extra_id_0> {label}", return_tensors="pt").inputs_ids
-        loss = model(inputs_ids=inputs_ids, labels=label_ids)
+        label_ids = tokenizer(f"<extra_id_0> {label}", return_tensors="pt").input_ids
+        loss = model(input_ids=input_ids, labels=label_ids).loss.item()
         results[loss] = label
 
     min_loss = min(results.keys())
-- 
GitLab