Skip to content
Snippets Groups Projects
Commit 7d964612 authored by kreuzer's avatar kreuzer
Browse files

Aktualisieren scripts/ActorOnly+CEL/main_ActorOnly.py, scripts/ActorOnly+CEL/main_CrossEntropy.py

parent 62879b27
No related branches found
No related tags found
No related merge requests found
import torch
from torch import nn
import numpy as np
import stanza
from gensim.models import KeyedVectors
from datasets import load_dataset
from structures import *
from models import *
import utils
# loads skipgram gensim
file_name = "data/1-billion-word-language-modeling-benchmark-r13output.word2vec.vec"
model_gensim = KeyedVectors.load_word2vec_format(file_name)
# initialize tokenizer, => sentences splitting and tokenizing, !pip install stanza
nlp = stanza.Pipeline(lang='en', processors='tokenize')
# loads dataset cnn_dailymail, !pip install datasets
dataset = load_dataset('cnn_dailymail', '3.0.0', split='train[:100]+validaton[:20]+test[:20]') # extract subset for testing
dataset = PreprocessedDataSet(dataset, model_gensim, nlp)
# loads dataset cnn_dailymail
train_dataset = PreprocessedDataSet('/workspace/students/kreuzer/train')
validation_dataset = PreprocessedDataSet('/workspace/students/kreuzer/val')
test_dataset = PreprocessedDataSet('/workspace/students/kreuzer/test')
# hyperparameters
epochs=20
batch_size=20
learning_rate=0.001
training_dataloader = torch.utils.data.DataLoader(dataset.train, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(dataset.test, batch_size=batch_size, shuffle=True)
training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x:x)
m = ActorOnlySummarisation()
since = time.time()
val_rouge_history = []
......@@ -37,16 +27,17 @@ best_model_wts = copy.deepcopy(m.state_dict())
for epoch in range(epochs):
print('Epoch {}/{}'.format(epoch, epochs - 1))
print()
print('Epoch {}/{}'.format(epoch+1, epochs))
print('-' * 10)
# train phase
epoch_loss, epoch_rouge = m.training_epoch(D.train)
epoch_loss, epoch_rouge = m.training_epoch(training_dataloader)
print('Train Loss: {:.4f} Rouge Score: {:.4f}'.format(epoch_loss, epoch_rouge))
# validation phase
epoch_rouge = m.validation(dataset.validation)
val_rouge_history.append(epoch_rouge)
epoch_rouge = m.validation(validation_dataset)
val_rouge_history.append(epoch_rouge)
print('Validation Rouge Score: {:.4f}'.format(epoch_rouge))
# epoch completed, deep copy the best model sofar
......@@ -56,9 +47,15 @@ for epoch in range(epochs):
# after training completed
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Training + validation complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val rouge: {:4f}'.format(best_rouge))
# write val_rouge_history in file
with open("ActorOnly_val_history.txt", "a") as f:
f.write(">>>Start\n")
for event in val_rouge_history:
f.write(str(event)+"\n")
f.write("\n")
# load best model weights
m.load_state_dict(best_model_wts)
......@@ -66,14 +63,16 @@ m.load_state_dict(best_model_wts)
# for Critic: save model_actor_only
torch.save(m.state_dict(), 'model_actor_only_wts.pth')
# set sent_vecs
dataset.compute_sent_vecs(model)
train_dataset.compute_sent_vecs(m)
validation_dataset.compute_sent_vecs(m)
test_dataset.compute_sent_vecs(m)
# testing
since = time.time()
epoch_rouge_1, epoch_rouge_2, epoch_rouge_l = m.test(dataset.test)
print('Test rouge_1: {:.4f} rouge_2: {:.4f} rouge_l: {:.4f}'.format(epoch_rouge_1, epoch_rouge_2, epoch_rouge_l))
epoch_rouge_1, epoch_rouge_2, epoch_rouge_l = m.test(test_dataset)
print('Test rouge_1: {:.4f} rouge_2: {:.4f} rouge_l: {:.4f} mean: {:.4f}'.format(epoch_rouge_1, epoch_rouge_2, epoch_rouge_l, (epoch_rouge_1+epoch_rouge_2+epoch_rouge_l)/3.0))
# after testing completed
time_elapsed = time.time() - since
......
......@@ -3,15 +3,19 @@ from torch import nn
import numpy as np
from structures import *
from models import *
import utils
dataset = torch.load("dataset.data")
# loads dataset cnn_dailymail
train_dataset = PreprocessedDataSet('/workspace/students/kreuzer/train')
validation_dataset = PreprocessedDataSet('/workspace/students/kreuzer/val')
test_dataset = PreprocessedDataSet('/workspace/students/kreuzer/test')
# hyperparameters
epochs=2
epochs=20
batch_size=20
learning_rate=0.001
train_dataloader = torch.utils.data.DataLoader(dataset.train, batch_size=batch_size, shuffle=True, collate_fn=lambda x:x)
training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x:x)
m = SummarisationModelWithCrossEntropyLoss()
......@@ -28,11 +32,11 @@ for epoch in range(epochs):
print('-' * 10)
# train phase
epoch_loss = m.training_epoch(train_dataloader)
epoch_loss = m.training_epoch(training_dataloader)
print('Train Loss: {:.4f}'.format(epoch_loss))
# validation phase
epoch_rouge = m.validation(dataset.validation)
epoch_rouge = m.validation(validation_dataset)
val_rouge_history.append(epoch_rouge)
print('Validation Rouge Score: {:.4f}'.format(epoch_rouge))
......@@ -43,9 +47,15 @@ for epoch in range(epochs):
# after training completed
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Training + validation complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val rouge: {:4f}'.format(best_rouge))
# write val_rouge_history in file
with open("CEL_val_history.txt", "a") as f:
f.write(">>>Start\n")
for event in val_rouge_history:
f.write(str(event)+"\n")
f.write("\n")
# load best model weights
m.load_state_dict(best_model_wts)
......@@ -53,7 +63,7 @@ m.load_state_dict(best_model_wts)
# testing
since = time.time()
epoch_rouge_1, epoch_rouge_2, epoch_rouge_l = m.test(dataset.test)
epoch_rouge_1, epoch_rouge_2, epoch_rouge_l = m.test(test_dataset)
print('Test rouge_1: {:.4f} rouge_2: {:.4f} rouge_l: {:.4f} mean: {:.4f}'.format(epoch_rouge_1, epoch_rouge_2, epoch_rouge_l, (epoch_rouge_1+epoch_rouge_2+epoch_rouge_l)/3.0))
# after testing completed
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment