From 1d518ef1cee1fd69ad7ffc78c55f46d7fad8cd61 Mon Sep 17 00:00:00 2001
From: kupper <kupper@login.cl.uni-heidelberg.de>
Date: Thu, 27 Mar 2025 16:23:19 +0100
Subject: [PATCH] Test/Train split and finetuning config

---
 data/data_manager.py                           | 18 +++++++++++++-----
 src/experiments/finetune_T5/finetune_T5_MLM.py |  6 +++---
 src/experiments/finetune_T5/finetune_T5_NLI.py |  7 ++++---
 src/models/T5_MLM_label.py                     |  6 +++---
 src/models/T5_NLI.py                           |  8 ++++----
 5 files changed, 27 insertions(+), 18 deletions(-)

diff --git a/data/data_manager.py b/data/data_manager.py
index e218dd9..29406c5 100644
--- a/data/data_manager.py
+++ b/data/data_manager.py
@@ -30,23 +30,30 @@ def get_labels(dataset):
         raise Exception
 
 
-def get_annotated_sentences(dataset, test_instances=10):
+def get_annotated_sentences(dataset, test_instances=10, train_test_split=1.0, train=False):
     """
     Returns a list of annotated sentences reproducibly picked at random from the dataset.
     """
     random.seed(42)
 
     if dataset == "CoNLL":
-        return random.sample(conll_interface.get_annotated_sentences(), min(test_instances, len(conll_interface.get_annotated_sentences())))
+        sentences = random.sample(conll_interface.get_annotated_sentences(), min(test_instances, len(conll_interface.get_annotated_sentences())))
     elif dataset == "Pile-NER-type":
-        return random.sample(pile_interface.get_annotated_sentences(), min(test_instances, len(pile_interface.get_annotated_sentences())))
+        sentences =  random.sample(pile_interface.get_annotated_sentences(), min(test_instances, len(pile_interface.get_annotated_sentences())))
     elif dataset == "FIGER-coarse":
-        return random.sample(figer_interface.get_annotated_sentences_coarse(), min(test_instances, len(figer_interface.get_annotated_sentences_coarse())))
+        sentences = random.sample(figer_interface.get_annotated_sentences_coarse(), min(test_instances, len(figer_interface.get_annotated_sentences_coarse())))
     elif dataset == "FIGER-fine":
-        return random.sample(figer_interface.get_annotated_sentences_fine(), min(test_instances, len(figer_interface.get_annotated_sentences_fine())))
+        sentences =  random.sample(figer_interface.get_annotated_sentences_fine(), min(test_instances, len(figer_interface.get_annotated_sentences_fine())))
     else:
         raise Exception
     
+    split_pos = max(0, min(len(sentences), int(train_test_split * len(sentences))))
+
+    if train:
+        return sentences[split_pos:]
+    else:
+        return sentences[:split_pos]
+    
 def get_label_dict(dataset, test_instances=10):
     """
     Returns a dictionary of example entities for each label.
@@ -63,3 +70,4 @@ def get_label_dict(dataset, test_instances=10):
             label_dict[labeled_entity[1]].add(labeled_entity[0])
     print(f"Label dict: {label_dict}")
     return label_dict
+
diff --git a/src/experiments/finetune_T5/finetune_T5_MLM.py b/src/experiments/finetune_T5/finetune_T5_MLM.py
index e70bfc6..f244ed0 100644
--- a/src/experiments/finetune_T5/finetune_T5_MLM.py
+++ b/src/experiments/finetune_T5/finetune_T5_MLM.py
@@ -3,7 +3,7 @@ from src.models.T5_MLM_label import finetune_model
 
 def finetune_t5(dataset):
     print("start")
-    annotated_sentences = data_manager.get_annotated_sentences(dataset, 1000)
+    annotated_sentences = data_manager.get_annotated_sentences(dataset, 2000, train_test_split=0.5, train=True)
     labels = data_manager.get_labels(dataset)
 
     sentences = []
@@ -23,8 +23,8 @@ def finetune_t5(dataset):
     for i in range(min(len(sentences), 50)):
         print(f"sentence: {sentences[i]}, entity: {entities[i]}, label: {labels[i]}")
 
-    epochs = 20
+    epochs = 150
 
     finetune_model(sentences, entities, labels, output_dir=f"./src/models/t5_mlm_finetuned_model/pretrained_{dataset}_epoch{epochs}", epochs=epochs)
 
-finetune_t5("CoNLL")
+finetune_t5("FIGER-coarse")
diff --git a/src/experiments/finetune_T5/finetune_T5_NLI.py b/src/experiments/finetune_T5/finetune_T5_NLI.py
index b2a6927..994f425 100644
--- a/src/experiments/finetune_T5/finetune_T5_NLI.py
+++ b/src/experiments/finetune_T5/finetune_T5_NLI.py
@@ -3,7 +3,7 @@ from src.models.T5_NLI import finetune_model
 
 def finetune_t5(dataset):
     print("start")
-    annotated_sentences = data_manager.get_annotated_sentences(dataset, 1000)
+    annotated_sentences = data_manager.get_annotated_sentences(dataset, 2000, train_test_split=0.5, train=True)
     labels = data_manager.get_labels(dataset)
 
     premises = []
@@ -24,8 +24,9 @@ def finetune_t5(dataset):
 
     for i in range(min(len(premises), 50)):
         print(f"premise: {premises[i]}, hypothesis: {hypotheses[i]}, entailment: {entailment[i]}")
-    epochs = 20
+
+    epochs = 150
 
     finetune_model(premises, hypotheses, entailment, output_dir=f"./src/models/t5_nli_finetuned_model/pretrained_{dataset}_epoch{epochs}", epochs=epochs)
 
-finetune_t5("CoNLL")
+finetune_t5("FIGER-coarse")
diff --git a/src/models/T5_MLM_label.py b/src/models/T5_MLM_label.py
index c0711ad..279416f 100644
--- a/src/models/T5_MLM_label.py
+++ b/src/models/T5_MLM_label.py
@@ -4,7 +4,7 @@ from torch.nn.functional import softmax
 from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq
 from datasets import Dataset, DatasetDict
 
-model_name = "google-t5/t5-base"
+model_name = "google-t5/t5-large"
 
 print("Loading model: T5 MLM")
 tokenizer = T5Tokenizer.from_pretrained(model_name)
@@ -62,14 +62,14 @@ def finetune_model(sentences, entities, labels, output_dir, epochs=10):
     dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
 
     training_args = TrainingArguments(
-        output_dir="./src/models/t5_nli_finetuned_model",
+        output_dir="./src/models/t5_nli_finetuned_model/checkpoints/",
         eval_strategy="epoch",
         learning_rate=5e-5,
         per_device_train_batch_size=8,
         per_device_eval_batch_size=8,
         num_train_epochs=epochs,
         weight_decay=0.01,
-        save_strategy="no",
+        save_strategy="epoch",
         push_to_hub=False,
         logging_dir="./logs",
     )
diff --git a/src/models/T5_NLI.py b/src/models/T5_NLI.py
index f24058a..d062c7b 100644
--- a/src/models/T5_NLI.py
+++ b/src/models/T5_NLI.py
@@ -5,8 +5,8 @@ from datasets import Dataset, DatasetDict
 
 label_map = {True: "entailment", False: "contradiction"}
 
-# Use t5-base for testing because it is smaller
-model_name = "google-t5/t5-base"
+# Use t5-large for testing because it is smaller
+model_name = "google-t5/t5-large"
 
 # model_name = "google/t5_xxl_true_nli_mixture"
 
@@ -110,14 +110,14 @@ def finetune_model(premises, hypotheses, entailment, output_dir, epochs=10):
     tokenized_datasets.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
 
     training_args = TrainingArguments(
-        output_dir="./src/models/t5_nli_finetuned_model",
+        output_dir="./src/models/t5_nli_finetuned_model/checkpoints/",
         eval_strategy="epoch",
         learning_rate=5e-5,
         per_device_train_batch_size=8,
         per_device_eval_batch_size=8,
         num_train_epochs=epochs,
         weight_decay=0.01,
-        save_strategy="no",
+        save_strategy="epoch",
         push_to_hub=False,
         logging_dir="./logs",
     )
-- 
GitLab