Skip to content
Snippets Groups Projects
Commit 93d60fed authored by Thomas Wolf's avatar Thomas Wolf
Browse files

T5 NEC

parent 4d8d5696
No related branches found
No related tags found
No related merge requests found
......@@ -5,7 +5,7 @@ Makes evaluating models easier.
from src.models.llms_interface import available_models as llms
from src.models.GLiNER import find_entities as find_entities_gliner
from src.models.GLiNER import classify_entity as classify_entity_gliner
from src.experiments.NER_with_T5 import classify_entity as classify_entity_t5
from src.models.T5 import classify_entity as classify_entity_t5
from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities as find_entities_llm
......
from src.models.T5 import infer_nli
def classify_entity(sentence, entity, labels):
print("classify entity")
for label in labels:
print(f"Label: {label}")
hypothesis = f"{entity} is a {label}"
result = infer_nli(sentence, hypothesis)
print(f"Hypothesis: {hypothesis}, Result: {result}")
# TODO: determine highest confidence prediction
return labels[0]
......@@ -3,7 +3,8 @@ from datasets import Dataset, DatasetDict
label_map = {True: "entailment", False: "contradiction"}
model_name = "google/t5_xxl_true_nli_mixture"
# Use t5-base for testing because it is smaller
model_name = "google-t5/t5-base" # google/t5_xxl_true_nli_mixture
print("Loading model: T5 NLI")
tokenizer = T5Tokenizer.from_pretrained(model_name)
......@@ -14,18 +15,33 @@ print("Finished loading model: T5 NLI")
def infer_nli(premise, hypothesis):
input_text = f"nli hypothesis: {hypothesis} premise: {premise}"
print("tokenize")
# print("tokenize")
inputs = tokenizer(input_text, return_tensors="pt")
print("generate")
# print("generate")
output_ids = model.generate(**inputs)
print("decode")
# print("decode")
result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return result
def classify_entity(sentence, entity, labels):
best_label = None
for label in labels:
# rint(f"Label: {label}")
hypothesis = f"{entity} is {label}"
result = infer_nli(sentence, hypothesis)
# print(f"Hypothesis: {hypothesis}, Result: {result}")
if result == "entailment":
return label # Return immediately if entailment is found
elif result == "neutral" and best_label is None:
best_label = label # Store the first neutral label as a fallback
return best_label if best_label else labels[0]
def preprocess_data(sample):
input_text = f"nli hypothesis: {sample['hypothesis']} premise: {sample['premise']}"
target_text = label_map[bool(sample['entailment'])]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment