Skip to content
Snippets Groups Projects
Commit b31d1646 authored by kreuzer's avatar kreuzer
Browse files

Merge branch 'kreuzer-master-patch-03847' into 'master'

Aktualisieren models.py

See merge request kreuzer/nn-projekt-ss22!8
parents b92531ab b8d18532
No related branches found
No related tags found
No related merge requests found
......@@ -62,6 +62,28 @@ class SummarisationModel(nn.Module):
def test(self, dataset):
running_rouge_1 = 0.0
running_rouge_2 = 0.0
running_rouge_l = 0.0
self.eval()
with torch.no_grad():
for datapoint in dataset:
top_indices, probs = self.__call__(datapoint)
r_1, r_2, r_l = utils.rouge(utils.select_elements(datapoint.raw_document, top_indices), datapoint.raw_summary, verbose=True)
running_rouge_1 += r_1
running_rouge_2 += r_2
running_rouge_L += r_l
epoch_rouge_1 = running_rouge_1 / len(dataset)
epoch_rouge_2 = running_rouge_2 / len(dataset)
epoch_rouge_l = running_rouge_l / len(dataset)
return epoch_rouge_1, epoch_rouge_2, epoch_rouge_l
class ActorOnlySummarisationModel(SummarisationModel):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment