Skip to content
Snippets Groups Projects
Commit b92531ab authored by wu's avatar wu
Browse files

Merge branch 'wu-master-patch-25975' into 'master'

Update models.py

See merge request kreuzer/nn-projekt-ss22!7
parents 73d67dfb f6e5b8fa
No related branches found
No related tags found
No related merge requests found
......@@ -78,12 +78,13 @@ class ActorOnlySummarisationModel(SummarisationModel):
best_rouge = 0.0
best_model_wts = copy.deepcopy(self.state_dict())
for epoch in range(epochs):
print('Epoch {}/{}'.format(epoch, epochs - 1))
print('-' * 10)
# each epoch has a training and validation phase
# training phase of the epoch
running_loss = 0.0
running_rouge = 0.0
......@@ -114,7 +115,7 @@ class ActorOnlySummarisationModel(SummarisationModel):
epoch_rouge = running_rouge/ len(training_dataloader.dataset) # abh. von __len__ of PreprocessedDataSet
print('Train Loss: {:.4f} Rouge Score: {:.4f}'.format(epoch_loss, epoch_rouge))
# validation
# validation phase of the epoch
self.eval()
running_rouge = 0.0
with torch.no_grad():
......@@ -124,15 +125,16 @@ class ActorOnlySummarisationModel(SummarisationModel):
top_indices, probs = self.__call__(datapoint.document)
running_rouge += rouge(select_elements(datapoint.raw_document, top_indices), datapoint.raw_summary)
# vgl. train rouge for searchspace funtionality
epoch_rouge = running_rouge/ len(PreprocessedDataSet.validation)
epoch_rouge = running_rouge/ len(PreprocessedDataSet.validation)
val_rouge_history.append(epoch_rouge)
print('Validation Rouge Score: {:.4f}'.format(epoch_rouge))
# deep copy the model
# epoch completed, deep copy the best model sofar
if epoch_rouge > best_rouge:
best_rouge = epoch_rouge
best_model_wts = copy.deepcopy(self.state_dict())
val_rouge_history.append(epoch_rouge)
# training completed
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val rouge: {:4f}'.format(best_rouge))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment