Commit c2820af0 authored by Kartikay Khandelwal's avatar Kartikay Khandelwal Committed by Facebook Github Bot
Browse files

Rename embedding layers to be the same as NMT (#628)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/628

Updating embedding layers in TransformerSentenceEncoder to be compatible with the transformer model.

Reviewed By: liezl200

Differential Revision: D14836883

fbshipit-source-id: 2240f61bf40b191d01b4efdaac4dd7562b4166c6
parent 94e9d77c
Loading
Loading
Loading
Loading
+5 −5
Original line number Diff line number Diff line
@@ -106,7 +106,7 @@ class TransformerSentenceEncoder(nn.Module):
        self.use_position_embeddings = use_position_embeddings
        self.apply_bert_init = apply_bert_init

        self.token_embeddings = nn.Embedding(
        self.embed_tokens = nn.Embedding(
            self.vocab_size, self.embedding_dim, self.padding_idx
        )

@@ -116,7 +116,7 @@ class TransformerSentenceEncoder(nn.Module):
            else None
        )

        self.position_embeddings = (
        self.embed_positions = (
            PositionalEmbedding(
                self.max_seq_len,
                self.embedding_dim,
@@ -161,8 +161,8 @@ class TransformerSentenceEncoder(nn.Module):

        # embed positions
        positions = (
            self.position_embeddings(tokens)
            if self.position_embeddings is not None else None
            self.embed_positions(tokens)
            if self.embed_positions is not None else None
        )

        # embed segments
@@ -172,7 +172,7 @@ class TransformerSentenceEncoder(nn.Module):
            else None
        )

        x = self.token_embeddings(tokens)
        x = self.embed_tokens(tokens)
        if positions is not None:
            x += positions
        if segments is not None:
+2 −2
Original line number Diff line number Diff line
@@ -51,7 +51,7 @@ class TransformerSentenceEncoderLayer(nn.Module):

        # Initialize blocks
        self.activation_fn = gelu if use_gelu else F.relu
        self.self_attention = MultiheadAttention(
        self.self_attn = MultiheadAttention(
            self.embedding_dim, num_attention_heads, dropout=attention_dropout
        )

@@ -97,7 +97,7 @@ class TransformerSentenceEncoderLayer(nn.Module):

        residual = x
        x = self._maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        x, attn = self.self_attention(
        x, attn = self.self_attn(
            query=x,
            key=x,
            value=x,