Skip to content
Snippets Groups Projects
plot_training_curves.py 825 B
Newer Older
wu's avatar
wu committed
import torch
import numpy as np
import matplotlib.pyplot as plt
wu's avatar
wu committed
import pathlib

hist_path = pathlib.Path('hist')
actor_only_hist = torch.load(hist_path/'model_actor_only_val_rouge_hist.pt')
CE_hist = torch.load(hist_path/'model_CE_val_rouge_hist.pt')
wu's avatar
wu committed

wu's avatar
wu committed
ahist = torch.Tensor(actor_only_hist)
chist = torch.Tensor(CE_hist)
wu's avatar
wu committed

wu's avatar
wu committed
num_epochs = 2
wu's avatar
wu committed

ohist = []
shist = []

wu's avatar
wu committed
ohist = [h.cpu().numpy() for h in ahist]
shist = [h.cpu().numpy() for h in chist]
wu's avatar
wu committed

wu's avatar
wu committed
plt.title("Validation Rouge vs. Number of Training Epochs")
wu's avatar
wu committed
plt.xlabel("Training Epochs")
wu's avatar
wu committed
plt.ylabel("Validation Rouge")
plt.plot(range(1,num_epochs+1),ohist,label="Actor Only")
plt.plot(range(1,num_epochs+1),shist,label="Cross Entropy")
plt.ylim((min(actor_only_hist + CE_hist), max(actor_only_hist + CE_hist)))
wu's avatar
wu committed
plt.xticks(np.arange(1, num_epochs+1, 1.0))
plt.legend()
plt.show()