Skip to content
Snippets Groups Projects
Unverified Commit f558c9e4 authored by JulianFP's avatar JulianFP
Browse files

Add T5 MLM approach where the label is being masked

parent 602041ce
No related branches found
No related tags found
No related merge requests found
...@@ -6,6 +6,7 @@ from src.models.llms_interface import available_models as llms ...@@ -6,6 +6,7 @@ from src.models.llms_interface import available_models as llms
from src.models.GLiNER import find_entities as find_entities_gliner from src.models.GLiNER import find_entities as find_entities_gliner
from src.models.GLiNER import classify_entity as classify_entity_gliner from src.models.GLiNER import classify_entity as classify_entity_gliner
from src.models.T5 import classify_entity as classify_entity_t5 from src.models.T5 import classify_entity as classify_entity_t5
from src.models.T5 import classify_entity as classify_entity_t5_mlm_label
from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities as find_entities_llm from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities as find_entities_llm
...@@ -15,6 +16,8 @@ def classify_entity(model_name, sentence, entity, labels): ...@@ -15,6 +16,8 @@ def classify_entity(model_name, sentence, entity, labels):
""" """
if model_name == "T5": if model_name == "T5":
return classify_entity_t5(sentence, entity, labels) return classify_entity_t5(sentence, entity, labels)
elif model_name == "T5-MLM-label":
return classify_entity_t5_mlm_label(sentence, entity, labels)
elif model_name == "GLiNER": elif model_name == "GLiNER":
return classify_entity_gliner(sentence, entity, labels) return classify_entity_gliner(sentence, entity, labels)
......
import torch
from torch.nn.functional import softmax
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq
model_name = "google-t5/t5-base"
print("Loading model: T5 MLM")
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
print("Finished loading model: T5 MLM")
def classify_entity(sentence, entity, labels):
sentence_with_masked_hypothesis = f"{sentence} {entity} is a <extra_id_0>"
inputs_ids = tokenizer(sentence_with_masked_hypothesis, return_tensors="pt").inputs_ids
results = {}
for label in labels:
label_ids = tokenizer(f"<extra_id_0> {label}", return_tensors="pt").inputs_ids
loss = model(inputs_ids=inputs_ids, labels=label_ids)
results[loss] = label
min_loss = min(results.keys())
return results[min_loss]
from src.common_interface import classify_entity from src.common_interface import classify_entity
tested_models = ["GLiNER", "T5"] tested_models = ["GLiNER", "T5", "T5-MLM-label"]
test_sentence = "Barack Obama was the president of the United States." test_sentence = "Barack Obama was the president of the United States."
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment