Skip to content
Snippets Groups Projects
Unverified Commit fde56af7 authored by JulianFP's avatar JulianFP
Browse files

Add T5 MLM with entity masking, both classification and finetuning code

parent 1d518ef1
No related branches found
No related tags found
No related merge requests found
...@@ -116,3 +116,6 @@ train.json ...@@ -116,3 +116,6 @@ train.json
# slurm output # slurm output
logs logs
# evaluation output
results
...@@ -2,29 +2,15 @@ ...@@ -2,29 +2,15 @@
let let
myPythonPackages = ps: with ps; [ myPythonPackages = ps: with ps; [
pip pip
python-dotenv
numpy
ollama
datasets
protobuf
requests
tqdm
#transformers
#gensim
#accelerate
#gliner missing
]; ];
fhs = pkgs.buildFHSEnv { in pkgs.mkShell {
name = "Python NER environment with pip"; buildInputs = with pkgs; [
(python3.withPackages myPythonPackages)
targetPkgs = _: [ ];
(pkgs.python3.withPackages myPythonPackages) shellHook = ''
]; export LD_LIBRARY_PATH=$NIX_LD_LIBRARY_PATH
python -m venv .venv
profile ='' source .venv/bin/activate
python -m venv .venv zsh
source .venv/bin/activate '';
zsh }
'';
};
in fhs.env
...@@ -7,6 +7,8 @@ from src.models.GLiNER import find_entities as find_entities_gliner ...@@ -7,6 +7,8 @@ from src.models.GLiNER import find_entities as find_entities_gliner
from src.models.GLiNER import classify_entity as classify_entity_gliner from src.models.GLiNER import classify_entity as classify_entity_gliner
from src.models.T5_NLI import classify_entity as classify_entity_t5_nli from src.models.T5_NLI import classify_entity as classify_entity_t5_nli
from src.models.T5_MLM_label import classify_entity as classify_entity_t5_mlm_label from src.models.T5_MLM_label import classify_entity as classify_entity_t5_mlm_label
from src.models.T5_MLM_entity import classify_entity as classify_entity_t5_mlm_entity
from src.models.T5_MLM_entity import set_label_dict as set_label_dict_t5_mlm_entity
from src.models.Word2Vec import classify_entity as classify_entity_word2vec from src.models.Word2Vec import classify_entity as classify_entity_word2vec
from src.models.Word2Vec import set_label_dict as set_label_dict_word2vec from src.models.Word2Vec import set_label_dict as set_label_dict_word2vec
from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities as find_entities_llm from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities as find_entities_llm
...@@ -21,6 +23,8 @@ def classify_entity(model_name, sentence, entity, labels): ...@@ -21,6 +23,8 @@ def classify_entity(model_name, sentence, entity, labels):
return classify_entity_t5_nli(sentence, entity, labels) return classify_entity_t5_nli(sentence, entity, labels)
elif model_name == "T5-MLM-label": elif model_name == "T5-MLM-label":
return classify_entity_t5_mlm_label(sentence, entity, labels) return classify_entity_t5_mlm_label(sentence, entity, labels)
elif model_name == "T5-MLM-entity":
return classify_entity_t5_mlm_entity(sentence, entity, labels)
elif model_name == "GLiNER": elif model_name == "GLiNER":
return classify_entity_gliner(sentence, entity, labels) return classify_entity_gliner(sentence, entity, labels)
elif model_name == "Word2Vec": elif model_name == "Word2Vec":
...@@ -49,5 +53,7 @@ def set_label_dict(model_name, label_dict): ...@@ -49,5 +53,7 @@ def set_label_dict(model_name, label_dict):
""" """
if model_name == "Word2Vec": if model_name == "Word2Vec":
return set_label_dict_word2vec(label_dict) return set_label_dict_word2vec(label_dict)
elif model_name == "T5-MLM-entity":
return set_label_dict_t5_mlm_entity(label_dict)
else: else:
print(f"set_label_dict not implemented for {model_name}") print(f"set_label_dict not implemented for {model_name}")
...@@ -27,7 +27,7 @@ def run_NEC_tests(model_name, dataset, results_dir, test_instances=10): ...@@ -27,7 +27,7 @@ def run_NEC_tests(model_name, dataset, results_dir, test_instances=10):
labels = data_manager.get_labels(dataset) labels = data_manager.get_labels(dataset)
data = data_manager.get_annotated_sentences(dataset, test_instances) data = data_manager.get_annotated_sentences(dataset, test_instances)
if (model_name == "Word2Vec"): if (model_name == "Word2Vec" or model_name == "T5-MLM-entity"):
label_dict = data_manager.get_label_dict(dataset, test_instances) label_dict = data_manager.get_label_dict(dataset, test_instances)
print(label_dict) print(label_dict)
set_label_dict(model_name, label_dict) set_label_dict(model_name, label_dict)
...@@ -58,7 +58,7 @@ def run_NEC_tests(model_name, dataset, results_dir, test_instances=10): ...@@ -58,7 +58,7 @@ def run_NEC_tests(model_name, dataset, results_dir, test_instances=10):
def run_NEC_tests_all(): def run_NEC_tests_all():
models = ["GLiNER", "Llama-3.1-8B", "T5-NLI", "T5-MLM-label", "Word2Vec", "DeepSeek-R1-Distill-Qwen-32B"] models = ["GLiNER", "Llama-3.1-8B", "T5-NLI", "T5-MLM-label", "T5-MLM-entity", "Word2Vec", "DeepSeek-R1-Distill-Qwen-32B"]
datasets = ["CoNLL", "FIGER-coarse", "FIGER-fine"] # "Pile-NER-type"] datasets = ["CoNLL", "FIGER-coarse", "FIGER-fine"] # "Pile-NER-type"]
for model in models: for model in models:
for dataset in datasets: for dataset in datasets:
......
import data.data_manager as data_manager
from src.models.T5_MLM_entity import finetune_model, set_label_dict
def finetune_t5(dataset):
print("start")
annotated_sentences = data_manager.get_annotated_sentences(dataset, 1000)
labels = data_manager.get_labels(dataset)
sentences = []
entities = []
labels = []
for annotated_sentence in annotated_sentences:
sentence = annotated_sentence[0]
for annotation in annotated_sentence[1]:
sentences.append(sentence)
entities.append(annotation[0])
labels.append(annotation[1])
print(f"Finetuning on {len(sentences)} examples")
for i in range(min(len(sentences), 50)):
print(f"sentence: {sentences[i]}, entity: {entities[i]}, label: {labels[i]}")
epochs = 20
label_dict = data_manager.get_label_dict(dataset, 1000)
set_label_dict(label_dict)
finetune_model(sentences, entities, labels, output_dir=f"./src/models/t5_mlm_entity_finetuned_model/pretrained_{dataset}_epoch{epochs}", epochs=epochs)
finetune_t5("CoNLL")
import torch
import random
import numpy as np
from torch.nn.functional import softmax
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq
from datasets import Dataset, DatasetDict
model_name = "google-t5/t5-base"
print("Loading model: T5 MLM entity")
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
print("Finished loading model: T5 MLM entity")
def set_label_dict(label_dict):
global label_representatives
label_representatives = label_dict
def classify_entity(sentence, entity, labels):
global label_representatives
masked_sentence = sentence.replace(entity, "<extra_id_0>", 1)
input_ids = tokenizer(masked_sentence, return_tensors="pt").input_ids
best_label = None
best_label_loss = float("inf")
for label in labels:
best_rep_loss = float("inf")
representatives = label_representatives[label]
for representative in representatives:
# do not cheat if entity is one of the representatives already
if entity == representative:
continue
representative_ids = tokenizer(f"<extra_id_0> {representative} <extra_id_1>", return_tensors="pt").input_ids
loss = model(input_ids=input_ids, labels=representative_ids).loss.item()
best_rep_loss = min(best_rep_loss, loss)
if best_rep_loss < best_label_loss:
best_label_loss = best_rep_loss
best_label = label
return best_label if best_label else labels[0]
def finetune_model(sentences, entities, labels, output_dir, epochs=10):
input_texts = []
target_texts = []
random.seed(42)
for i in range(len(sentences)):
sentence = sentences[i]
entity = entities[i]
label = labels[i]
representatives = label_representatives[label]
representative = random.choice(list(representatives))
input_texts.append(sentence.replace(entity, "<extra_id_0>", 1))
target_texts.append(f"<extra_id_0> {representative} <extra_id_1>")
model_input = tokenizer(input_texts, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
targets = tokenizer(target_texts, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
model_input["input_ids"] = np.array(model_input["input_ids"])
model_input["attention_mask"] = np.array(model_input["attention_mask"])
model_input["labels"] = np.array(targets["input_ids"])
dataset = Dataset.from_dict({
"input_ids": model_input["input_ids"],
"attention_mask": model_input["attention_mask"],
"labels": model_input["labels"]
})
print(dataset)
# split into training and validation data (20% used for validation)
train_test_split = dataset.train_test_split(test_size=0.2, shuffle=True, seed=0)
dataset = DatasetDict({
"train": train_test_split["train"],
"validation": train_test_split["test"]
})
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
training_args = TrainingArguments(
output_dir="./src/models/t5_nli_finetuned_model",
eval_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=epochs,
weight_decay=0.01,
save_strategy="no",
push_to_hub=False,
logging_dir="./logs",
)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
tokenizer=tokenizer,
data_collator=data_collator,
)
trainer.train()
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
...@@ -6,10 +6,10 @@ from datasets import Dataset, DatasetDict ...@@ -6,10 +6,10 @@ from datasets import Dataset, DatasetDict
model_name = "google-t5/t5-large" model_name = "google-t5/t5-large"
print("Loading model: T5 MLM") print("Loading model: T5 MLM label")
tokenizer = T5Tokenizer.from_pretrained(model_name) tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name) model = T5ForConditionalGeneration.from_pretrained(model_name)
print("Finished loading model: T5 MLM") print("Finished loading model: T5 MLM label")
def classify_entity(sentence, entity, labels): def classify_entity(sentence, entity, labels):
sentence_with_masked_hypothesis = f"{sentence} {entity} is a <extra_id_0>." sentence_with_masked_hypothesis = f"{sentence} {entity} is a <extra_id_0>."
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment