Skip to content
Snippets Groups Projects
Commit 8c7c8819 authored by kupper's avatar kupper
Browse files

Loading of finetuned models and minor fixes

parent 32c4f9fd
No related branches found
No related tags found
No related merge requests found
...@@ -62,7 +62,7 @@ def run_NEC_tests(model_name, dataset, results_dir, test_instances=10): ...@@ -62,7 +62,7 @@ def run_NEC_tests(model_name, dataset, results_dir, test_instances=10):
txt_file.write(f"\nSentence: {sentence}") txt_file.write(f"\nSentence: {sentence}")
txt_file.write(f"\nEntity: {entity_name}") txt_file.write(f"\nEntity: {entity_name}")
txt_file.write(f"\nTrue Label(s): {', '.join(true_labels)}") txt_file.write(f"\nTrue Label(s): {', '.join(true_labels)}")
txt_file.write(f"\nPredicted Label: {predicted_label}") txt_file.write(f"\nPredicted Label: {predicted_label}\n")
txt_file.write("-" * 50 + "\n\n") txt_file.write("-" * 50 + "\n\n")
print(f"Results saved to:\n CSV: {csv_filename}\n TXT: {txt_filename}") print(f"Results saved to:\n CSV: {csv_filename}\n TXT: {txt_filename}")
...@@ -108,5 +108,5 @@ def read_NEC_metrics(directory): ...@@ -108,5 +108,5 @@ def read_NEC_metrics(directory):
print(f"Model: {model}, Dataset: {dataset}, Accuracy: {avg_accuracy:.2f}%") print(f"Model: {model}, Dataset: {dataset}, Accuracy: {avg_accuracy:.2f}%")
#run_NEC_tests_all() # run_NEC_tests_all()
read_NEC_metrics("results") read_NEC_metrics("results")
...@@ -15,7 +15,7 @@ def plot_loss_curve(logfile, title): ...@@ -15,7 +15,7 @@ def plot_loss_curve(logfile, title):
eval_losses.append(float(match.group(1))) eval_losses.append(float(match.group(1)))
plt.figure(figsize=(10,5)) plt.figure(figsize=(10,5))
plt.plot(eval_losses) plt.plot(eval_losses, label="Eval Loss")
plt.xlabel("Epoch") plt.xlabel("Epoch")
plt.ylabel("Eval Loss") plt.ylabel("Eval Loss")
...@@ -23,7 +23,7 @@ def plot_loss_curve(logfile, title): ...@@ -23,7 +23,7 @@ def plot_loss_curve(logfile, title):
plt.legend() plt.legend()
plt.grid(True) plt.grid(True)
plt.savefig(f"eval_loss_{os.path.basename(logfile)}.pdf") plt.savefig(f"eval_loss_{os.path.basename(logfile)}.svg")
plot_loss_curve("logs/finetune_T5_MLM_entity_427082.txt", "T5 Finetuning - MLM Entity Masking") plot_loss_curve("logs/finetune_T5_MLM_entity_427082.txt", "T5 Finetuning - MLM Entity Masking")
plot_loss_curve("logs/finetune_T5_MLM_label_427081.txt", "T5 Finetuning - MLM Label Masking") plot_loss_curve("logs/finetune_T5_MLM_label_427081.txt", "T5 Finetuning - MLM Label Masking")
......
...@@ -7,10 +7,23 @@ from datasets import Dataset, DatasetDict ...@@ -7,10 +7,23 @@ from datasets import Dataset, DatasetDict
model_name = "google-t5/t5-base" model_name = "google-t5/t5-base"
print("Loading model: T5 MLM entity") def load_base():
tokenizer = T5Tokenizer.from_pretrained(model_name) global model
model = T5ForConditionalGeneration.from_pretrained(model_name) global tokenizer
print("Finished loading model: T5 MLM entity") print("Loading model: T5 MLM entity")
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
print("Finished loading model: T5 MLM entity")
def load_finetuned(input_dir):
global model
global tokenizer
print(f"Loading model: T5 MLM entity finetuned ({input_dir})")
tokenizer = T5Tokenizer.from_pretrained(input_dir)
model = T5ForConditionalGeneration.from_pretrained(input_dir)
print(f"Finished loading model: T5 MLM entity finetuned")
def set_label_dict(label_dict): def set_label_dict(label_dict):
global label_representatives global label_representatives
...@@ -113,3 +126,6 @@ def finetune_model(sentences, entities, labels, output_dir, epochs=10): ...@@ -113,3 +126,6 @@ def finetune_model(sentences, entities, labels, output_dir, epochs=10):
model.save_pretrained(output_dir) model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
load_base()
# load_finetuned("./src/models/t5_mlm_entity_finetuned_model/checkpoints/checkpoint-12200")
...@@ -6,10 +6,23 @@ from datasets import Dataset, DatasetDict ...@@ -6,10 +6,23 @@ from datasets import Dataset, DatasetDict
model_name = "google-t5/t5-base" model_name = "google-t5/t5-base"
print("Loading model: T5 MLM label") def load_base():
tokenizer = T5Tokenizer.from_pretrained(model_name) global model
model = T5ForConditionalGeneration.from_pretrained(model_name) global tokenizer
print("Finished loading model: T5 MLM label") print("Loading model: T5 MLM label")
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
print("Finished loading model: T5 MLM label")
def load_finetuned(input_dir):
global model
global tokenizer
print(f"Loading model: T5 MLM label finetuned ({input_dir})")
tokenizer = T5Tokenizer.from_pretrained(input_dir)
model = T5ForConditionalGeneration.from_pretrained(input_dir)
print(f"Finished loading model: T5 MLM label finetuned")
def classify_entity(sentence, entity, labels): def classify_entity(sentence, entity, labels):
sentence_with_masked_hypothesis = f"{sentence} {entity} is a <extra_id_0>." sentence_with_masked_hypothesis = f"{sentence} {entity} is a <extra_id_0>."
...@@ -89,3 +102,6 @@ def finetune_model(sentences, entities, labels, output_dir, epochs=10): ...@@ -89,3 +102,6 @@ def finetune_model(sentences, entities, labels, output_dir, epochs=10):
model.save_pretrained(output_dir) model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
load_base()
# load_finetuned("./src/models/t5_mlm_label_finetuned_model/checkpoints/checkpoint-9638")
...@@ -139,4 +139,4 @@ def finetune_model(premises, hypotheses, entailment, output_dir, epochs=10): ...@@ -139,4 +139,4 @@ def finetune_model(premises, hypotheses, entailment, output_dir, epochs=10):
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
load_base() load_base()
# load_finetuned("./src/models/t5_nli_finetuned_model/pretrained_CoNLL_epoch20") # load_finetuned("./src/models/t5_nli_finetuned_model/checkpoints/checkpoint-85500")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment