Skip to content
Snippets Groups Projects
Commit 56ea896e authored by kupper's avatar kupper
Browse files

Rename T5 to T5-NLI and save finetuned model

parent d641f769
No related branches found
No related tags found
No related merge requests found
......@@ -110,3 +110,9 @@ venv.bak/
*.data
*.pickle
train.json
# finetuned
*_finetuned_model
# slurm output
logs
#!/bin/bash
#SBATCH --job-name=finetune_T5
#SBATCH --output=logs/finetune_T5_%j.txt
#SBATCH --ntasks=1
#SBATCH --time=1:00:00
#SBATCH --mem=8000
#SBATCH --mail-type=ALL
#SBATCH --mail-user=kupper@cl.uni-heidelberg.de
#SBATCH --partition=students
#SBATCH --cpus-per-task=4
#SBATCH --qos=batch
#SBATCH --gres=gpu
export PYTHONUNBUFFERED=1
srun python3 -m src.experiments.finetune_T5
......@@ -5,7 +5,7 @@ Makes evaluating models easier.
from src.models.llms_interface import available_models as llms
from src.models.GLiNER import find_entities as find_entities_gliner
from src.models.GLiNER import classify_entity as classify_entity_gliner
from src.models.T5 import classify_entity as classify_entity_t5
from src.models.T5_NLI import classify_entity as classify_entity_t5_nli
from src.models.T5_MLM_label import classify_entity as classify_entity_t5_mlm_label
from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities as find_entities_llm
......@@ -14,8 +14,8 @@ def classify_entity(model_name, sentence, entity, labels):
"""
NEC. Returns label (string) for entity.
"""
if model_name == "T5":
return classify_entity_t5(sentence, entity, labels)
if model_name == "T5-NLI":
return classify_entity_t5_nli(sentence, entity, labels)
elif model_name == "T5-MLM-label":
return classify_entity_t5_mlm_label(sentence, entity, labels)
elif model_name == "GLiNER":
......
import data.data_manager as data_manager
from src.models.T5 import finetune_model
from src.models.T5_NLI import finetune_model
def finetune_t5(dataset):
print("start")
annotated_sentences = data_manager.get_annotated_sentences(dataset)
annotated_sentences = data_manager.get_annotated_sentences(dataset, 1000)
labels = data_manager.get_labels(dataset)
premises = []
......@@ -11,7 +11,7 @@ def finetune_t5(dataset):
entailment = []
for annotated_sentence in annotated_sentences:
sentence = annotated_sentence[0][0]
sentence = annotated_sentence[0]
for annotation in annotated_sentence[1]:
entity = annotation[0]
......@@ -20,9 +20,12 @@ def finetune_t5(dataset):
hypotheses.append(f"{entity} is a {label}")
entailment.append(label == annotation[1])
print(f"Finetuning on {len(premises)} examples")
for i in range(min(len(premises), 50)):
print(f"premise: {premises[i]}, hypothesis: {hypotheses[i]}, entailment: {entailment[i]}")
epochs = 20
finetune_model(premises, hypotheses, entailment)
finetune_model(premises, hypotheses, entailment, output_dir=f"./src/models/t5_nli_finetuned_model/pretrained_{dataset}_epoch{epochs}", epochs=epochs)
finetune_t5("CoNLL")
......@@ -6,13 +6,26 @@ from datasets import Dataset, DatasetDict
label_map = {True: "entailment", False: "contradiction"}
# Use t5-base for testing because it is smaller
model_name = "google-t5/t5-base" # google/t5_xxl_true_nli_mixture
print("Loading model: T5 NLI")
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
print("Finished loading model: T5 NLI")
model_name = "google-t5/t5-base"
# model_name = "google/t5_xxl_true_nli_mixture"
def load_base():
global model
global tokenizer
print("Loading model: T5 NLI")
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
print("Finished loading model: T5 NLI")
def load_finetuned(input_dir):
global model
global tokenizer
print(f"Loading model: T5 NLI finetuned ({input_dir})")
tokenizer = T5Tokenizer.from_pretrained(input_dir)
model = T5ForConditionalGeneration.from_pretrained(input_dir)
print(f"Finished loading model: T5 NLI finetuned")
def infer_nli(premise, hypothesis):
input_text = f"nli hypothesis: {hypothesis} premise: {premise}"
......@@ -77,7 +90,7 @@ def preprocess_data(sample):
return tokenized_input
def finetune_model(premises, hypotheses, entailment):
def finetune_model(premises, hypotheses, entailment, output_dir, epochs=10):
# TODO: should we use dataset on a higher level as well?
data_dict = {"premise": premises, "hypothesis": hypotheses, "entailment": entailment}
dataset = Dataset.from_dict(data_dict)
......@@ -97,14 +110,14 @@ def finetune_model(premises, hypotheses, entailment):
tokenized_datasets.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
training_args = TrainingArguments(
output_dir="./t5_nli_finetuned",
evaluation_strategy="epoch",
output_dir="./src/models/t5_nli_finetuned_model",
eval_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
num_train_epochs=3,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=epochs,
weight_decay=0.01,
save_strategy="epoch",
save_strategy="no",
push_to_hub=False,
logging_dir="./logs",
)
......@@ -121,3 +134,8 @@ def finetune_model(premises, hypotheses, entailment):
)
trainer.train()
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
load_base()
from src.common_interface import classify_entity
tested_models = ["GLiNER", "T5", "T5-MLM-label"]
tested_models = ["GLiNER", "T5-NLI", "T5-MLM-label"]
test_sentence = "Barack Obama was the president of the United States."
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment