Skip to content
Snippets Groups Projects
models.py 9.11 KiB
Newer Older
wu's avatar
wu committed
import copy
import torch
from torch import nn
wu's avatar
wu committed
import numpy as np
wu's avatar
wu committed

kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
class SummarisationModel(nn.Module):
kreuzer's avatar
kreuzer committed

    def __init__(self):
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        super().__init__()
kreuzer's avatar
kreuzer committed

        self.cnns = [
kreuzer's avatar
kreuzer committed
            nn.Conv1d(200, 50, 1, padding="valid", groups=1),
            nn.Conv1d(200, 50, 2, padding="valid", groups=1),
            nn.Conv1d(200, 50, 3, padding="valid", groups=1),
            nn.Conv1d(200, 50, 4, padding="valid", groups=1),
            nn.Conv1d(200, 50, 5, padding="valid", groups=1),
            nn.Conv1d(200, 50, 6, padding="valid", groups=1),
            nn.Conv1d(200, 50, 7, padding="valid", groups=1)
kreuzer's avatar
kreuzer committed
        self.document_encoder = nn.LSTM(350, 600) 
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        self.sentence_extractor = nn.LSTM(350, 600)
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        self.projector = nn.Linear(600, 2)
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        self.softmax = nn.Softmax(dim=-1)
kreuzer's avatar
kreuzer committed
    
kreuzer's avatar
kreuzer committed
    def encode_sentences(self, document):
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        convolutions = []
        for cnn in self.cnns:
            convolutions.append(cnn(document.transpose(1,2)).amax(dim=2))
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        return torch.cat(convolutions, dim=1)

kreuzer's avatar
kreuzer committed
    def encode_document(self, encoded_sentences):
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        _, (hidden_state, cell_state) = self.document_encoder(encoded_sentences.flip(dims=(0,)))
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        return hidden_state, cell_state

    def encode(self, document):

        encoded_sentences = self.encode_sentences(document)

        return encoded_sentences, self.encode_document(encoded_sentences)
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
    def forward(self, document, k=3):
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        encoded_sentences, states = self.encode(document)
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        logits = self.projector(self.sentence_extractor(encoded_sentences, states)[0])
        
kreuzer's avatar
kreuzer committed
        probs = self.softmax(logits)[:,0]
        
kreuzer's avatar
kreuzer committed
        if k < len(probs):
            return probs.topk(k).indices, probs # handle doc weniger als 3 sents? 
        return torch.arange(len(probs)), probs
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
    def test(self, dataset):

kreuzer's avatar
kreuzer committed


class ActorOnlySummarisationModel(SummarisationModel):
    
    def _train(self, dataset, epochs=20, batch_size=20, learning_rate=0.001, shuffle=True):
        
        optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
wu's avatar
wu committed

        training_dataloader = torch.utils.data.DataLoader(dataset.train, batch_size=batch_size, shuffle=shuffle)
        test_dataloader = torch.utils.data.DataLoader(dataset.test, batch_size=batch_size, shuffle=shuffle)

        since = time.time()
        val_rouge_history = []
    
        best_rouge = 0.0
wu's avatar
wu committed
        best_model_wts = copy.deepcopy(self.state_dict())
wu's avatar
wu committed
        
wu's avatar
wu committed
        for epoch in range(epochs):

            print('Epoch {}/{}'.format(epoch, epochs - 1))
            print('-' * 10)
kreuzer's avatar
kreuzer committed
            
wu's avatar
wu committed
            # training phase of the epoch
wu's avatar
wu committed
            running_loss = 0.0
            running_rouge = 0.0

            self.train()
kreuzer's avatar
kreuzer committed
            for batch in training_dataloader:

                optimizer.zero_grad()

                for datapoint in batch:
                    
kreuzer's avatar
kreuzer committed
                    top_indices, probs = self.__call__(datapoint.document)
kreuzer's avatar
kreuzer committed

                    o = datapoint.p_searchspace @ torch.log(probs) + datapoint.n_searchspace @ torch.log(1 - probs)

                    idx_sample = torch.argmax(o)

                    loss = - datapoint.top_rouge[idx_sample] * o[idx_sample]

                    loss.backward()
wu's avatar
wu committed

                    # training statistics (train dataset)
                    running_loss += loss.item()
                    running_rouge += datapoint.top_rouge[idx_sample] # größer als val/test rouge = indiz for searchspace funtionality

kreuzer's avatar
kreuzer committed
                optimizer.step()
wu's avatar
wu committed
            epoch_loss = running_loss/ len(training_dataloader.dataset)
            epoch_rouge = running_rouge/ len(training_dataloader.dataset) # abh. von __len__ of PreprocessedDataSet
            print('Train Loss: {:.4f} Rouge Score: {:.4f}'.format(epoch_loss, epoch_rouge))

wu's avatar
wu committed
            # validation phase of the epoch
wu's avatar
wu committed
            self.eval()
kreuzer's avatar
kreuzer committed
            running_rouge = 0.0
wu's avatar
wu committed
            with torch.no_grad():

kreuzer's avatar
kreuzer committed
                for datapoint in PreprocessedDataSet.validation:
wu's avatar
wu committed
                    
kreuzer's avatar
kreuzer committed
                    top_indices, probs = self.__call__(datapoint.document)
                    running_rouge += rouge(select_elements(datapoint.raw_document, top_indices), datapoint.raw_summary)  
                    # vgl. train rouge for searchspace funtionality
wu's avatar
wu committed
                epoch_rouge = running_rouge/ len(PreprocessedDataSet.validation)
                val_rouge_history.append(epoch_rouge)    
wu's avatar
wu committed
                print('Validation Rouge Score: {:.4f}'.format(epoch_rouge))
wu's avatar
wu committed
            
            # epoch completed, deep copy the best model sofar
wu's avatar
wu committed
            if epoch_rouge > best_rouge:
                best_rouge = epoch_rouge
                best_model_wts = copy.deepcopy(self.state_dict())

wu's avatar
wu committed
        # training completed
wu's avatar
wu committed
        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        print('Best val rouge: {:4f}'.format(best_rouge))
        # write val_rouge_history in file
kreuzer's avatar
kreuzer committed

wu's avatar
wu committed
        # load best model weights
        self.load_state_dict(best_model_wts)
        
kreuzer's avatar
kreuzer committed

class SummarisationModelWithCrossEntropyLoss(SummarisationModel):
    
    def _train(self, dataset, epochs=20, batch_size=20, learning_rate=0.001, shuffle=True):
        
        optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        loss_fn = nn.BCELoss(reduction='sum')
        
        for _ in range(epochs):
            
            training_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
            
            for batch in training_dataloader:

                optimizer.zero_grad()

                for datapoint in batch:
                    
kreuzer's avatar
kreuzer committed
                    _, probs = self.__call__(datapoint.document)
kreuzer's avatar
kreuzer committed

                    loss = loss_fn(probs, datapoint.bin_summary)

                    loss.backward()
                
                optimizer.step()
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed

class ActorCriticSummarisationModel(SummarisationModel):

kreuzer's avatar
kreuzer committed
    def __init__(self, )
kreuzer's avatar
kreuzer committed



class Critic(nn.Module):

wu's avatar
wu committed
    def __init__(self, model, steepness=8, denoise=100):
kreuzer's avatar
kreuzer committed

        super().__init__()

wu's avatar
wu committed
        self.steepness = steepness
        self.denoise = denoise

        model = copy.deepcopy(model)
        #model.eval()
        for param in model.parameters():
            param.requires_grad = False

        self.document_encoder = model.encode_document
kreuzer's avatar
kreuzer committed
 
        self.layer_1 = nn.Linear(1200, 600)
        self.layer_2 = nn.Linear(600, 600)
        self.layer_3 = nn.Linear(600, 1)

        W_1 = torch.cat((torch.eye(600), -torch.eye(600)), 1)
        W_2 = torch.eye(600)
        W_3 = torch.ones(600)

wu's avatar
wu committed
        self.layer_1.weight.data /= self.denoise
kreuzer's avatar
kreuzer committed
        self.layer_1.weight.data += W_1
wu's avatar
wu committed
        self.layer_2.weight.data /= self.denoise
kreuzer's avatar
kreuzer committed
        self.layer_2.weight.data += W_2
wu's avatar
wu committed
        self.layer_3.weight.data /= self.denoise
kreuzer's avatar
kreuzer committed
        self.layer_3.weight.data += W_3
wu's avatar
wu committed
        self.layer_3.weight.data /= 600
kreuzer's avatar
kreuzer committed

    def forward(self, encoded_sentences_1, encoded_sentences_2):

        _, document_vec_1 = self.document_encoder(encoded_sentences_1)
        _, document_vec_2 = self.document_encoder(encoded_sentences_2)

kreuzer's avatar
kreuzer committed
        double_document = torch.cat((torch.squeeze(document_vec_1), torch.squeeze(document_vec_2)), dim=-1)
kreuzer's avatar
kreuzer committed

wu's avatar
wu committed
        return torch.tanh(self.steepness*nn.functional.relu(self.layer_3(
            nn.functional.relu(self.layer_2(
            utils.gaussian(self.layer_1(double_document)))))))
wu's avatar
wu committed
    
kreuzer's avatar
kreuzer committed
    def _train(self, dataset, epochs=200, batch_size=20, learning_rate=0.001, shuffle=True, pos_samples=0.5):
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
        optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        loss_fn = nn.MSELoss()

        for _ in range(epochs):
            
            training_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

            for batch in training_dataloader:

                optimizer.zero_grad()
wu's avatar
wu committed
            
kreuzer's avatar
kreuzer committed
                    for datapoint in batch:

                        r = np.random.random()
                        if r > pos_samples:

wu's avatar
wu committed
                            k = np.random.choice(len(datapoint.p_searchspace)) 
                            sample = datapoint.sent_vecs.masked_select(datapoint.p_searchspace[k].bool())  # not padded sent embeddngs
wu's avatar
wu committed

wu's avatar
wu committed
                            score = self.__call__(sample, datapoint.gold_sent_vecs) 
                            loss = loss_fn(score, datapoint.top_rouge[k])
kreuzer's avatar
kreuzer committed

                        else:
wu's avatar
wu committed
                            if len(datapoint.sent_vecs) >= 3: 
                                narray = np.random.choice(len(datapoint.sent_vecs), 3, replace = False) 
                                narray.sort()
                                sample = datapoint.sent_vecs[narray]
                            else:
                                continue # handle len(sent_vecs) < 3 

                            score = self.__call__(sample, datapoint.gold_sent_vecs)
kreuzer's avatar
kreuzer committed
                            loss = loss_fn(score, utils.rouge(raw_document[narray]), raw_summary)) 
wu's avatar
wu committed
                            # rouge score berechnen für negative sample => besser wäre externes berechnen und speichern?
kreuzer's avatar
kreuzer committed

kreuzer's avatar
kreuzer committed
                        loss.backward()
kreuzer's avatar
kreuzer committed
                
                optimizer.step()
            
            # eval

            # test with rouge