Skip to content
Snippets Groups Projects
Commit 8ac5bdf0 authored by wu's avatar wu
Browse files

Update plot_training_curves.py

parent c2fae59f
Branches master
No related tags found
No related merge requests found
......@@ -4,27 +4,19 @@ import matplotlib.pyplot as plt
import pathlib
hist_path = pathlib.Path('hist')
actor_only_hist = torch.load(hist_path/'model_actor_only_val_rouge_hist.pt')
CE_hist = torch.load(hist_path/'model_CE_val_rouge_hist.pt')
ahist = torch.Tensor(actor_only_hist)
chist = torch.Tensor(CE_hist)
# list of loss, loss of type float in training and in validation / test
actor_only_hist = torch.load(hist_path/'model_actor_only_val_rouge_hist.pt', map_location="cpu")
CE_hist = torch.load(hist_path/'model_CE_val_rouge_hist.pt', map_location="cpu")
num_epochs = 2
ohist = []
shist = []
ohist = [h.cpu().numpy() for h in ahist]
shist = [h.cpu().numpy() for h in chist]
plt.title("Validation Rouge vs. Number of Training Epochs")
plt.xlabel("Training Epochs")
plt.ylabel("Validation Rouge")
plt.plot(range(1,num_epochs+1),ohist,label="Actor Only")
plt.plot(range(1,num_epochs+1),shist,label="Cross Entropy")
plt.plot(range(1,num_epochs+1),actor_only_hist,label="Actor Only")
plt.plot(range(1,num_epochs+1),CE_hist,label="Cross Entropy")
plt.ylim((min(actor_only_hist + CE_hist), max(actor_only_hist + CE_hist)))
plt.xticks(np.arange(1, num_epochs+1, 1.0))
plt.xticks(np.arange(0, num_epochs, 1000)) #
plt.legend()
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment