From 61c49c3e93d4fc096f076986aca66ae3d8ce99f3 Mon Sep 17 00:00:00 2001 From: JulianFP <julian@partanengroup.de> Date: Thu, 13 Mar 2025 17:56:19 +0000 Subject: [PATCH] Fix T5_MLM_label --- src/common_interface.py | 2 +- src/models/T5_MLM_label.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/common_interface.py b/src/common_interface.py index 956591f..934fb08 100644 --- a/src/common_interface.py +++ b/src/common_interface.py @@ -6,7 +6,7 @@ from src.models.llms_interface import available_models as llms from src.models.GLiNER import find_entities as find_entities_gliner from src.models.GLiNER import classify_entity as classify_entity_gliner from src.models.T5 import classify_entity as classify_entity_t5 -from src.models.T5 import classify_entity as classify_entity_t5_mlm_label +from src.models.T5_MLM_label import classify_entity as classify_entity_t5_mlm_label from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities as find_entities_llm diff --git a/src/models/T5_MLM_label.py b/src/models/T5_MLM_label.py index 0b875f4..39bd94d 100644 --- a/src/models/T5_MLM_label.py +++ b/src/models/T5_MLM_label.py @@ -10,13 +10,13 @@ model = T5ForConditionalGeneration.from_pretrained(model_name) print("Finished loading model: T5 MLM") def classify_entity(sentence, entity, labels): - sentence_with_masked_hypothesis = f"{sentence} {entity} is a <extra_id_0>" - inputs_ids = tokenizer(sentence_with_masked_hypothesis, return_tensors="pt").inputs_ids + sentence_with_masked_hypothesis = f"{sentence} {entity} is a <extra_id_0>." + input_ids = tokenizer(sentence_with_masked_hypothesis, return_tensors="pt").input_ids results = {} for label in labels: - label_ids = tokenizer(f"<extra_id_0> {label}", return_tensors="pt").inputs_ids - loss = model(inputs_ids=inputs_ids, labels=label_ids) + label_ids = tokenizer(f"<extra_id_0> {label}", return_tensors="pt").input_ids + loss = model(input_ids=input_ids, labels=label_ids).loss.item() results[loss] = label min_loss = min(results.keys()) -- GitLab