From 93d60fed339e4c17ccc2be39105e14cfaccadf93 Mon Sep 17 00:00:00 2001 From: Thomas Wolf <thomas_wolf@online.de> Date: Mon, 10 Mar 2025 11:06:00 +0100 Subject: [PATCH] T5 NEC --- src/common_interface.py | 2 +- src/experiments/NER_with_T5.py | 12 ------------ src/models/T5.py | 24 ++++++++++++++++++++---- 3 files changed, 21 insertions(+), 17 deletions(-) delete mode 100644 src/experiments/NER_with_T5.py diff --git a/src/common_interface.py b/src/common_interface.py index 4bd6871..821ea2f 100644 --- a/src/common_interface.py +++ b/src/common_interface.py @@ -5,7 +5,7 @@ Makes evaluating models easier. from src.models.llms_interface import available_models as llms from src.models.GLiNER import find_entities as find_entities_gliner from src.models.GLiNER import classify_entity as classify_entity_gliner -from src.experiments.NER_with_T5 import classify_entity as classify_entity_t5 +from src.models.T5 import classify_entity as classify_entity_t5 from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities as find_entities_llm diff --git a/src/experiments/NER_with_T5.py b/src/experiments/NER_with_T5.py deleted file mode 100644 index bf0754c..0000000 --- a/src/experiments/NER_with_T5.py +++ /dev/null @@ -1,12 +0,0 @@ -from src.models.T5 import infer_nli - - -def classify_entity(sentence, entity, labels): - print("classify entity") - for label in labels: - print(f"Label: {label}") - hypothesis = f"{entity} is a {label}" - result = infer_nli(sentence, hypothesis) - print(f"Hypothesis: {hypothesis}, Result: {result}") - # TODO: determine highest confidence prediction - return labels[0] diff --git a/src/models/T5.py b/src/models/T5.py index 84cbb9d..022a396 100644 --- a/src/models/T5.py +++ b/src/models/T5.py @@ -3,7 +3,8 @@ from datasets import Dataset, DatasetDict label_map = {True: "entailment", False: "contradiction"} -model_name = "google/t5_xxl_true_nli_mixture" +# Use t5-base for testing because it is smaller +model_name = "google-t5/t5-base" # google/t5_xxl_true_nli_mixture print("Loading model: T5 NLI") tokenizer = T5Tokenizer.from_pretrained(model_name) @@ -14,18 +15,33 @@ print("Finished loading model: T5 NLI") def infer_nli(premise, hypothesis): input_text = f"nli hypothesis: {hypothesis} premise: {premise}" - print("tokenize") + # print("tokenize") inputs = tokenizer(input_text, return_tensors="pt") - print("generate") + # print("generate") output_ids = model.generate(**inputs) - print("decode") + # print("decode") result = tokenizer.decode(output_ids[0], skip_special_tokens=True) return result +def classify_entity(sentence, entity, labels): + best_label = None + for label in labels: + # rint(f"Label: {label}") + hypothesis = f"{entity} is {label}" + result = infer_nli(sentence, hypothesis) + # print(f"Hypothesis: {hypothesis}, Result: {result}") + if result == "entailment": + return label # Return immediately if entailment is found + elif result == "neutral" and best_label is None: + best_label = label # Store the first neutral label as a fallback + + return best_label if best_label else labels[0] + + def preprocess_data(sample): input_text = f"nli hypothesis: {sample['hypothesis']} premise: {sample['premise']}" target_text = label_map[bool(sample['entailment'])] -- GitLab