Skip to content
Snippets Groups Projects
Commit 2d481623 authored by kreuzer's avatar kreuzer
Browse files

Aktualisieren structures.py, utils.py

parent 659da407
No related branches found
No related tags found
No related merge requests found
import torch
import numpy as np
import utils
import pathlib
......@@ -27,7 +28,7 @@ class DataPoint:
doc_embeddings, raw_doc = [], []
for sent in doc_preprocessed.sents:
for sent in doc_preprocessed.sentences:
sent_embeddings, raw_sent = [], []
for tok in sent.tokens:
......@@ -58,7 +59,7 @@ class DataPoint:
self.gold_sent_vecs = None
# Searchspace
self.p_searchspace, self.n_searchspace, self.top_rouge, self.bin_summary = utils.searchspace(self.raw_document, self.raw_summary)
self.p_searchspace, self.n_searchspace, self.top_rouge, self.bin_summary = utils.searchspace(self.raw_document, self.raw_summary, 3)
def compute_sent_vecs(self, model):
# model.sentence_encoder
......@@ -71,24 +72,49 @@ class DataPoint:
# DataSet
# assembles all datapoints, representing cnn_dailymail
class PreprocessedDataSet:
def __init__(self, dataset, model_gensim, nlp):
def read_data(part):
l = []
for y in dataset[part]:
article, highlights = y['article'], y['highlights']
# for every document
l.append(DataPoint(article, highlights, model_gensim, nlp))
return l
self.train = read_data('train')
self.test = read_data('test')
self.validation = read_data('validation')
def __init__(self, path, dataset=None, model_gensim=None, nlp=None): #arg path: Name der Partition (train, test, validation) oder Pfad (erstellt Ordner)
self.path = pathlib.Path(path)
if dataset == None and model_gensim == None and nlp == None:
self.length = len(list(filter(lambda x:x.is_file(), self.path.iterdir())))
elif dataset != None and model_gensim != None and nlp != None:
self.path.mkdir()
self.length = len(dataset)
for i, element in enumerate(dataset):
article, highlights = element['article'], element['highlights']
dp = DataPoint(article, highlights, model_gensim, nlp)
torch.save(dp, self.path/str(i))
else:
raise Exception("INCONSISTENT ARGUMENTS")
def __len__(self):
return self.length
def __getitem__(self, index):
if index not in range(0, self.length):
raise IndexError()
try:
dp = torch.load(self.path/str(index))
except:
raise FileNotFoundError()
return dp
def compute_sent_vecs(self, model):
for datapoint in self.train + self.test + self.validation:
datapoint.compute_sent_vecs(model)
for i in range(self.length):
dp = self.__getitem__(i)
dp.compute_sent_vecs(model)
torch.save(dp, self.path/str(i))
......@@ -98,8 +124,3 @@ class PreprocessedDataSet:
......@@ -144,13 +144,9 @@ def rouge(summary, gold_summary, verbose=False):
length_gold_summary = len(gold_tokens)
if verbose:
return rouge_n(unigram_counts, gold_unigram_counts, length_summary, length_gold_summary),
rouge_n(bigram_counts, gold_bigram_counts, length_summary - 1, length_gold_summary - 1),
rouge_l(summary, gold_summary, unigram_counts, gold_unigram_counts, length_summary, length_gold_summary)
return rouge_n(unigram_counts, gold_unigram_counts, length_summary, length_gold_summary), rouge_n(bigram_counts, gold_bigram_counts, length_summary - 1, length_gold_summary - 1), rouge_l(summary, gold_summary, unigram_counts, gold_unigram_counts, length_summary, length_gold_summary)
return (rouge_n(unigram_counts, gold_unigram_counts, length_summary, length_gold_summary)
+ rouge_n(bigram_counts, gold_bigram_counts, length_summary - 1, length_gold_summary - 1)
+ rouge_l(summary, gold_summary, unigram_counts, gold_unigram_counts, length_summary, length_gold_summary)) / 3.0
return (rouge_n(unigram_counts, gold_unigram_counts, length_summary, length_gold_summary) + rouge_n(bigram_counts, gold_bigram_counts, length_summary - 1, length_gold_summary - 1) + rouge_l(summary, gold_summary, unigram_counts, gold_unigram_counts, length_summary, length_gold_summary)) / 3.0
# pre-compute searchspace of a document for high quality sampling in training
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment