From 1d518ef1cee1fd69ad7ffc78c55f46d7fad8cd61 Mon Sep 17 00:00:00 2001 From: kupper <kupper@login.cl.uni-heidelberg.de> Date: Thu, 27 Mar 2025 16:23:19 +0100 Subject: [PATCH] Test/Train split and finetuning config --- data/data_manager.py | 18 +++++++++++++----- src/experiments/finetune_T5/finetune_T5_MLM.py | 6 +++--- src/experiments/finetune_T5/finetune_T5_NLI.py | 7 ++++--- src/models/T5_MLM_label.py | 6 +++--- src/models/T5_NLI.py | 8 ++++---- 5 files changed, 27 insertions(+), 18 deletions(-) diff --git a/data/data_manager.py b/data/data_manager.py index e218dd9..29406c5 100644 --- a/data/data_manager.py +++ b/data/data_manager.py @@ -30,23 +30,30 @@ def get_labels(dataset): raise Exception -def get_annotated_sentences(dataset, test_instances=10): +def get_annotated_sentences(dataset, test_instances=10, train_test_split=1.0, train=False): """ Returns a list of annotated sentences reproducibly picked at random from the dataset. """ random.seed(42) if dataset == "CoNLL": - return random.sample(conll_interface.get_annotated_sentences(), min(test_instances, len(conll_interface.get_annotated_sentences()))) + sentences = random.sample(conll_interface.get_annotated_sentences(), min(test_instances, len(conll_interface.get_annotated_sentences()))) elif dataset == "Pile-NER-type": - return random.sample(pile_interface.get_annotated_sentences(), min(test_instances, len(pile_interface.get_annotated_sentences()))) + sentences = random.sample(pile_interface.get_annotated_sentences(), min(test_instances, len(pile_interface.get_annotated_sentences()))) elif dataset == "FIGER-coarse": - return random.sample(figer_interface.get_annotated_sentences_coarse(), min(test_instances, len(figer_interface.get_annotated_sentences_coarse()))) + sentences = random.sample(figer_interface.get_annotated_sentences_coarse(), min(test_instances, len(figer_interface.get_annotated_sentences_coarse()))) elif dataset == "FIGER-fine": - return random.sample(figer_interface.get_annotated_sentences_fine(), min(test_instances, len(figer_interface.get_annotated_sentences_fine()))) + sentences = random.sample(figer_interface.get_annotated_sentences_fine(), min(test_instances, len(figer_interface.get_annotated_sentences_fine()))) else: raise Exception + split_pos = max(0, min(len(sentences), int(train_test_split * len(sentences)))) + + if train: + return sentences[split_pos:] + else: + return sentences[:split_pos] + def get_label_dict(dataset, test_instances=10): """ Returns a dictionary of example entities for each label. @@ -63,3 +70,4 @@ def get_label_dict(dataset, test_instances=10): label_dict[labeled_entity[1]].add(labeled_entity[0]) print(f"Label dict: {label_dict}") return label_dict + diff --git a/src/experiments/finetune_T5/finetune_T5_MLM.py b/src/experiments/finetune_T5/finetune_T5_MLM.py index e70bfc6..f244ed0 100644 --- a/src/experiments/finetune_T5/finetune_T5_MLM.py +++ b/src/experiments/finetune_T5/finetune_T5_MLM.py @@ -3,7 +3,7 @@ from src.models.T5_MLM_label import finetune_model def finetune_t5(dataset): print("start") - annotated_sentences = data_manager.get_annotated_sentences(dataset, 1000) + annotated_sentences = data_manager.get_annotated_sentences(dataset, 2000, train_test_split=0.5, train=True) labels = data_manager.get_labels(dataset) sentences = [] @@ -23,8 +23,8 @@ def finetune_t5(dataset): for i in range(min(len(sentences), 50)): print(f"sentence: {sentences[i]}, entity: {entities[i]}, label: {labels[i]}") - epochs = 20 + epochs = 150 finetune_model(sentences, entities, labels, output_dir=f"./src/models/t5_mlm_finetuned_model/pretrained_{dataset}_epoch{epochs}", epochs=epochs) -finetune_t5("CoNLL") +finetune_t5("FIGER-coarse") diff --git a/src/experiments/finetune_T5/finetune_T5_NLI.py b/src/experiments/finetune_T5/finetune_T5_NLI.py index b2a6927..994f425 100644 --- a/src/experiments/finetune_T5/finetune_T5_NLI.py +++ b/src/experiments/finetune_T5/finetune_T5_NLI.py @@ -3,7 +3,7 @@ from src.models.T5_NLI import finetune_model def finetune_t5(dataset): print("start") - annotated_sentences = data_manager.get_annotated_sentences(dataset, 1000) + annotated_sentences = data_manager.get_annotated_sentences(dataset, 2000, train_test_split=0.5, train=True) labels = data_manager.get_labels(dataset) premises = [] @@ -24,8 +24,9 @@ def finetune_t5(dataset): for i in range(min(len(premises), 50)): print(f"premise: {premises[i]}, hypothesis: {hypotheses[i]}, entailment: {entailment[i]}") - epochs = 20 + + epochs = 150 finetune_model(premises, hypotheses, entailment, output_dir=f"./src/models/t5_nli_finetuned_model/pretrained_{dataset}_epoch{epochs}", epochs=epochs) -finetune_t5("CoNLL") +finetune_t5("FIGER-coarse") diff --git a/src/models/T5_MLM_label.py b/src/models/T5_MLM_label.py index c0711ad..279416f 100644 --- a/src/models/T5_MLM_label.py +++ b/src/models/T5_MLM_label.py @@ -4,7 +4,7 @@ from torch.nn.functional import softmax from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq from datasets import Dataset, DatasetDict -model_name = "google-t5/t5-base" +model_name = "google-t5/t5-large" print("Loading model: T5 MLM") tokenizer = T5Tokenizer.from_pretrained(model_name) @@ -62,14 +62,14 @@ def finetune_model(sentences, entities, labels, output_dir, epochs=10): dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"]) training_args = TrainingArguments( - output_dir="./src/models/t5_nli_finetuned_model", + output_dir="./src/models/t5_nli_finetuned_model/checkpoints/", eval_strategy="epoch", learning_rate=5e-5, per_device_train_batch_size=8, per_device_eval_batch_size=8, num_train_epochs=epochs, weight_decay=0.01, - save_strategy="no", + save_strategy="epoch", push_to_hub=False, logging_dir="./logs", ) diff --git a/src/models/T5_NLI.py b/src/models/T5_NLI.py index f24058a..d062c7b 100644 --- a/src/models/T5_NLI.py +++ b/src/models/T5_NLI.py @@ -5,8 +5,8 @@ from datasets import Dataset, DatasetDict label_map = {True: "entailment", False: "contradiction"} -# Use t5-base for testing because it is smaller -model_name = "google-t5/t5-base" +# Use t5-large for testing because it is smaller +model_name = "google-t5/t5-large" # model_name = "google/t5_xxl_true_nli_mixture" @@ -110,14 +110,14 @@ def finetune_model(premises, hypotheses, entailment, output_dir, epochs=10): tokenized_datasets.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"]) training_args = TrainingArguments( - output_dir="./src/models/t5_nli_finetuned_model", + output_dir="./src/models/t5_nli_finetuned_model/checkpoints/", eval_strategy="epoch", learning_rate=5e-5, per_device_train_batch_size=8, per_device_eval_batch_size=8, num_train_epochs=epochs, weight_decay=0.01, - save_strategy="no", + save_strategy="epoch", push_to_hub=False, logging_dir="./logs", ) -- GitLab