From 7bf8383faab67eeadd342a9020588e46e1c45e0a Mon Sep 17 00:00:00 2001
From: JulianFP <julian@partanengroup.de>
Date: Thu, 27 Mar 2025 18:17:09 +0000
Subject: [PATCH] Apply finetuning config made previously by kupper to MLM
 entity as well

---
 src/experiments/finetune_T5/finetune_T5_MLM_entity.py | 6 +++---
 src/models/T5_MLM_entity.py                           | 6 +++---
 2 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/src/experiments/finetune_T5/finetune_T5_MLM_entity.py b/src/experiments/finetune_T5/finetune_T5_MLM_entity.py
index c7ac094..3366c8e 100644
--- a/src/experiments/finetune_T5/finetune_T5_MLM_entity.py
+++ b/src/experiments/finetune_T5/finetune_T5_MLM_entity.py
@@ -3,7 +3,7 @@ from src.models.T5_MLM_entity import finetune_model, set_label_dict
 
 def finetune_t5(dataset):
     print("start")
-    annotated_sentences = data_manager.get_annotated_sentences(dataset, 1000)
+    annotated_sentences = data_manager.get_annotated_sentences(dataset, 2000, train_test_split=0.5, train=True)
     labels = data_manager.get_labels(dataset)
 
     sentences = []
@@ -23,11 +23,11 @@ def finetune_t5(dataset):
     for i in range(min(len(sentences), 50)):
         print(f"sentence: {sentences[i]}, entity: {entities[i]}, label: {labels[i]}")
 
-    epochs = 20
+    epochs = 150
 
 
     label_dict = data_manager.get_label_dict(dataset, 1000)
     set_label_dict(label_dict)
     finetune_model(sentences, entities, labels, output_dir=f"./src/models/t5_mlm_entity_finetuned_model/pretrained_{dataset}_epoch{epochs}", epochs=epochs)
 
-finetune_t5("CoNLL")
+finetune_t5("FIGER-coarse")
diff --git a/src/models/T5_MLM_entity.py b/src/models/T5_MLM_entity.py
index f9604cb..f07add7 100644
--- a/src/models/T5_MLM_entity.py
+++ b/src/models/T5_MLM_entity.py
@@ -5,7 +5,7 @@ from torch.nn.functional import softmax
 from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq
 from datasets import Dataset, DatasetDict
 
-model_name = "google-t5/t5-base"
+model_name = "google-t5/t5-large"
 
 print("Loading model: T5 MLM entity")
 tokenizer = T5Tokenizer.from_pretrained(model_name)
@@ -86,14 +86,14 @@ def finetune_model(sentences, entities, labels, output_dir, epochs=10):
     dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
 
     training_args = TrainingArguments(
-        output_dir="./src/models/t5_nli_finetuned_model",
+        output_dir="./src/models/t5_nli_finetuned_model/checkpoints/",
         eval_strategy="epoch",
         learning_rate=5e-5,
         per_device_train_batch_size=8,
         per_device_eval_batch_size=8,
         num_train_epochs=epochs,
         weight_decay=0.01,
-        save_strategy="no",
+        save_strategy="epoch",
         push_to_hub=False,
         logging_dir="./logs",
     )
-- 
GitLab