Skip to content
Snippets Groups Projects
Unverified Commit 61c49c3e authored by JulianFP's avatar JulianFP
Browse files

Fix T5_MLM_label

parent f558c9e4
No related branches found
No related tags found
No related merge requests found
......@@ -6,7 +6,7 @@ from src.models.llms_interface import available_models as llms
from src.models.GLiNER import find_entities as find_entities_gliner
from src.models.GLiNER import classify_entity as classify_entity_gliner
from src.models.T5 import classify_entity as classify_entity_t5
from src.models.T5 import classify_entity as classify_entity_t5_mlm_label
from src.models.T5_MLM_label import classify_entity as classify_entity_t5_mlm_label
from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities as find_entities_llm
......
......@@ -10,13 +10,13 @@ model = T5ForConditionalGeneration.from_pretrained(model_name)
print("Finished loading model: T5 MLM")
def classify_entity(sentence, entity, labels):
sentence_with_masked_hypothesis = f"{sentence} {entity} is a <extra_id_0>"
inputs_ids = tokenizer(sentence_with_masked_hypothesis, return_tensors="pt").inputs_ids
sentence_with_masked_hypothesis = f"{sentence} {entity} is a <extra_id_0>."
input_ids = tokenizer(sentence_with_masked_hypothesis, return_tensors="pt").input_ids
results = {}
for label in labels:
label_ids = tokenizer(f"<extra_id_0> {label}", return_tensors="pt").inputs_ids
loss = model(inputs_ids=inputs_ids, labels=label_ids)
label_ids = tokenizer(f"<extra_id_0> {label}", return_tensors="pt").input_ids
loss = model(input_ids=input_ids, labels=label_ids).loss.item()
results[loss] = label
min_loss = min(results.keys())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment