Skip to content
Snippets Groups Projects
Commit 38fe2086 authored by wu's avatar wu
Browse files

compute sent vecs mit best model

parent ca546013
No related branches found
No related tags found
No related merge requests found
......@@ -54,6 +54,13 @@ for epoch in range(epochs):
# (over)write best model in file, and for Critic: save model_actor_only
torch.save(m.state_dict(), 'model_actor_only_wts.pth')
# load best model and compute sent_vecs anew
best_m = ActorOnlySummarisationModel()
best_m.load_state_dict(best_model_wts)
train_data.compute_sent_vecs(best_m, 'workspace/students/kreuzer/new_train') #
val_data.compute_sent_vecs(best_m, 'workspace/students/kreuzer/new_val')
test_data.compute_sent_vecs(best_m, 'workspace/students/kreuzer/new_test')
time_elapsed = time.time() - since
print('Training + validation running already {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val rouge so far: {:4f}'.format(best_rouge))
......@@ -72,14 +79,6 @@ print('Best val rouge: {:4f}'.format(best_rouge))
# load best model weights
m.load_state_dict(best_model_wts)
# set sent_vecs
train_data.compute_sent_vecs(m, 'workspace/students/kreuzer/new_train') #
val_data.compute_sent_vecs(m, 'workspace/students/kreuzer/new_val')
test_data.compute_sent_vecs(m, 'workspace/students/kreuzer/new_test')
# testing
since = time.time()
......
......@@ -92,7 +92,7 @@ class PreprocessedDataSet:
self.path = path
self.path.mkdir()
self.length = len(dataset)
for i, element in enumerate(dataset):
......@@ -132,8 +132,9 @@ class PreprocessedDataSet:
def compute_sent_vecs(self, model, path):
path = pathlib.Path(path)
path.mkdir()
# writes in file, allows overwriting in main_ActorOnly
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
for i in range(self.length):
dp = self.__getitem__(i)
dp.compute_sent_vecs(model)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment