Commit f040158a authored by Kartikay Khandelwal's avatar Kartikay Khandelwal Committed by Facebook Github Bot
Browse files

Add Transformer Sentence Encoder for BERT and XLM Pre-training in PyText (#621)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/621

In this commit, I add some modules to Fairseq needed to set up Bert/XLM style pretraining.

Reviewed By: borguz

Differential Revision: D14719663

fbshipit-source-id: 1c5c36b6b2cde1c9bcd3c9e9ac853d2b7ae64102
parent 3658fa32
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@
from .adaptive_input import AdaptiveInput
from .adaptive_softmax import AdaptiveSoftmax
from .beamable_mm import BeamableMM
from .bert_layer_norm import BertLayerNorm
from .character_token_embedder import CharacterTokenEmbedder
from .conv_tbc import ConvTBC
from .downsampled_multihead_attention import DownsampledMultiHeadAttention
@@ -23,12 +24,15 @@ from .mean_pool_gating_network import MeanPoolGatingNetwork
from .multihead_attention import MultiheadAttention
from .scalar_bias import ScalarBias
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer
from .transformer_sentence_encoder import TransformerSentenceEncoder
from .unfold import unfold1d

__all__ = [
    'AdaptiveInput',
    'AdaptiveSoftmax',
    'BeamableMM',
    'BertLayerNorm',
    'CharacterTokenEmbedder',
    'ConvTBC',
    'DownsampledMultiHeadAttention',
@@ -44,5 +48,7 @@ __all__ = [
    'MultiheadAttention',
    'ScalarBias',
    'SinusoidalPositionalEmbedding',
    'TransformerSentenceEncoderLayer',
    'TransformerSentenceEncoder',
    'unfold1d',
]
+27 −0
Original line number Diff line number Diff line
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import torch
import torch.nn as nn


class BertLayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        """
        Construct a layernorm module in the TF style used with BERT
        (epsilon inside the square root).
        """
        super(BertLayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias
+198 −0
Original line number Diff line number Diff line
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

from fairseq.modules import (
    MultiheadAttention, LearnedPositionalEmbedding, TransformerSentenceEncoderLayer
)


def init_bert_params(module):
    """
    Initialize the weights specific to the BERT Model.
    This overrides the default initializations depending on the specified arguments.
        1. If normal_init_linear_weights is set then weights of linear
           layer will be initialized using the normal distribution and
           bais will be set to the specified value.
        2. If normal_init_embed_weights is set then weights of embedding
           layer will be initialized using the normal distribution.
        3. If normal_init_proj_weights is set then weights of
           in_project_weight for MultiHeadAttention initialized using
           the normal distribution (to be validated).
    """

    if isinstance(module, nn.Linear):
        module.weight.data.normal_(mean=0.0, std=0.02)
        if module.bias is not None:
            module.bias.data.zero_()
    if isinstance(module, nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=0.02)
    if isinstance(module, MultiheadAttention):
        module.in_proj_weight.data.normal_(mean=0.0, std=0.02)


def PositionalEmbedding(
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: int,
        left_pad: bool
)-> nn.Embedding:
    m = LearnedPositionalEmbedding(
        num_embeddings + padding_idx + 1, embedding_dim, padding_idx, left_pad)
    nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
    nn.init.constant_(m.weight[padding_idx], 0)
    return m


class TransformerSentenceEncoder(nn.Module):
    """
    Implementation for a Bi-directional Transformer based Sentence Encoder used
    in BERT/XLM style pre-trained models.

    This first computes the token embedding using the token embedding matrix,
    position embeddings (if specified) and segment embeddings
    (if specified). After applying the specified number of
    TransformerEncoderLayers, it outputs all the internal states of the
    encoder as well as the final representation associated with the first
    token (usually CLS token).

    Input:
        - tokens: B x T matrix representing sentences
        - segment_labels: B x T matrix representing segment label for tokens

    Output:
        - a tuple of the following:
            - a list of internal model states used to compute the
              predictions where each tensor has shape B x T x C
            - sentence representation associated with first input token
              in format B x C.
    """

    def __init__(
        self,
        padding_idx: int,
        vocab_size: int,
        num_encoder_layers: int = 6,
        embedding_dim: int = 768,
        ffn_embedding_dim: int = 3072,
        num_attention_heads: int = 8,
        dropout: float = 0.1,
        attention_dropout: float = 0.1,
        activation_dropout: float = 0.1,
        max_seq_len: int = 256,
        num_segments: int = 2,
        use_position_embeddings: bool = True,
        encoder_normalize_before: bool = False,
        use_bert_layer_norm: bool = False,
        use_gelu: bool = True,
        apply_bert_init: bool = False,
    ) -> None:

        super().__init__()
        self.padding_idx = padding_idx
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.max_seq_len = max_seq_len
        self.embedding_dim = embedding_dim
        self.num_segments = num_segments
        self.use_position_embeddings = use_position_embeddings
        self.apply_bert_init = apply_bert_init

        self.token_embeddings = nn.Embedding(
            self.vocab_size, self.embedding_dim, self.padding_idx
        )

        self.segment_embeddings = (
            nn.Embedding(self.num_segments, self.embedding_dim, self.padding_idx)
            if self.num_segments > 0
            else None
        )

        self.position_embeddings = (
            PositionalEmbedding(
                self.max_seq_len,
                self.embedding_dim,
                self.padding_idx,
                left_pad=False,
            )
            if self.use_position_embeddings
            else None
        )

        self.layers = nn.ModuleList(
            [
                TransformerSentenceEncoderLayer(
                    embedding_dim=self.embedding_dim,
                    ffn_embedding_dim=ffn_embedding_dim,
                    num_attention_heads=num_attention_heads,
                    dropout=self.dropout,
                    attention_dropout=attention_dropout,
                    activation_dropout=activation_dropout,
                    encoder_normalize_before=encoder_normalize_before,
                    use_bert_layer_norm=use_bert_layer_norm,
                    use_gelu=use_gelu,
                )
                for _ in range(num_encoder_layers)
            ]
        )

        # Apply initialization of model params after building the model
        if self.apply_bert_init:
            self.apply(init_bert_params)

    def forward(
        self,
        tokens: torch.Tensor,
        segment_labels: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        # compute padding mask. This is needed for multi-head attention
        padding_mask = tokens.eq(self.padding_idx)
        if not padding_mask.any():
            padding_mask = None

        # embed positions
        positions = (
            self.position_embeddings(tokens)
            if self.position_embeddings is not None else None
        )

        # embed segments
        segments = (
            self.segment_embeddings(segment_labels)
            if self.segment_embeddings is not None
            else None
        )

        x = self.token_embeddings(tokens)
        if positions is not None:
            x += positions
        if segments is not None:
            x += segments
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)
        inner_states = [x]

        for layer in self.layers:
            x, _ = layer(
                x,
                self_attn_padding_mask=padding_mask,
            )
            inner_states.append(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        sentence_rep = x[:, 0, :]

        return inner_states, sentence_rep
+119 −0
Original line number Diff line number Diff line
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.modules import MultiheadAttention, BertLayerNorm


def gelu(x: torch.Tensor) -> torch.Tensor:
    """
    Implementation of the gelu activation function.
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


class TransformerSentenceEncoderLayer(nn.Module):
    """
    Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
    models.

    If the flag use_bert_layer_norm is set then we use the custom
    BertLayerNorm module instead of nn.LayerNorm.
    """

    def __init__(
        self,
        embedding_dim: float = 768,
        ffn_embedding_dim: float = 3072,
        num_attention_heads: float = 8,
        dropout: float = 0.1,
        attention_dropout: float = 0.1,
        activation_dropout: float = 0.1,
        encoder_normalize_before: bool = True,
        use_bert_layer_norm: bool = True,
        use_gelu: bool = True,
    ) -> None:

        super().__init__()
        # Initialize parameters
        self.embedding_dim = embedding_dim
        self.dropout = dropout
        self.activation_dropout = activation_dropout
        self.normalize_before = encoder_normalize_before

        # Initialize blocks
        self.activation_fn = gelu if use_gelu else F.relu
        self.self_attention = MultiheadAttention(
            self.embedding_dim, num_attention_heads, dropout=attention_dropout
        )

        # layer norm associated with the self attention layer
        self.self_attn_layer_norm = (
            BertLayerNorm(self.embedding_dim)
            if use_bert_layer_norm
            else nn.LayerNorm(self.embedding_dim, eps=1e-12)
        )
        self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
        self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)

        # layer norm associated with the position wise feed-forward NN
        self.final_layer_norm = (
            BertLayerNorm(self.embedding_dim)
            if use_bert_layer_norm
            else nn.LayerNorm(self.embedding_dim, eps=1e-12)
        )

    def _maybe_layer_norm(
        self,
        layer_norm: nn.Module,
        x: torch.Tensor,
        before: bool = False,
        after: bool = False,
    ):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def forward(
        self,
        x: torch.Tensor,
        self_attn_mask: torch.Tensor = None,
        self_attn_padding_mask: torch.Tensor = None,
    ):
        """
        LayerNorm is applied either before or after the self-attention/ffn
        modules similar to the original Transformer imlementation.
        """

        residual = x
        x = self._maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        x, attn = self.self_attention(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self._maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
        residual = x
        x = self._maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self._maybe_layer_norm(self.final_layer_norm, x, after=True)
        return x, attn