Skip to content
Snippets Groups Projects
models.py 9.45 KiB
Newer Older
wu's avatar
wu committed
import copy
import torch
from torch import nn
wu's avatar
wu committed
import numpy as np
wu's avatar
wu committed

wu's avatar
wu committed
import utils #

kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
class SummarisationModel(nn.Module):
kreuzer's avatar
kreuzer committed

    def __init__(self):
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        super().__init__()
kreuzer's avatar
kreuzer committed

        self.cnns = [
kreuzer's avatar
kreuzer committed
            nn.Conv1d(200, 50, 1, padding="valid", groups=1),
            nn.Conv1d(200, 50, 2, padding="valid", groups=1),
            nn.Conv1d(200, 50, 3, padding="valid", groups=1),
            nn.Conv1d(200, 50, 4, padding="valid", groups=1),
            nn.Conv1d(200, 50, 5, padding="valid", groups=1),
            nn.Conv1d(200, 50, 6, padding="valid", groups=1),
            nn.Conv1d(200, 50, 7, padding="valid", groups=1)
kreuzer's avatar
kreuzer committed
        self.document_encoder = nn.LSTM(350, 600) 
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        self.sentence_extractor = nn.LSTM(350, 600)
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        self.projector = nn.Linear(600, 2)
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        self.softmax = nn.Softmax(dim=-1)
kreuzer's avatar
kreuzer committed
    
kreuzer's avatar
kreuzer committed
    def encode_sentences(self, document):
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        convolutions = []
        for cnn in self.cnns:
            convolutions.append(cnn(document.transpose(1,2)).amax(dim=2))
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        return torch.cat(convolutions, dim=1)

kreuzer's avatar
kreuzer committed
    def encode_document(self, encoded_sentences):
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        _, (hidden_state, cell_state) = self.document_encoder(encoded_sentences.flip(dims=(0,)))
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        return hidden_state, cell_state

    def encode(self, document):

        encoded_sentences = self.encode_sentences(document)

        return encoded_sentences, self.encode_document(encoded_sentences)
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
    def forward(self, document, k=3):
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        encoded_sentences, states = self.encode(document)
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        logits = self.projector(self.sentence_extractor(encoded_sentences, states)[0])
        
kreuzer's avatar
kreuzer committed
        probs = self.softmax(logits)[:,0]
        
kreuzer's avatar
kreuzer committed
        if k < len(probs):
            return probs.topk(k).indices, probs # handle doc weniger als 3 sents? 
        return torch.arange(len(probs)), probs
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
    def test(self, dataset):

wu's avatar
wu committed
        running_rouge_1 = 0.0
        running_rouge_2 = 0.0
        running_rouge_l = 0.0
        
        self.eval()
        with torch.no_grad():

            for datapoint in dataset:
kreuzer's avatar
kreuzer committed
                
wu's avatar
wu committed
                top_indices, probs = self.__call__(datapoint.document)
                r_1, r_2, r_l = utils.rouge(utils.select_elements(datapoint.raw_document, top_indices), datapoint.raw_summary, verbose=True)  
                running_rouge_1 += r_1
                running_rouge_2 += r_2
                running_rouge_l += r_l
            
        epoch_rouge_1 = running_rouge_1 / len(dataset)
        epoch_rouge_2 = running_rouge_2 / len(dataset)  
        epoch_rouge_l = running_rouge_l / len(dataset)
kreuzer's avatar
kreuzer committed

wu's avatar
wu committed
        return epoch_rouge_1, epoch_rouge_2, epoch_rouge_l
kreuzer's avatar
kreuzer committed
        
    def validation(self, dataset):

        return sum(self.test(dataset)) / 3.0
kreuzer's avatar
kreuzer committed


class ActorOnlySummarisationModel(SummarisationModel):
    
kreuzer's avatar
kreuzer committed
    def __init__(self):

        super().__init__()

        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001)

wu's avatar
wu committed
    def training_epoch(self, dataloader, learning_rate=None): # def scheduler? or global variable?
kreuzer's avatar
kreuzer committed

wu's avatar
wu committed
        if learning_rate != None:
            for g in self.optimizer.param_groups:
                g['lr'] = learning_rate 
kreuzer's avatar
kreuzer committed
        
        self.train()

        epoch_loss = 0.0
        epoch_rouge = 0.0
            
        for batch in dataloader:

            self.optimizer.zero_grad()

            for datapoint in batch:
                
                _, probs = self.__call__(datapoint.document)

                o = datapoint.p_searchspace @ torch.log(probs) + datapoint.n_searchspace @ torch.log(1 - probs)

                idx_sample = torch.argmax(o)

                loss = - datapoint.top_rouge[idx_sample] * o[idx_sample]

                loss.backward()

                epoch_loss += loss.item()
                epoch_rouge += datapoint.top_rouge[idx_sample]
            
            self.optimizer.step()
        
        return epoch_loss / len(dataloader.dataset), epoch_rouge / len(dataloader.dataset)
wu's avatar
wu committed
        
kreuzer's avatar
kreuzer committed

class SummarisationModelWithCrossEntropyLoss(SummarisationModel):
    
kreuzer's avatar
kreuzer committed
    def __init__(self):
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        super().__init__()
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        self.loss_fn = nn.BCELoss(reduction='sum')
kreuzer's avatar
kreuzer committed
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001)

wu's avatar
wu committed
    def training_epoch(self, dataloader, learning_rate=None):

        if learning_rate != None:
            for g in self.optimizer.param_groups:
                g['lr'] = learning_rate 
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        self.train()
kreuzer's avatar
kreuzer committed

        epoch_loss = 0.0
kreuzer's avatar
kreuzer committed
            
        for batch in dataloader:

kreuzer's avatar
kreuzer committed
            self.optimizer.zero_grad()
kreuzer's avatar
kreuzer committed

            for datapoint in batch:
                
                _, probs = self.__call__(datapoint.document)

kreuzer's avatar
kreuzer committed
                loss = self.loss_fn(probs, datapoint.bin_summary)
kreuzer's avatar
kreuzer committed

                loss.backward()
kreuzer's avatar
kreuzer committed

                epoch_loss += loss.item()
kreuzer's avatar
kreuzer committed
            
kreuzer's avatar
kreuzer committed
            self.optimizer.step()
kreuzer's avatar
kreuzer committed
        
        return epoch_loss / len(dataloader.dataset)
kreuzer's avatar
kreuzer committed
class ActorCriticSummarisationModel(SummarisationModel):

wu's avatar
wu committed
    #def __init__(self, )
    pass
kreuzer's avatar
kreuzer committed


class Critic(nn.Module):

wu's avatar
wu committed
    def __init__(self, model, steepness=8, denoise=100):
kreuzer's avatar
kreuzer committed

        super().__init__()

wu's avatar
wu committed
        self.steepness = steepness
        self.denoise = denoise
        # eventuell move to main
        self.optimizer = torch.optim.Adam(self.parameters(), lr= 0.001)
        self.loss_fn = nn.MSELoss()
wu's avatar
wu committed

        model = copy.deepcopy(model)
        #model.eval()
        for param in model.parameters():
            param.requires_grad = False

        self.document_encoder = model.encode_document # encode, encode_document
kreuzer's avatar
kreuzer committed
 
        self.layer_1 = nn.Linear(1200, 600)
        self.layer_2 = nn.Linear(600, 600)
        self.layer_3 = nn.Linear(600, 1)

        W_1 = torch.cat((torch.eye(600), -torch.eye(600)), 1)
        W_2 = torch.eye(600)
        W_3 = torch.ones(600)

wu's avatar
wu committed
        self.layer_1.weight.data /= self.denoise
kreuzer's avatar
kreuzer committed
        self.layer_1.weight.data += W_1
wu's avatar
wu committed
        self.layer_2.weight.data /= self.denoise
kreuzer's avatar
kreuzer committed
        self.layer_2.weight.data += W_2
wu's avatar
wu committed
        self.layer_3.weight.data /= self.denoise
