diff --git a/src/common_interface.py b/src/common_interface.py index 821ea2fecdc5c1e9cefcd3aed66e2382b6cb5619..956591f233a169475d14b695e7acf5e27ce9d56d 100644 --- a/src/common_interface.py +++ b/src/common_interface.py @@ -6,6 +6,7 @@ from src.models.llms_interface import available_models as llms from src.models.GLiNER import find_entities as find_entities_gliner from src.models.GLiNER import classify_entity as classify_entity_gliner from src.models.T5 import classify_entity as classify_entity_t5 +from src.models.T5 import classify_entity as classify_entity_t5_mlm_label from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities as find_entities_llm @@ -15,6 +16,8 @@ def classify_entity(model_name, sentence, entity, labels): """ if model_name == "T5": return classify_entity_t5(sentence, entity, labels) + elif model_name == "T5-MLM-label": + return classify_entity_t5_mlm_label(sentence, entity, labels) elif model_name == "GLiNER": return classify_entity_gliner(sentence, entity, labels) diff --git a/src/models/T5_MLM_label.py b/src/models/T5_MLM_label.py new file mode 100644 index 0000000000000000000000000000000000000000..0b875f4156c0328baa17a7844c64bcabc45b2a53 --- /dev/null +++ b/src/models/T5_MLM_label.py @@ -0,0 +1,23 @@ +import torch +from torch.nn.functional import softmax +from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq + +model_name = "google-t5/t5-base" + +print("Loading model: T5 MLM") +tokenizer = T5Tokenizer.from_pretrained(model_name) +model = T5ForConditionalGeneration.from_pretrained(model_name) +print("Finished loading model: T5 MLM") + +def classify_entity(sentence, entity, labels): + sentence_with_masked_hypothesis = f"{sentence} {entity} is a <extra_id_0>" + inputs_ids = tokenizer(sentence_with_masked_hypothesis, return_tensors="pt").inputs_ids + + results = {} + for label in labels: + label_ids = tokenizer(f"<extra_id_0> {label}", return_tensors="pt").inputs_ids + loss = model(inputs_ids=inputs_ids, labels=label_ids) + results[loss] = label + + min_loss = min(results.keys()) + return results[min_loss] diff --git a/tests/test_NEC.py b/tests/test_NEC.py index deeedde8c09df48087016ee8bbd40dc6e7a08dff..87f0eced1d36d75ae5f1f7b4d2bfcd096ce282c4 100644 --- a/tests/test_NEC.py +++ b/tests/test_NEC.py @@ -1,6 +1,6 @@ from src.common_interface import classify_entity -tested_models = ["GLiNER", "T5"] +tested_models = ["GLiNER", "T5", "T5-MLM-label"] test_sentence = "Barack Obama was the president of the United States."