Skip to content
Snippets Groups Projects
Commit 1d518ef1 authored by kupper's avatar kupper
Browse files

Test/Train split and finetuning config

parent 9282b007
No related branches found
No related tags found
No related merge requests found
......@@ -30,23 +30,30 @@ def get_labels(dataset):
raise Exception
def get_annotated_sentences(dataset, test_instances=10):
def get_annotated_sentences(dataset, test_instances=10, train_test_split=1.0, train=False):
"""
Returns a list of annotated sentences reproducibly picked at random from the dataset.
"""
random.seed(42)
if dataset == "CoNLL":
return random.sample(conll_interface.get_annotated_sentences(), min(test_instances, len(conll_interface.get_annotated_sentences())))
sentences = random.sample(conll_interface.get_annotated_sentences(), min(test_instances, len(conll_interface.get_annotated_sentences())))
elif dataset == "Pile-NER-type":
return random.sample(pile_interface.get_annotated_sentences(), min(test_instances, len(pile_interface.get_annotated_sentences())))
sentences = random.sample(pile_interface.get_annotated_sentences(), min(test_instances, len(pile_interface.get_annotated_sentences())))
elif dataset == "FIGER-coarse":
return random.sample(figer_interface.get_annotated_sentences_coarse(), min(test_instances, len(figer_interface.get_annotated_sentences_coarse())))
sentences = random.sample(figer_interface.get_annotated_sentences_coarse(), min(test_instances, len(figer_interface.get_annotated_sentences_coarse())))
elif dataset == "FIGER-fine":
return random.sample(figer_interface.get_annotated_sentences_fine(), min(test_instances, len(figer_interface.get_annotated_sentences_fine())))
sentences = random.sample(figer_interface.get_annotated_sentences_fine(), min(test_instances, len(figer_interface.get_annotated_sentences_fine())))
else:
raise Exception
split_pos = max(0, min(len(sentences), int(train_test_split * len(sentences))))
if train:
return sentences[split_pos:]
else:
return sentences[:split_pos]
def get_label_dict(dataset, test_instances=10):
"""
Returns a dictionary of example entities for each label.
......@@ -63,3 +70,4 @@ def get_label_dict(dataset, test_instances=10):
label_dict[labeled_entity[1]].add(labeled_entity[0])
print(f"Label dict: {label_dict}")
return label_dict
......@@ -3,7 +3,7 @@ from src.models.T5_MLM_label import finetune_model
def finetune_t5(dataset):
print("start")
annotated_sentences = data_manager.get_annotated_sentences(dataset, 1000)
annotated_sentences = data_manager.get_annotated_sentences(dataset, 2000, train_test_split=0.5, train=True)
labels = data_manager.get_labels(dataset)
sentences = []
......@@ -23,8 +23,8 @@ def finetune_t5(dataset):
for i in range(min(len(sentences), 50)):
print(f"sentence: {sentences[i]}, entity: {entities[i]}, label: {labels[i]}")
epochs = 20
epochs = 150
finetune_model(sentences, entities, labels, output_dir=f"./src/models/t5_mlm_finetuned_model/pretrained_{dataset}_epoch{epochs}", epochs=epochs)
finetune_t5("CoNLL")
finetune_t5("FIGER-coarse")
......@@ -3,7 +3,7 @@ from src.models.T5_NLI import finetune_model
def finetune_t5(dataset):
print("start")
annotated_sentences = data_manager.get_annotated_sentences(dataset, 1000)
annotated_sentences = data_manager.get_annotated_sentences(dataset, 2000, train_test_split=0.5, train=True)
labels = data_manager.get_labels(dataset)
premises = []
......@@ -24,8 +24,9 @@ def finetune_t5(dataset):
for i in range(min(len(premises), 50)):
print(f"premise: {premises[i]}, hypothesis: {hypotheses[i]}, entailment: {entailment[i]}")
epochs = 20
epochs = 150
finetune_model(premises, hypotheses, entailment, output_dir=f"./src/models/t5_nli_finetuned_model/pretrained_{dataset}_epoch{epochs}", epochs=epochs)
finetune_t5("CoNLL")
finetune_t5("FIGER-coarse")
......@@ -4,7 +4,7 @@ from torch.nn.functional import softmax
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq
from datasets import Dataset, DatasetDict
model_name = "google-t5/t5-base"
model_name = "google-t5/t5-large"
print("Loading model: T5 MLM")
tokenizer = T5Tokenizer.from_pretrained(model_name)
......@@ -62,14 +62,14 @@ def finetune_model(sentences, entities, labels, output_dir, epochs=10):
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
training_args = TrainingArguments(
output_dir="./src/models/t5_nli_finetuned_model",
output_dir="./src/models/t5_nli_finetuned_model/checkpoints/",
eval_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=epochs,
weight_decay=0.01,
save_strategy="no",
save_strategy="epoch",
push_to_hub=False,
logging_dir="./logs",
)
......
......@@ -5,8 +5,8 @@ from datasets import Dataset, DatasetDict
label_map = {True: "entailment", False: "contradiction"}
# Use t5-base for testing because it is smaller
model_name = "google-t5/t5-base"
# Use t5-large for testing because it is smaller
model_name = "google-t5/t5-large"
# model_name = "google/t5_xxl_true_nli_mixture"
......@@ -110,14 +110,14 @@ def finetune_model(premises, hypotheses, entailment, output_dir, epochs=10):
tokenized_datasets.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
training_args = TrainingArguments(
output_dir="./src/models/t5_nli_finetuned_model",
output_dir="./src/models/t5_nli_finetuned_model/checkpoints/",
eval_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=epochs,
weight_decay=0.01,
save_strategy="no",
save_strategy="epoch",
push_to_hub=False,
logging_dir="./logs",
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment