Skip to content
Snippets Groups Projects
Commit 9f712c09 authored by Thomas Wolf's avatar Thomas Wolf
Browse files

Merge remote-tracking branch 'origin/master'

parents 8774a58f 8c7c8819
No related branches found
No related tags found
No related merge requests found
......@@ -108,5 +108,5 @@ def read_NEC_metrics(directory):
print(f"Model: {model}, Dataset: {dataset}, Accuracy: {avg_accuracy:.2f}%")
#run_NEC_tests_all()
# run_NEC_tests_all()
read_NEC_metrics("results")
......@@ -15,7 +15,7 @@ def plot_loss_curve(logfile, title):
eval_losses.append(float(match.group(1)))
plt.figure(figsize=(10,5))
plt.plot(eval_losses)
plt.plot(eval_losses, label="Eval Loss")
plt.xlabel("Epoch")
plt.ylabel("Eval Loss")
......@@ -23,7 +23,7 @@ def plot_loss_curve(logfile, title):
plt.legend()
plt.grid(True)
plt.savefig(f"eval_loss_{os.path.basename(logfile)}.pdf")
plt.savefig(f"eval_loss_{os.path.basename(logfile)}.svg")
plot_loss_curve("logs/finetune_T5_MLM_entity_427082.txt", "T5 Finetuning - MLM Entity Masking")
plot_loss_curve("logs/finetune_T5_MLM_label_427081.txt", "T5 Finetuning - MLM Label Masking")
......
......@@ -7,10 +7,23 @@ from datasets import Dataset, DatasetDict
model_name = "google-t5/t5-base"
print("Loading model: T5 MLM entity")
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
print("Finished loading model: T5 MLM entity")
def load_base():
global model
global tokenizer
print("Loading model: T5 MLM entity")
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
print("Finished loading model: T5 MLM entity")
def load_finetuned(input_dir):
global model
global tokenizer
print(f"Loading model: T5 MLM entity finetuned ({input_dir})")
tokenizer = T5Tokenizer.from_pretrained(input_dir)
model = T5ForConditionalGeneration.from_pretrained(input_dir)
print(f"Finished loading model: T5 MLM entity finetuned")
def set_label_dict(label_dict):
global label_representatives
......@@ -113,3 +126,6 @@ def finetune_model(sentences, entities, labels, output_dir, epochs=10):
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
load_base()
# load_finetuned("./src/models/t5_mlm_entity_finetuned_model/checkpoints/checkpoint-12200")
......@@ -6,10 +6,23 @@ from datasets import Dataset, DatasetDict
model_name = "google-t5/t5-base"
print("Loading model: T5 MLM label")
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
print("Finished loading model: T5 MLM label")
def load_base():
global model
global tokenizer
print("Loading model: T5 MLM label")
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
print("Finished loading model: T5 MLM label")
def load_finetuned(input_dir):
global model
global tokenizer
print(f"Loading model: T5 MLM label finetuned ({input_dir})")
tokenizer = T5Tokenizer.from_pretrained(input_dir)
model = T5ForConditionalGeneration.from_pretrained(input_dir)
print(f"Finished loading model: T5 MLM label finetuned")
def classify_entity(sentence, entity, labels):
sentence_with_masked_hypothesis = f"{sentence} {entity} is a <extra_id_0>."
......@@ -89,3 +102,6 @@ def finetune_model(sentences, entities, labels, output_dir, epochs=10):
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
load_base()
# load_finetuned("./src/models/t5_mlm_label_finetuned_model/checkpoints/checkpoint-9638")
......@@ -139,4 +139,4 @@ def finetune_model(premises, hypotheses, entailment, output_dir, epochs=10):
tokenizer.save_pretrained(output_dir)
load_base()
# load_finetuned("./src/models/t5_nli_finetuned_model/pretrained_CoNLL_epoch20")
# load_finetuned("./src/models/t5_nli_finetuned_model/checkpoints/checkpoint-85500")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment