Commit a47630e1 authored by Liezl Puzon's avatar Liezl Puzon Committed by Facebook Github Bot
Browse files

Fix hybrid transformer state dict update after encoder layernorm rename (#633)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/633

Pull Request resolved: https://github.com/pytorch/translate/pull/456

This diff makes it easier to upgrade the state dict for components that use TransformerEncoderLayer

Reviewed By: jhcross

Differential Revision: D14916941

fbshipit-source-id: 6d0258c8a9492a720684dadce59c90fc87cbf5cf
parent 58b912f6
Loading
Loading
Loading
Loading
+21 −12
Original line number Diff line number Diff line
@@ -368,18 +368,8 @@ class TransformerEncoder(FairseqEncoder):
            state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
        for i in range(len(self.layers)):
            # update layer norms
            layer_norm_map = {
                '0': 'self_attn_layer_norm',
                '1': 'final_layer_norm'
            }
            for old, new in layer_norm_map.items():
                for m in ('weight', 'bias'):
                    k = '{}.layers.{}.layer_norms.{}.{}'.format(name, i, old, m)
                    if k in state_dict:
                        state_dict[
                            '{}.layers.{}.{}.{}'.format(name, i, new, m)
                        ] = state_dict[k]
                        del state_dict[k]
            self.layers[i].upgrade_state_dict_named(state_dict, f"{name}.layers.{i}")

        version_key = '{}.version'.format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
            # earlier checkpoints did not normalize after the stack of layers
@@ -605,6 +595,25 @@ class TransformerEncoderLayer(nn.Module):
        self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
        self.final_layer_norm = LayerNorm(self.embed_dim)

    def upgrade_state_dict_named(self, state_dict, name):
        """
        Rename layer norm states from `...layer_norms.0.weight` to
        `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
        `...final_layer_norm.weight`
        """
        layer_norm_map = {
            '0': 'self_attn_layer_norm',
            '1': 'final_layer_norm'
        }
        for old, new in layer_norm_map.items():
            for m in ('weight', 'bias'):
                k = f'{name}.layer_norms.{old}.{m}'
                if k in state_dict:
                    state_dict[
                        f'{name}.{new}.{m}'
                    ] = state_dict[k]
                    del state_dict[k]

    def forward(self, x, encoder_padding_mask):
        """
        Args: