Skip to content
Snippets Groups Projects
Commit 7d1ec45c authored by kreuzer's avatar kreuzer
Browse files

Aktualisieren main_CrossEntropy.py

parent 6c4b0f7a
No related branches found
No related tags found
No related merge requests found
import torch
from torch import nn
import numpy as np
from structures import *
from models import *
dataset = torch.load("dataset.data")
# hyperparameters
epochs=2
batch_size=20
learning_rate=0.001
train_dataloader = torch.utils.data.DataLoader(dataset.train, batch_size=batch_size, shuffle=True, collate_fn=lambda x:x)
m = SummarisationModelWithCrossEntropyLoss()
since = time.time()
val_rouge_history = []
best_rouge = 0.0
best_model_wts = copy.deepcopy(m.state_dict())
for epoch in range(epochs):
print()
print('Epoch {}/{}'.format(epoch+1, epochs))
print('-' * 10)
# train phase
epoch_loss = m.training_epoch(train_dataloader)
print('Train Loss: {:.4f}'.format(epoch_loss))
# validation phase
epoch_rouge = m.validation(dataset.validation)
val_rouge_history.append(epoch_rouge)
print('Validation Rouge Score: {:.4f}'.format(epoch_rouge))
# epoch completed, deep copy the best model sofar
if epoch_rouge > best_rouge:
best_rouge = epoch_rouge
best_model_wts = copy.deepcopy(m.state_dict())
# after training completed
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val rouge: {:4f}'.format(best_rouge))
# write val_rouge_history in file
# load best model weights
m.load_state_dict(best_model_wts)
# testing
since = time.time()
epoch_rouge_1, epoch_rouge_2, epoch_rouge_l = m.test(dataset.test)
print('Test rouge_1: {:.4f} rouge_2: {:.4f} rouge_l: {:.4f} mean: {:.4f}'.format(epoch_rouge_1, epoch_rouge_2, epoch_rouge_l, (epoch_rouge_1+epoch_rouge_2+epoch_rouge_l)/3.0))
# after testing completed
time_elapsed = time.time() - since
print('Testing complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment