Skip to content
Snippets Groups Projects
Commit 1d518ef1 authored by kupper's avatar kupper
Browse files

Test/Train split and finetuning config

parent 9282b007
No related branches found
No related tags found
No related merge requests found
...@@ -30,23 +30,30 @@ def get_labels(dataset): ...@@ -30,23 +30,30 @@ def get_labels(dataset):
raise Exception raise Exception
def get_annotated_sentences(dataset, test_instances=10): def get_annotated_sentences(dataset, test_instances=10, train_test_split=1.0, train=False):
""" """
Returns a list of annotated sentences reproducibly picked at random from the dataset. Returns a list of annotated sentences reproducibly picked at random from the dataset.
""" """
random.seed(42) random.seed(42)
if dataset == "CoNLL": if dataset == "CoNLL":
return random.sample(conll_interface.get_annotated_sentences(), min(test_instances, len(conll_interface.get_annotated_sentences()))) sentences = random.sample(conll_interface.get_annotated_sentences(), min(test_instances, len(conll_interface.get_annotated_sentences())))
elif dataset == "Pile-NER-type": elif dataset == "Pile-NER-type":
return random.sample(pile_interface.get_annotated_sentences(), min(test_instances, len(pile_interface.get_annotated_sentences()))) sentences = random.sample(pile_interface.get_annotated_sentences(), min(test_instances, len(pile_interface.get_annotated_sentences())))
elif dataset == "FIGER-coarse": elif dataset == "FIGER-coarse":
return random.sample(figer_interface.get_annotated_sentences_coarse(), min(test_instances, len(figer_interface.get_annotated_sentences_coarse()))) sentences = random.sample(figer_interface.get_annotated_sentences_coarse(), min(test_instances, len(figer_interface.get_annotated_sentences_coarse())))
elif dataset == "FIGER-fine": elif dataset == "FIGER-fine":
return random.sample(figer_interface.get_annotated_sentences_fine(), min(test_instances, len(figer_interface.get_annotated_sentences_fine()))) sentences = random.sample(figer_interface.get_annotated_sentences_fine(), min(test_instances, len(figer_interface.get_annotated_sentences_fine())))
else: else:
raise Exception raise Exception
split_pos = max(0, min(len(sentences), int(train_test_split * len(sentences))))
if train:
return sentences[split_pos:]
else:
return sentences[:split_pos]
def get_label_dict(dataset, test_instances=10): def get_label_dict(dataset, test_instances=10):
""" """
Returns a dictionary of example entities for each label. Returns a dictionary of example entities for each label.
...@@ -63,3 +70,4 @@ def get_label_dict(dataset, test_instances=10): ...@@ -63,3 +70,4 @@ def get_label_dict(dataset, test_instances=10):
label_dict[labeled_entity[1]].add(labeled_entity[0]) label_dict[labeled_entity[1]].add(labeled_entity[0])
print(f"Label dict: {label_dict}") print(f"Label dict: {label_dict}")
return label_dict return label_dict
...@@ -3,7 +3,7 @@ from src.models.T5_MLM_label import finetune_model ...@@ -3,7 +3,7 @@ from src.models.T5_MLM_label import finetune_model
def finetune_t5(dataset): def finetune_t5(dataset):
print("start") print("start")
annotated_sentences = data_manager.get_annotated_sentences(dataset, 1000) annotated_sentences = data_manager.get_annotated_sentences(dataset, 2000, train_test_split=0.5, train=True)
labels = data_manager.get_labels(dataset) labels = data_manager.get_labels(dataset)
sentences = [] sentences = []
...@@ -23,8 +23,8 @@ def finetune_t5(dataset): ...@@ -23,8 +23,8 @@ def finetune_t5(dataset):
for i in range(min(len(sentences), 50)): for i in range(min(len(sentences), 50)):
print(f"sentence: {sentences[i]}, entity: {entities[i]}, label: {labels[i]}") print(f"sentence: {sentences[i]}, entity: {entities[i]}, label: {labels[i]}")
epochs = 20 epochs = 150
finetune_model(sentences, entities, labels, output_dir=f"./src/models/t5_mlm_finetuned_model/pretrained_{dataset}_epoch{epochs}", epochs=epochs) finetune_model(sentences, entities, labels, output_dir=f"./src/models/t5_mlm_finetuned_model/pretrained_{dataset}_epoch{epochs}", epochs=epochs)
finetune_t5("CoNLL") finetune_t5("FIGER-coarse")
...@@ -3,7 +3,7 @@ from src.models.T5_NLI import finetune_model ...@@ -3,7 +3,7 @@ from src.models.T5_NLI import finetune_model
def finetune_t5(dataset): def finetune_t5(dataset):
print("start") print("start")
annotated_sentences = data_manager.get_annotated_sentences(dataset, 1000) annotated_sentences = data_manager.get_annotated_sentences(dataset, 2000, train_test_split=0.5, train=True)
labels = data_manager.get_labels(dataset) labels = data_manager.get_labels(dataset)
premises = [] premises = []
...@@ -24,8 +24,9 @@ def finetune_t5(dataset): ...@@ -24,8 +24,9 @@ def finetune_t5(dataset):
for i in range(min(len(premises), 50)): for i in range(min(len(premises), 50)):
print(f"premise: {premises[i]}, hypothesis: {hypotheses[i]}, entailment: {entailment[i]}") print(f"premise: {premises[i]}, hypothesis: {hypotheses[i]}, entailment: {entailment[i]}")
epochs = 20
epochs = 150
finetune_model(premises, hypotheses, entailment, output_dir=f"./src/models/t5_nli_finetuned_model/pretrained_{dataset}_epoch{epochs}", epochs=epochs) finetune_model(premises, hypotheses, entailment, output_dir=f"./src/models/t5_nli_finetuned_model/pretrained_{dataset}_epoch{epochs}", epochs=epochs)
finetune_t5("CoNLL") finetune_t5("FIGER-coarse")
...@@ -4,7 +4,7 @@ from torch.nn.functional import softmax ...@@ -4,7 +4,7 @@ from torch.nn.functional import softmax
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq
from datasets import Dataset, DatasetDict from datasets import Dataset, DatasetDict
model_name = "google-t5/t5-base" model_name = "google-t5/t5-large"
print("Loading model: T5 MLM") print("Loading model: T5 MLM")
tokenizer = T5Tokenizer.from_pretrained(model_name) tokenizer = T5Tokenizer.from_pretrained(model_name)
...@@ -62,14 +62,14 @@ def finetune_model(sentences, entities, labels, output_dir, epochs=10): ...@@ -62,14 +62,14 @@ def finetune_model(sentences, entities, labels, output_dir, epochs=10):
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"]) dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir="./src/models/t5_nli_finetuned_model", output_dir="./src/models/t5_nli_finetuned_model/checkpoints/",
eval_strategy="epoch", eval_strategy="epoch",
learning_rate=5e-5, learning_rate=5e-5,
per_device_train_batch_size=8, per_device_train_batch_size=8,
per_device_eval_batch_size=8, per_device_eval_batch_size=8,
num_train_epochs=epochs, num_train_epochs=epochs,
weight_decay=0.01, weight_decay=0.01,
save_strategy="no", save_strategy="epoch",
push_to_hub=False, push_to_hub=False,
logging_dir="./logs", logging_dir="./logs",
) )
......
...@@ -5,8 +5,8 @@ from datasets import Dataset, DatasetDict ...@@ -5,8 +5,8 @@ from datasets import Dataset, DatasetDict
label_map = {True: "entailment", False: "contradiction"} label_map = {True: "entailment", False: "contradiction"}
# Use t5-base for testing because it is smaller # Use t5-large for testing because it is smaller
model_name = "google-t5/t5-base" model_name = "google-t5/t5-large"
# model_name = "google/t5_xxl_true_nli_mixture" # model_name = "google/t5_xxl_true_nli_mixture"
...@@ -110,14 +110,14 @@ def finetune_model(premises, hypotheses, entailment, output_dir, epochs=10): ...@@ -110,14 +110,14 @@ def finetune_model(premises, hypotheses, entailment, output_dir, epochs=10):
tokenized_datasets.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"]) tokenized_datasets.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir="./src/models/t5_nli_finetuned_model", output_dir="./src/models/t5_nli_finetuned_model/checkpoints/",
eval_strategy="epoch", eval_strategy="epoch",
learning_rate=5e-5, learning_rate=5e-5,
per_device_train_batch_size=8, per_device_train_batch_size=8,
per_device_eval_batch_size=8, per_device_eval_batch_size=8,
num_train_epochs=epochs, num_train_epochs=epochs,
weight_decay=0.01, weight_decay=0.01,
save_strategy="no", save_strategy="epoch",
push_to_hub=False, push_to_hub=False,
logging_dir="./logs", logging_dir="./logs",
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment