From 4d8d5696eef8cb671c201b38d8a0139d1a055f96 Mon Sep 17 00:00:00 2001 From: Thomas Wolf <thomas_wolf@online.de> Date: Mon, 10 Mar 2025 09:56:40 +0100 Subject: [PATCH] NEC with GLiNER --- src/common_interface.py | 7 ++++--- src/experiments/GLiNER_evalutaion/evaluation.py | 2 +- src/experiments/NEC_evaluation/evaluation.py | 7 +++++++ src/experiments/NER_with_T5.py | 1 + src/models/GLiNER.py | 10 ++++++++++ src/models/T5.py | 12 ++++++------ src/plotter.py | 3 +++ tests/test_NEC.py | 15 ++++++++------- 8 files changed, 40 insertions(+), 17 deletions(-) create mode 100644 src/experiments/NEC_evaluation/evaluation.py diff --git a/src/common_interface.py b/src/common_interface.py index f8824bb..4bd6871 100644 --- a/src/common_interface.py +++ b/src/common_interface.py @@ -4,18 +4,19 @@ Makes evaluating models easier. """ from src.models.llms_interface import available_models as llms from src.models.GLiNER import find_entities as find_entities_gliner +from src.models.GLiNER import classify_entity as classify_entity_gliner from src.experiments.NER_with_T5 import classify_entity as classify_entity_t5 from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities as find_entities_llm def classify_entity(model_name, sentence, entity, labels): """ - Entity Classification + NEC. Returns label (string) for entity. """ if model_name == "T5": return classify_entity_t5(sentence, entity, labels) elif model_name == "GLiNER": - pass # todo + return classify_entity_gliner(sentence, entity, labels) def predict_mask_mlm(model_name, masked_sentence, labels): @@ -32,7 +33,7 @@ def predict_mask_nli(model_name, masked_sentence, labels): def find_entities(model_name, sentence, labels): """ - NER + NER. Returns list of pairs [(entity, label), ...] """ if model_name in llms: return find_entities_llm(model_name, sentence, labels) diff --git a/src/experiments/GLiNER_evalutaion/evaluation.py b/src/experiments/GLiNER_evalutaion/evaluation.py index 043c9a3..eb8cc48 100644 --- a/src/experiments/GLiNER_evalutaion/evaluation.py +++ b/src/experiments/GLiNER_evalutaion/evaluation.py @@ -4,6 +4,6 @@ Evaluates GLiNER as SotA and plots results using reusable functions in plotter. from src.metrics import NER_metrics, read_NER_metrics -NER_metrics("GLiNER", "CoNLL", "results", test_instances=100) +# NER_metrics("GLiNER", "CoNLL", "results", test_instances=100) read_NER_metrics("results") diff --git a/src/experiments/NEC_evaluation/evaluation.py b/src/experiments/NEC_evaluation/evaluation.py new file mode 100644 index 0000000..6ed1292 --- /dev/null +++ b/src/experiments/NEC_evaluation/evaluation.py @@ -0,0 +1,7 @@ +""" +This file evaluates all NEC approaches. +""" + +# todo: perform tests on datasets and store results + +# todo read results and compare the models / plot the results diff --git a/src/experiments/NER_with_T5.py b/src/experiments/NER_with_T5.py index 8f9e5c1..bf0754c 100644 --- a/src/experiments/NER_with_T5.py +++ b/src/experiments/NER_with_T5.py @@ -1,5 +1,6 @@ from src.models.T5 import infer_nli + def classify_entity(sentence, entity, labels): print("classify entity") for label in labels: diff --git a/src/models/GLiNER.py b/src/models/GLiNER.py index ffa9cd0..891b357 100644 --- a/src/models/GLiNER.py +++ b/src/models/GLiNER.py @@ -16,3 +16,13 @@ def find_entities(sentence, labels): entity_list.append((entity["text"], entity["label"])) return entity_list + + +def classify_entity(sentence, entity, labels): + entity_list = find_entities(sentence, labels) + for e in entity_list: + if e[0] == entity: + return e[1] # Return label + + return "" + diff --git a/src/models/T5.py b/src/models/T5.py index 3fe7934..84cbb9d 100644 --- a/src/models/T5.py +++ b/src/models/T5.py @@ -10,9 +10,10 @@ tokenizer = T5Tokenizer.from_pretrained(model_name) model = T5ForConditionalGeneration.from_pretrained(model_name) print("Finished loading model: T5 NLI") + def infer_nli(premise, hypothesis): input_text = f"nli hypothesis: {hypothesis} premise: {premise}" - + print("tokenize") inputs = tokenizer(input_text, return_tensors="pt") @@ -24,19 +25,21 @@ def infer_nli(premise, hypothesis): return result + def preprocess_data(sample): input_text = f"nli hypothesis: {sample['hypothesis']} premise: {sample['premise']}" target_text = label_map[bool(sample['entailment'])] - + tokenized_input = tokenizer(input_text, padding="max_length", truncation=True, max_length=512) tokenized_target = tokenizer(target_text, padding="max_length", truncation=True, max_length=10) tokenized_input["labels"] = tokenized_target["input_ids"] return tokenized_input + def finetune_model(premises, hypotheses, entailment): # TODO: should we use dataset on a higher level as well? - data_dict = { "premise": premises, "hypothesis": hypotheses, "entailment": entailment} + data_dict = {"premise": premises, "hypothesis": hypotheses, "entailment": entailment} dataset = Dataset.from_dict(data_dict) print(dataset) @@ -78,6 +81,3 @@ def finetune_model(premises, hypotheses, entailment): ) trainer.train() - - - diff --git a/src/plotter.py b/src/plotter.py index 9636398..a3ba001 100644 --- a/src/plotter.py +++ b/src/plotter.py @@ -14,3 +14,6 @@ def plot_bars(data, x_column, y_column, grouping, title, ylabel, xlabel): """ Reusable barchart plotting function """ + + +# todo: bar chart with three grouped columns for each model: precision, recall, f1-score \ No newline at end of file diff --git a/tests/test_NEC.py b/tests/test_NEC.py index dea133b..deeedde 100644 --- a/tests/test_NEC.py +++ b/tests/test_NEC.py @@ -1,17 +1,18 @@ from src.common_interface import classify_entity -tested_models = ["T5"] +tested_models = ["GLiNER", "T5"] -test_sentences = ["Barack Obama was the president of the United States."] +test_sentence = "Barack Obama was the president of the United States." -test_entity = ["Barack Obama"] +test_entities = ["Barack Obama", "United States"] -true_labels = ["person"] +labels = ["person", "organization", "time", "location", "miscellaneous"] -labels = ["person", "organization", "time", "location"] print("Test NEC") for model in tested_models: - for index in range(len(test_sentences)): - classify_entity(model, test_sentences[index], test_entity[index], labels) + print("\n") + for test_entity in test_entities: + print(f"{model} prediction for {test_entity}:") + print(classify_entity(model, test_sentence, test_entity, labels)) -- GitLab