Skip to content
Snippets Groups Projects
plot_training_curves.py 776 B
Newer Older
wu's avatar
wu committed
import torch
import numpy as np
import matplotlib.pyplot as plt
wu's avatar
wu committed
import pathlib

hist_path = pathlib.Path('hist')
wu's avatar
wu committed
# list of loss, loss of type float in training and in validation / test
actor_only_hist = torch.load(hist_path/'model_actor_only_val_rouge_hist.pt', map_location="cpu")
CE_hist = torch.load(hist_path/'model_CE_val_rouge_hist.pt', map_location="cpu")
wu's avatar
wu committed

wu's avatar
wu committed
num_epochs = 2
wu's avatar
wu committed

wu's avatar
wu committed
plt.title("Validation Rouge vs. Number of Training Epochs")
wu's avatar
wu committed
plt.xlabel("Training Epochs")
wu's avatar
wu committed
plt.ylabel("Validation Rouge")
wu's avatar
wu committed
plt.plot(range(1,num_epochs+1),actor_only_hist,label="Actor Only")
plt.plot(range(1,num_epochs+1),CE_hist,label="Cross Entropy")
wu's avatar
wu committed
plt.ylim((min(actor_only_hist + CE_hist), max(actor_only_hist + CE_hist)))
wu's avatar
wu committed
plt.xticks(np.arange(0, num_epochs, 1000)) # 
wu's avatar
wu committed
plt.legend()
plt.show()