Skip to content
Snippets Groups Projects
plot_training_curves.py 571 B
Newer Older
wu's avatar
wu committed

import torch
import numpy as np
import matplotlib.pyplot as plt

hist = torch.rand(10)
scratch_hist = torch.rand(10)

num_epochs = 10

ohist = []
shist = []

ohist = [h.cpu().numpy() for h in hist]
shist = [h.cpu().numpy() for h in scratch_hist]

plt.title("Validation Accuracy vs. Number of Training Epochs")
plt.xlabel("Training Epochs")
plt.ylabel("Validation Accuracy")
plt.plot(range(1,num_epochs+1),ohist,label="Pretrained")
plt.plot(range(1,num_epochs+1),shist,label="Scratch")
plt.ylim((0,1.))
plt.xticks(np.arange(1, num_epochs+1, 1.0))
plt.legend()
plt.show()