Skip to content
Snippets Groups Projects
Commit ca546013 authored by wu's avatar wu
Browse files

print statics, save models and statics

parent ba3f047a
No related branches found
No related tags found
No related merge requests found
......@@ -50,24 +50,32 @@ for epoch in range(epochs):
if epoch_rouge > best_rouge:
best_rouge = epoch_rouge
best_model_wts = copy.deepcopy(m.state_dict())
# (over)write best model in file, and for Critic: save model_actor_only
torch.save(m.state_dict(), 'model_actor_only_wts.pth')
time_elapsed = time.time() - since
print('Training + validation running already {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val rouge so far: {:4f}'.format(best_rouge))
# save model weights after each validation
torch.save(m.state_dict(), f'model_actor_only_wts{epoch}.pth')
# (over)write val_rouge_history in file
torch.save(val_rouge_history, 'model_actor_only_val_rouge_hist.pt')
# after training completed
time_elapsed = time.time() - since
print('Training + validation complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val rouge: {:4f}'.format(best_rouge))
# write val_rouge_history in file
pathlib.Path('hist').mkdir(parents=True, exist_ok=True)
hist_path = pathlib.Path('hist')
torch.save(val_rouge_history, hist_path/'model_actor_only_val_rouge_hist.pt') # all hist in one folder hist?
# load best model weights
m.load_state_dict(best_model_wts)
# for Critic: save model_actor_only
torch.save(m.state_dict(), 'model_actor_only_wts.pth')
# set sent_vecs
train_data.compute_sent_vecs(m, 'workspace/students/kreuzer/new_train') #
val_data.compute_sent_vecs(m, 'workspace/students/kreuzer/new_val')
......
......@@ -41,28 +41,38 @@ for epoch in range(epochs):
val_rouge_history.append(epoch_rouge)
print('Validation Rouge Score: {:.4f}'.format(epoch_rouge))
# epoch completed, deep copy the best model sofar
if epoch_rouge > best_rouge:
best_rouge = epoch_rouge
best_model_wts = copy.deepcopy(m.state_dict())
# (over)write best model in file, and for Critic: save model_actor_only
torch.save(m.state_dict(), 'model_CE_wts.pth')
time_elapsed = time.time() - since
print('Training + validation running already {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val rouge so far: {:4f}'.format(best_rouge))
# save model weights after each validation
torch.save(m.state_dict(), f'model_CE_wts{epoch}.pth')
# (over)write val_rouge_history in file
torch.save(val_rouge_history, 'model_CE_val_rouge_hist.pt')
# after training completed
time_elapsed = time.time() - since
print('Training + validation complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val rouge: {:4f}'.format(best_rouge))
# write val_rouge_history in file
pathlib.Path('hist').mkdir(parents=True, exist_ok=True)
hist_path = pathlib.Path('hist')
torch.save(val_rouge_history, hist_path/'model_CE_val_rouge_hist.pt')
# load best model weights
m.load_state_dict(best_model_wts)
#save best model
torch.save(m.state_dict(), f'model_CE_wts.pth')
# testing
since = time.time()
......
......@@ -6,18 +6,18 @@ from models import *
import time
# loads dataset cnn_dailymail
train_dataset = PreprocessedDataSet('/workspace/students/kreuzer/train')
validation_dataset = PreprocessedDataSet('/workspace/students/kreuzer/val')
test_dataset = PreprocessedDataSet('/workspace/students/kreuzer/test')
train_dataset = PreprocessedDataSet(['/workspace/students/kreuzer/new_train'])
validation_dataset = PreprocessedDataSet(['/workspace/students/kreuzer/new_val'])
test_dataset = PreprocessedDataSet(['/workspace/students/kreuzer/new_test'])
# hyperparameters
epochs=20
epochs=10 # <= actor trained epochs
batch_size=20
learning_rate=0.001
training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn = lambda x: x)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn = lambda x: x)
# load state dict from training
model_actor_only = ActorOnlySummarisationModel()
......@@ -39,11 +39,11 @@ for e in range(epochs):
print('-' * 10)
# train phase
epoch_loss = m.training_epoch(training_dataloader) # collate_fn
epoch_loss = m.training_epoch(train_dataloader) # collate_fn
print('Train Loss: {:.4f}'.format(epoch_loss))
# validation phase
val_epoch_loss = m.test(dataset.validation)
val_epoch_loss = m.test(validation_dataset)
val_loss_history.append(val_epoch_loss)
print('Validation Loss: {:.4f}'.format(val_epoch_loss))
......@@ -51,6 +51,19 @@ for e in range(epochs):
if val_epoch_loss < best_loss:
best_loss = val_epoch_loss
best_model_wts = copy.deepcopy(m.state_dict())
# (over)write best model in file
torch.save(m.state_dict(), 'model_critic_wts.pth')
time_elapsed = time.time() - since
print('Training + validation running already {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val result so far: {:4f}'.format(best_loss))
# save model weights after each validation
torch.save(m.state_dict(), f'model_critic_wts{epoch}.pth')
# (over)write val_rouge_history in file
torch.save(val_loss_history, 'model_critic_val_loss_hist.pt')
# after training completed
time_elapsed = time.time() - since
......@@ -64,7 +77,7 @@ m.load_state_dict(best_model_wts)
# testing
since = time.time()
test_loss = m.test(dataset.test)
test_loss = m.test(test_dataset)
print('Test rouge score difference: {:.4f}'.format(test_loss))
# after testing completed
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment