Skip to content
Snippets Groups Projects
Commit 32c4f9fd authored by Thomas Wolf's avatar Thomas Wolf
Browse files

Merge remote-tracking branch 'origin/master'

parents ca5ce224 3d39d2cc
No related branches found
No related tags found
No related merge requests found
import os
import re
import pandas as pd
import matplotlib.pyplot as plt
def plot_loss_curve(logfile, title):
eval_losses = []
pattern = re.compile(r"'eval_loss': ([\d\.e-]+)")
with open(logfile, 'r', errors='ignore') as file:
for line in file:
match = pattern.search(line)
if match:
eval_losses.append(float(match.group(1)))
plt.figure(figsize=(10,5))
plt.plot(eval_losses)
plt.xlabel("Epoch")
plt.ylabel("Eval Loss")
plt.title(title)
plt.legend()
plt.grid(True)
plt.savefig(f"eval_loss_{os.path.basename(logfile)}.pdf")
plot_loss_curve("logs/finetune_T5_MLM_entity_427082.txt", "T5 Finetuning - MLM Entity Masking")
plot_loss_curve("logs/finetune_T5_MLM_label_427081.txt", "T5 Finetuning - MLM Label Masking")
plot_loss_curve("logs/finetune_T5_NLI_427080.txt", "T5 Finetuning - NLI")
import pandas as pd
import matplotlib.pyplot as plt
df = pd.read_csv("eval_loss.csv")
print(df.head())
plt.figure(figsize=(10,5))
plt.plot(df["epoch"], df["eval_loss"])
plt.xlabel("Epoch")
plt.ylabel("Eval Loss")
plt.title("T5 finetuning training curve")
plt.legend()
plt.grid(True)
plt.savefig("eval_loss.pdf")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment