Skip to content
Snippets Groups Projects
Commit 0b2cb530 authored by kreuzer's avatar kreuzer
Browse files

Aktualisieren models.py

parent b8259f3f
No related branches found
No related tags found
No related merge requests found
......@@ -82,6 +82,10 @@ class SummarisationModel(nn.Module):
epoch_rouge_l = running_rouge_l / len(dataset)
return epoch_rouge_1, epoch_rouge_2, epoch_rouge_l
def validation(self, dataset):
return sum(self.test(dataset)) / 3.0
......@@ -164,6 +168,45 @@ class ActorOnlySummarisationModel(SummarisationModel):
# load best model weights
self.load_state_dict(best_model_wts)
def __init__(self):
super().__init__()
self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
def epoch(self, dataloader, learning_rate=0.001):
if learning_rate != 0.001:
self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
self.train()
epoch_loss = 0.0
epoch_rouge = 0.0
for batch in dataloader:
self.optimizer.zero_grad()
for datapoint in batch:
_, probs = self.__call__(datapoint.document)
o = datapoint.p_searchspace @ torch.log(probs) + datapoint.n_searchspace @ torch.log(1 - probs)
idx_sample = torch.argmax(o)
loss = - datapoint.top_rouge[idx_sample] * o[idx_sample]
loss.backward()
epoch_loss += loss.item()
epoch_rouge += datapoint.top_rouge[idx_sample]
self.optimizer.step()
return epoch_loss / len(dataloader.dataset), epoch_rouge / len(dataloader.dataset)
class SummarisationModelWithCrossEntropyLoss(SummarisationModel):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment