diff --git a/src/experiments/finetune_T5/finetune_T5_MLM_entity.py b/src/experiments/finetune_T5/finetune_T5_MLM_entity.py index c7ac094afa25b2269c700b71a2949c05a8effbd5..3366c8ebe0574b46b0de71d83730f206a637d2fe 100644 --- a/src/experiments/finetune_T5/finetune_T5_MLM_entity.py +++ b/src/experiments/finetune_T5/finetune_T5_MLM_entity.py @@ -3,7 +3,7 @@ from src.models.T5_MLM_entity import finetune_model, set_label_dict def finetune_t5(dataset): print("start") - annotated_sentences = data_manager.get_annotated_sentences(dataset, 1000) + annotated_sentences = data_manager.get_annotated_sentences(dataset, 2000, train_test_split=0.5, train=True) labels = data_manager.get_labels(dataset) sentences = [] @@ -23,11 +23,11 @@ def finetune_t5(dataset): for i in range(min(len(sentences), 50)): print(f"sentence: {sentences[i]}, entity: {entities[i]}, label: {labels[i]}") - epochs = 20 + epochs = 150 label_dict = data_manager.get_label_dict(dataset, 1000) set_label_dict(label_dict) finetune_model(sentences, entities, labels, output_dir=f"./src/models/t5_mlm_entity_finetuned_model/pretrained_{dataset}_epoch{epochs}", epochs=epochs) -finetune_t5("CoNLL") +finetune_t5("FIGER-coarse") diff --git a/src/models/T5_MLM_entity.py b/src/models/T5_MLM_entity.py index f9604cbc343be57f4533f7d6620c6beed49dc933..f07add7716b90d0767ed0355c211a4869a6fdcf0 100644 --- a/src/models/T5_MLM_entity.py +++ b/src/models/T5_MLM_entity.py @@ -5,7 +5,7 @@ from torch.nn.functional import softmax from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq from datasets import Dataset, DatasetDict -model_name = "google-t5/t5-base" +model_name = "google-t5/t5-large" print("Loading model: T5 MLM entity") tokenizer = T5Tokenizer.from_pretrained(model_name) @@ -86,14 +86,14 @@ def finetune_model(sentences, entities, labels, output_dir, epochs=10): dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"]) training_args = TrainingArguments( - output_dir="./src/models/t5_nli_finetuned_model", + output_dir="./src/models/t5_nli_finetuned_model/checkpoints/", eval_strategy="epoch", learning_rate=5e-5, per_device_train_batch_size=8, per_device_eval_batch_size=8, num_train_epochs=epochs, weight_decay=0.01, - save_strategy="no", + save_strategy="epoch", push_to_hub=False, logging_dir="./logs", )