kreuzer's avatar
kreuzer committed
        self.layer_3.weight.data += W_3
wu's avatar
wu committed
        self.layer_3.weight.data /= 600
kreuzer's avatar
kreuzer committed

    def forward(self, encoded_sentences_1, encoded_sentences_2):

        _, document_vec_1 = self.document_encoder(encoded_sentences_1)
        _, document_vec_2 = self.document_encoder(encoded_sentences_2)

kreuzer's avatar
kreuzer committed
        double_document = torch.cat((torch.squeeze(document_vec_1), torch.squeeze(document_vec_2)), dim=-1)
kreuzer's avatar
kreuzer committed

wu's avatar
wu committed
        return torch.tanh(self.steepness*nn.functional.relu(self.layer_3(
            nn.functional.relu(self.layer_2(
            utils.gaussian(self.layer_1(double_document)))))))
wu's avatar
wu committed
    
kreuzer's avatar
kreuzer committed

    def training_epoch(self, dataloader, learning_rate=None):
wu's avatar
wu committed

        if learning_rate != None:
            for g in self.optimizer.param_groups:
                g['lr'] = learning_rate 

        self.train()
        pos_samples= 0.5 

        epoch_loss = 0.0
        for batch in train_dataloader:

            self.optimizer.zero_grad()
        
            for datapoint in batch:

                r = np.random.random()
                if r > pos_samples:

                    k = np.random.choice(len(datapoint.p_searchspace)) 
                    sample = datapoint.sent_vecs.masked_select(datapoint.p_searchspace[k].bool())  # not padded sent embeddngs

                    score = self.__call__(sample, datapoint.gold_sent_vecs) 
                    loss = self.loss_fn(score, datapoint.top_rouge[k])

                else:
                    if len(datapoint.sent_vecs) >= 3: 
                        narray = np.random.choice(len(datapoint.sent_vecs), 3, replace = False) 
                        narray.sort()
                        sample = datapoint.sent_vecs[narray]
                    else:
                        continue # handle len(sent_vecs) < 3 

                    score = self.__call__(sample, datapoint.gold_sent_vecs)
                    loss = self.loss_fn(score, utils.rouge(datapoint.raw_document[narray], datapoint.raw_summary))
wu's avatar
wu committed
                    # rouge score berechnen für negative sample => besser wäre externes berechnen und speichern?

                epoch_loss += loss.item()
kreuzer's avatar
kreuzer committed
                
wu's avatar
wu committed
                loss.backward()
wu's avatar
wu committed
            self.optimizer.step()
        
        return epoch_loss / len(dataloader.dataset)

    def test(self, dataset):
wu's avatar
wu committed
        self.eval()
        pos_samples= 0.5 
        running_diff = 0.0
wu's avatar
wu committed
        with torch.no_grad():

            for datapoint in dataset:
                r = np.random.random()
                if r > pos_samples:
                    k = np.random.choice(len(datapoint.p_searchspace)) 
                    sample = datapoint.sent_vecs.masked_select(datapoint.p_searchspace[k].bool())  # not padded sent embeddngs
                    score = self.__call__(sample, datapoint.gold_sent_vecs) 
                    score_diff = score - datapoint.top_rouge[k] # tensor
                    
                else:
                    if len(datapoint.sent_vecs) >= 3: 
                        narray = np.random.choice(len(datapoint.sent_vecs), 3, replace = False) 
                        narray.sort()
                        sample = datapoint.sent_vecs[narray]
wu's avatar
wu committed
                    else:
                        continue # handle len(sent_vecs) < 3 

                    score = self.__call__(sample, datapoint.gold_sent_vecs)
kreuzer's avatar
kreuzer committed
                    score_diff = score - utils.rouge(datapoint.raw_document[narray], datapoint.raw_summary)
                    # rouge score berechnen für negative sample => besser wäre externes berechnen und speichern?

                running_diff += abs(score_diff.item())
    
        return running_diff / len(dataset)
kreuzer's avatar
kreuzer committed