import torch
import numpy as np
import matplotlib.pyplot as plt
import pathlib

hist_path = pathlib.Path('hist')
# list of loss, loss of type float in training and in validation / test
actor_only_hist = torch.load(hist_path/'model_actor_only_val_rouge_hist.pt', map_location="cpu")
CE_hist = torch.load(hist_path/'model_CE_val_rouge_hist.pt', map_location="cpu")

num_epochs = 2

plt.title("Validation Rouge vs. Number of Training Epochs")
plt.xlabel("Training Epochs")
plt.ylabel("Validation Rouge")
plt.plot(range(1,num_epochs+1),actor_only_hist,label="Actor Only")
plt.plot(range(1,num_epochs+1),CE_hist,label="Cross Entropy")
plt.ylim((min(actor_only_hist + CE_hist), max(actor_only_hist + CE_hist)))
plt.xticks(np.arange(0, num_epochs, 1000)) # 
plt.legend()
plt.show()