From fde56af7a347fb7ad0ea913ece19a995208c3a0c Mon Sep 17 00:00:00 2001 From: JulianFP <julian@partanengroup.de> Date: Thu, 27 Mar 2025 18:09:29 +0000 Subject: [PATCH] Add T5 MLM with entity masking, both classification and finetuning code --- .gitignore | 3 + shell.nix | 36 ++---- src/common_interface.py | 6 + src/experiments/NEC_evaluation/evaluation.py | 4 +- .../finetune_T5/finetune_T5_MLM_entity.py | 33 +++++ ...une_T5_MLM.py => finetune_T5_MLM_label.py} | 0 src/models/T5_MLM_entity.py | 115 ++++++++++++++++++ src/models/T5_MLM_label.py | 4 +- 8 files changed, 172 insertions(+), 29 deletions(-) create mode 100644 src/experiments/finetune_T5/finetune_T5_MLM_entity.py rename src/experiments/finetune_T5/{finetune_T5_MLM.py => finetune_T5_MLM_label.py} (100%) create mode 100644 src/models/T5_MLM_entity.py diff --git a/.gitignore b/.gitignore index 9d2190d..f2ca767 100644 --- a/.gitignore +++ b/.gitignore @@ -116,3 +116,6 @@ train.json # slurm output logs + +# evaluation output +results diff --git a/shell.nix b/shell.nix index 1fb8620..90bd06e 100644 --- a/shell.nix +++ b/shell.nix @@ -2,29 +2,15 @@ let myPythonPackages = ps: with ps; [ pip - python-dotenv - numpy - ollama - datasets - protobuf - requests - tqdm - #transformers - #gensim - #accelerate - #gliner missing ]; - fhs = pkgs.buildFHSEnv { - name = "Python NER environment with pip"; - - targetPkgs = _: [ - (pkgs.python3.withPackages myPythonPackages) - ]; - - profile ='' - python -m venv .venv - source .venv/bin/activate - zsh - ''; - }; -in fhs.env +in pkgs.mkShell { + buildInputs = with pkgs; [ + (python3.withPackages myPythonPackages) + ]; + shellHook = '' + export LD_LIBRARY_PATH=$NIX_LD_LIBRARY_PATH + python -m venv .venv + source .venv/bin/activate + zsh + ''; +} diff --git a/src/common_interface.py b/src/common_interface.py index dc6c73c..e5c6a70 100644 --- a/src/common_interface.py +++ b/src/common_interface.py @@ -7,6 +7,8 @@ from src.models.GLiNER import find_entities as find_entities_gliner from src.models.GLiNER import classify_entity as classify_entity_gliner from src.models.T5_NLI import classify_entity as classify_entity_t5_nli from src.models.T5_MLM_label import classify_entity as classify_entity_t5_mlm_label +from src.models.T5_MLM_entity import classify_entity as classify_entity_t5_mlm_entity +from src.models.T5_MLM_entity import set_label_dict as set_label_dict_t5_mlm_entity from src.models.Word2Vec import classify_entity as classify_entity_word2vec from src.models.Word2Vec import set_label_dict as set_label_dict_word2vec from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities as find_entities_llm @@ -21,6 +23,8 @@ def classify_entity(model_name, sentence, entity, labels): return classify_entity_t5_nli(sentence, entity, labels) elif model_name == "T5-MLM-label": return classify_entity_t5_mlm_label(sentence, entity, labels) + elif model_name == "T5-MLM-entity": + return classify_entity_t5_mlm_entity(sentence, entity, labels) elif model_name == "GLiNER": return classify_entity_gliner(sentence, entity, labels) elif model_name == "Word2Vec": @@ -49,5 +53,7 @@ def set_label_dict(model_name, label_dict): """ if model_name == "Word2Vec": return set_label_dict_word2vec(label_dict) + elif model_name == "T5-MLM-entity": + return set_label_dict_t5_mlm_entity(label_dict) else: print(f"set_label_dict not implemented for {model_name}") diff --git a/src/experiments/NEC_evaluation/evaluation.py b/src/experiments/NEC_evaluation/evaluation.py index 748e56f..5447780 100644 --- a/src/experiments/NEC_evaluation/evaluation.py +++ b/src/experiments/NEC_evaluation/evaluation.py @@ -27,7 +27,7 @@ def run_NEC_tests(model_name, dataset, results_dir, test_instances=10): labels = data_manager.get_labels(dataset) data = data_manager.get_annotated_sentences(dataset, test_instances) - if (model_name == "Word2Vec"): + if (model_name == "Word2Vec" or model_name == "T5-MLM-entity"): label_dict = data_manager.get_label_dict(dataset, test_instances) print(label_dict) set_label_dict(model_name, label_dict) @@ -58,7 +58,7 @@ def run_NEC_tests(model_name, dataset, results_dir, test_instances=10): def run_NEC_tests_all(): - models = ["GLiNER", "Llama-3.1-8B", "T5-NLI", "T5-MLM-label", "Word2Vec", "DeepSeek-R1-Distill-Qwen-32B"] + models = ["GLiNER", "Llama-3.1-8B", "T5-NLI", "T5-MLM-label", "T5-MLM-entity", "Word2Vec", "DeepSeek-R1-Distill-Qwen-32B"] datasets = ["CoNLL", "FIGER-coarse", "FIGER-fine"] # "Pile-NER-type"] for model in models: for dataset in datasets: diff --git a/src/experiments/finetune_T5/finetune_T5_MLM_entity.py b/src/experiments/finetune_T5/finetune_T5_MLM_entity.py new file mode 100644 index 0000000..c7ac094 --- /dev/null +++ b/src/experiments/finetune_T5/finetune_T5_MLM_entity.py @@ -0,0 +1,33 @@ +import data.data_manager as data_manager +from src.models.T5_MLM_entity import finetune_model, set_label_dict + +def finetune_t5(dataset): + print("start") + annotated_sentences = data_manager.get_annotated_sentences(dataset, 1000) + labels = data_manager.get_labels(dataset) + + sentences = [] + entities = [] + labels = [] + + for annotated_sentence in annotated_sentences: + sentence = annotated_sentence[0] + + for annotation in annotated_sentence[1]: + sentences.append(sentence) + entities.append(annotation[0]) + labels.append(annotation[1]) + + print(f"Finetuning on {len(sentences)} examples") + + for i in range(min(len(sentences), 50)): + print(f"sentence: {sentences[i]}, entity: {entities[i]}, label: {labels[i]}") + + epochs = 20 + + + label_dict = data_manager.get_label_dict(dataset, 1000) + set_label_dict(label_dict) + finetune_model(sentences, entities, labels, output_dir=f"./src/models/t5_mlm_entity_finetuned_model/pretrained_{dataset}_epoch{epochs}", epochs=epochs) + +finetune_t5("CoNLL") diff --git a/src/experiments/finetune_T5/finetune_T5_MLM.py b/src/experiments/finetune_T5/finetune_T5_MLM_label.py similarity index 100% rename from src/experiments/finetune_T5/finetune_T5_MLM.py rename to src/experiments/finetune_T5/finetune_T5_MLM_label.py diff --git a/src/models/T5_MLM_entity.py b/src/models/T5_MLM_entity.py new file mode 100644 index 0000000..f9604cb --- /dev/null +++ b/src/models/T5_MLM_entity.py @@ -0,0 +1,115 @@ +import torch +import random +import numpy as np +from torch.nn.functional import softmax +from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq +from datasets import Dataset, DatasetDict + +model_name = "google-t5/t5-base" + +print("Loading model: T5 MLM entity") +tokenizer = T5Tokenizer.from_pretrained(model_name) +model = T5ForConditionalGeneration.from_pretrained(model_name) +print("Finished loading model: T5 MLM entity") + +def set_label_dict(label_dict): + global label_representatives + label_representatives = label_dict + +def classify_entity(sentence, entity, labels): + global label_representatives + + masked_sentence = sentence.replace(entity, "<extra_id_0>", 1) + input_ids = tokenizer(masked_sentence, return_tensors="pt").input_ids + + best_label = None + best_label_loss = float("inf") + + for label in labels: + best_rep_loss = float("inf") + representatives = label_representatives[label] + + for representative in representatives: + # do not cheat if entity is one of the representatives already + if entity == representative: + continue + + representative_ids = tokenizer(f"<extra_id_0> {representative} <extra_id_1>", return_tensors="pt").input_ids + loss = model(input_ids=input_ids, labels=representative_ids).loss.item() + best_rep_loss = min(best_rep_loss, loss) + + if best_rep_loss < best_label_loss: + best_label_loss = best_rep_loss + best_label = label + + return best_label if best_label else labels[0] + + +def finetune_model(sentences, entities, labels, output_dir, epochs=10): + input_texts = [] + target_texts = [] + + random.seed(42) + + for i in range(len(sentences)): + sentence = sentences[i] + entity = entities[i] + label = labels[i] + representatives = label_representatives[label] + representative = random.choice(list(representatives)) + input_texts.append(sentence.replace(entity, "<extra_id_0>", 1)) + target_texts.append(f"<extra_id_0> {representative} <extra_id_1>") + + model_input = tokenizer(input_texts, return_tensors="pt", padding="max_length", truncation=True, max_length=128) + targets = tokenizer(target_texts, return_tensors="pt", padding="max_length", truncation=True, max_length=128) + + model_input["input_ids"] = np.array(model_input["input_ids"]) + model_input["attention_mask"] = np.array(model_input["attention_mask"]) + model_input["labels"] = np.array(targets["input_ids"]) + + dataset = Dataset.from_dict({ + "input_ids": model_input["input_ids"], + "attention_mask": model_input["attention_mask"], + "labels": model_input["labels"] + }) + + print(dataset) + + # split into training and validation data (20% used for validation) + train_test_split = dataset.train_test_split(test_size=0.2, shuffle=True, seed=0) + + dataset = DatasetDict({ + "train": train_test_split["train"], + "validation": train_test_split["test"] + }) + + dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"]) + + training_args = TrainingArguments( + output_dir="./src/models/t5_nli_finetuned_model", + eval_strategy="epoch", + learning_rate=5e-5, + per_device_train_batch_size=8, + per_device_eval_batch_size=8, + num_train_epochs=epochs, + weight_decay=0.01, + save_strategy="no", + push_to_hub=False, + logging_dir="./logs", + ) + + data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=dataset["train"], + eval_dataset=dataset["validation"], + tokenizer=tokenizer, + data_collator=data_collator, + ) + + trainer.train() + + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) diff --git a/src/models/T5_MLM_label.py b/src/models/T5_MLM_label.py index 279416f..d5fa033 100644 --- a/src/models/T5_MLM_label.py +++ b/src/models/T5_MLM_label.py @@ -6,10 +6,10 @@ from datasets import Dataset, DatasetDict model_name = "google-t5/t5-large" -print("Loading model: T5 MLM") +print("Loading model: T5 MLM label") tokenizer = T5Tokenizer.from_pretrained(model_name) model = T5ForConditionalGeneration.from_pretrained(model_name) -print("Finished loading model: T5 MLM") +print("Finished loading model: T5 MLM label") def classify_entity(sentence, entity, labels): sentence_with_masked_hypothesis = f"{sentence} {entity} is a <extra_id_0>." -- GitLab