Commit e12e1d25 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Simplify and generalize utils.make_positions

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/625

Differential Revision: D14822123

Pulled By: myleott

fbshipit-source-id: 8a263d30020588577ee02fb8c6959ff918705103
parent a47630e1
Loading
Loading
Loading
Loading
+8 −15
Original line number Diff line number Diff line
@@ -179,17 +179,14 @@ class FConvEncoder(FairseqEncoder):
            connections are added between layers when ``residual=1`` (which is
            the default behavior).
        dropout (float, optional): dropout to be applied before each conv layer
        left_pad (bool, optional): whether the input is left-padded
            (default: True).
    """

    def __init__(
        self, dictionary, embed_dim=512, embed_dict=None, max_positions=1024,
            convolutions=((512, 3),) * 20, dropout=0.1, left_pad=True,
        convolutions=((512, 3),) * 20, dropout=0.1,
    ):
        super().__init__(dictionary)
        self.dropout = dropout
        self.left_pad = left_pad
        self.num_attention_layers = None

        num_embeddings = len(dictionary)
@@ -202,7 +199,6 @@ class FConvEncoder(FairseqEncoder):
            max_positions,
            embed_dim,
            self.padding_idx,
            left_pad=self.left_pad,
        )

        convolutions = extend_conv_spec(convolutions)
@@ -391,12 +387,10 @@ class FConvDecoder(FairseqIncrementalDecoder):
        max_positions=1024, convolutions=((512, 3),) * 20, attention=True,
        dropout=0.1, share_embed=False, positional_embeddings=True,
        adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0,
            left_pad=False,
    ):
        super().__init__(dictionary)
        self.register_buffer('version', torch.Tensor([2]))
        self.dropout = dropout
        self.left_pad = left_pad
        self.need_attn = True

        convolutions = extend_conv_spec(convolutions)
@@ -418,7 +412,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
            max_positions,
            embed_dim,
            padding_idx,
            left_pad=self.left_pad,
        ) if positional_embeddings else None

        self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
@@ -616,8 +609,8 @@ def Embedding(num_embeddings, embedding_dim, padding_idx):
    return m


def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad):
    m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx):
    m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
    nn.init.normal_(m.weight, 0, 0.1)
    nn.init.constant_(m.weight[padding_idx], 0)
    return m
+4 −8
Original line number Diff line number Diff line
@@ -140,12 +140,11 @@ class FConvEncoder(FairseqEncoder):
    def __init__(
        self, dictionary, embed_dim=512, max_positions=1024,
        convolutions=((512, 3),) * 20, dropout=0.1, attention=False,
        attention_nheads=1, left_pad=True,
        attention_nheads=1,
    ):
        super().__init__(dictionary)
        self.dropout = dropout
        self.num_attention_layers = None
        self.left_pad = left_pad

        num_embeddings = len(dictionary)
        self.padding_idx = dictionary.pad()
@@ -154,7 +153,6 @@ class FConvEncoder(FairseqEncoder):
            max_positions,
            embed_dim,
            self.padding_idx,
            left_pad=self.left_pad,
        )

        def expand_bool_array(val):
@@ -269,14 +267,13 @@ class FConvDecoder(FairseqDecoder):
        convolutions=((512, 3),) * 8, attention=True, dropout=0.1,
        selfattention=False, attention_nheads=1, selfattention_nheads=1,
        project_input=False, gated_attention=False, downsample=False,
        pretrained=False, trained_decoder=None, left_pad=False,
        pretrained=False, trained_decoder=None,
    ):
        super().__init__(dictionary)
        self.register_buffer('version', torch.Tensor([2]))
        self.pretrained = pretrained
        self.pretrained_decoder = trained_decoder
        self.dropout = dropout
        self.left_pad = left_pad
        self.need_attn = True
        in_channels = convolutions[0][0]

@@ -301,7 +298,6 @@ class FConvDecoder(FairseqDecoder):
            max_positions,
            embed_dim,
            padding_idx,
            left_pad=self.left_pad,
        )

        self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
@@ -487,8 +483,8 @@ def Embedding(num_embeddings, embedding_dim, padding_idx):
    return m


def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad):
    m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx):
    m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
    m.weight.data.normal_(0, 0.1)
    return m

+9 −11
Original line number Diff line number Diff line
@@ -291,11 +291,9 @@ class LightConvEncoder(FairseqEncoder):
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
        left_pad (bool, optional): whether the input is left-padded. Default:
            ``True``
    """

    def __init__(self, args, dictionary, embed_tokens, left_pad=True):
    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(dictionary)
        self.dropout = args.dropout

@@ -307,7 +305,6 @@ class LightConvEncoder(FairseqEncoder):
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions, embed_dim, self.padding_idx,
            left_pad=left_pad,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

@@ -399,11 +396,9 @@ class LightConvDecoder(FairseqIncrementalDecoder):
        embed_tokens (torch.nn.Embedding): output embedding
        no_encoder_attn (bool, optional): whether to attend to encoder outputs.
            Default: ``False``
        left_pad (bool, optional): whether the input is left-padded. Default:
            ``False``
    """

    def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False, final_norm=True):
    def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, final_norm=True):
        super().__init__(dictionary)
        self.dropout = args.dropout
        self.share_input_output_embed = args.share_decoder_input_output_embed
@@ -422,7 +417,6 @@ class LightConvDecoder(FairseqIncrementalDecoder):

        self.embed_positions = PositionalEmbedding(
            args.max_target_positions, embed_dim, padding_idx,
            left_pad=left_pad,
            learned=args.decoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

@@ -778,13 +772,17 @@ def Linear(in_features, out_features, bias=True):
    return m


def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, learned=False):
    if learned:
        m = LearnedPositionalEmbedding(num_embeddings + padding_idx + 1, embedding_dim, padding_idx, left_pad)
        m = LearnedPositionalEmbedding(
            num_embeddings + padding_idx + 1, embedding_dim, padding_idx,
        )
        nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
        nn.init.constant_(m.weight[padding_idx], 0)
    else:
        m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, num_embeddings + padding_idx + 1)
        m = SinusoidalPositionalEmbedding(
            embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1,
        )
    return m


+9 −11
Original line number Diff line number Diff line
@@ -263,11 +263,9 @@ class TransformerEncoder(FairseqEncoder):
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
        left_pad (bool, optional): whether the input is left-padded
            (default: True).
    """

    def __init__(self, args, dictionary, embed_tokens, left_pad=True):
    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(dictionary)
        self.dropout = args.dropout

@@ -279,7 +277,6 @@ class TransformerEncoder(FairseqEncoder):
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions, embed_dim, self.padding_idx,
            left_pad=left_pad,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

@@ -390,13 +387,11 @@ class TransformerDecoder(FairseqIncrementalDecoder):
        embed_tokens (torch.nn.Embedding): output embedding
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
        left_pad (bool, optional): whether the input is left-padded
            (default: False).
        final_norm (bool, optional): apply layer norm to the output of the
            final decoder layer (default: True).
    """

    def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False, final_norm=True):
    def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, final_norm=True):
        super().__init__(dictionary)
        self.dropout = args.dropout
        self.share_input_output_embed = args.share_decoder_input_output_embed
@@ -415,7 +410,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):

        self.embed_positions = PositionalEmbedding(
            args.max_target_positions, embed_dim, padding_idx,
            left_pad=left_pad,
            learned=args.decoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

@@ -796,13 +790,17 @@ def Linear(in_features, out_features, bias=True):
    return m


def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, learned=False):
    if learned:
        m = LearnedPositionalEmbedding(num_embeddings + padding_idx + 1, embedding_dim, padding_idx, left_pad)
        m = LearnedPositionalEmbedding(
            num_embeddings + padding_idx + 1, embedding_dim, padding_idx,
        )
        nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
        nn.init.constant_(m.weight[padding_idx], 0)
    else:
        m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, num_embeddings + padding_idx + 1)
        m = SinusoidalPositionalEmbedding(
            embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1,
        )
    return m


+5 −5
Original line number Diff line number Diff line
@@ -13,13 +13,11 @@ from fairseq import utils
class LearnedPositionalEmbedding(nn.Embedding):
    """This module learns positional embeddings up to a fixed maximum size.

    Padding symbols are ignored, but it is necessary to specify whether padding
    is added on the left side (left_pad=True) or right side (left_pad=False).
    Padding symbols are ignored.
    """

    def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad):
    def __init__(self, num_embeddings, embedding_dim, padding_idx):
        super().__init__(num_embeddings, embedding_dim, padding_idx)
        self.left_pad = left_pad
        self.onnx_trace = False

    def forward(self, input, incremental_state=None):
@@ -28,7 +26,9 @@ class LearnedPositionalEmbedding(nn.Embedding):
            # positions is the same for every token when decoding a single step
            positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1))
        else:
            positions = utils.make_positions(input.data, self.padding_idx, self.left_pad, self.onnx_trace)
            positions = utils.make_positions(
                input.data, self.padding_idx, onnx_trace=self.onnx_trace,
            )
        return super().forward(positions)

    def max_positions(self):
Loading