Commit 8776928c authored by Kartikay Khandelwal's avatar Kartikay Khandelwal Committed by Facebook Github Bot
Browse files

Open Source MLM Implementation in Fairseq (#635)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/635

Adding a task and relevant models, datasets and criteria needed for training Cross-lingual Language Models similar to Masked Language Model used in XLM (Lample and Conneau, 2019 - https://arxiv.org/abs/1901.07291).

Reviewed By: liezl200

Differential Revision: D14943776

fbshipit-source-id: 3e416a730303d1dd4f5b92550c78db989be27073
parent 303b95ce
Loading
Loading
Loading
Loading
+76 −0
Original line number Diff line number Diff line
# Cross-Lingual Language Model Pre-training

Below are some details for training Cross-Lingual Language Models (XLM) - similar to the ones presented in [Lample & Conneau, 2019](https://arxiv.org/pdf/1901.07291.pdf) - in Fairseq. The current implementation only supports the Masked Language Model (MLM) from the paper above.

## Downloading and Tokenizing Monolingual Data

Pointers to the monolingual data from wikipedia, used for training the XLM-style MLM model as well as details on processing (tokenization and BPE) it can be found in the [XLM Github Repository](https://github.com/facebookresearch/XLM#download--preprocess-monolingual-data).

Let's assume the following for the code snippets in later sections to work
- Processed data is in the folder: monolingual_data/processed
- Each language has 3 files for train, test and validation. For example we have the following files for English:
    train.en, valid.en
- We are training a model for 5 languages: Arabic (ar), German (de), English (en), Hindi (hi) and French (fr)
- The vocabulary file is monolingual_data/processed/vocab_mlm


## Fairseq Pre-processing and Binarization

Pre-process and binarize the data with the MaskedLMDictionary and cross_lingual_lm task

```
# Ensure the output directory exists
mkdir -p monolingual_data/fairseq_processed

for lg in ar de en hi fr
do

  fairseq-preprocess -- \
  --task cross_lingual_lm \
  --srcdict monolingual_data/processed/vocab_mlm \
  --only-source \
  --trainpref monolingual_data/processed/train \
  --validpref monolingual_data/processed/valid \
  --testpref monolingual_data/processed/test \
  --destdir monolingual_data/fairseq_processed \
  --workers 20 \
  --source-lang $lg

  # Since we only have a source language, the output file has a None for the
  # target language. Remove this

  for stage in train test valid

    sudo mv $stage.$lg-None.$lg.bin $stage.$lg.bin
    sudo mv $stage.$lg-None.$lg.idx $stage.$lg.idx

  done

done
```

## Train a Cross-lingual Language Model similar to the XLM MLM model

Use the following command to train the model on 5 languages.

```
fairseq-train -- \
--task cross_lingual_lm monolingual_data/processed \
--save-dir checkpoints/mlm
--max-update 2400000 --save-interval 1 --no-epoch-checkpoints \
--arch xlm_base \
--optimizer adam --lr-scheduler reduce_lr_on_plateau \
--lr-shrink 0.5 --lr 0.0001 --min-lr 1e-09 \
--dropout 0.1 \
--criterion masked_lm_loss \
--max-tokens 2048 --tokens-per-sample 256 --no-bias-kv --attention-dropout 0.1 \
--lazy-load --seed 0 \
--masked-lm-only \
--monolingual-langs 'ar,de,en,hi,fr' --num-segment 5 \
--ddp-backend=no_c10d
```

Some Notes:
- Using tokens_per_sample greater than 256 can cause OOM (out-of-memory) issues. Usually since MLM packs in streams of text, this parameter doesn't need much tuning.
- The Evaluation workflow for computing MLM Perplexity on test data is in progress.
- Finetuning this model on a downstream task is something which is not currently available.
+148 −0
Original line number Diff line number Diff line
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.


import math
import torch.nn.functional as F

from fairseq import utils
from . import FairseqCriterion, register_criterion


def compute_cross_entropy_loss(logits, targets, ignore_index=-100):
    """
    Function to compute the cross entropy loss. The default value of
    ignore_index is the same as the default value for F.cross_entropy in
    pytorch.
    """
    assert logits.size(0) == targets.size(-1), \
        "Logits and Targets tensor shapes don't match up"

    loss = F.cross_entropy(
        logits,
        targets,
        reduction="sum",
        ignore_index=ignore_index,
    )
    return loss


@register_criterion('masked_lm_loss')
class MaskedLmLoss(FairseqCriterion):
    """
    Implementation for the loss used in masked language model (MLM) training.
    This optionally also computes the next sentence prediction (NSP) loss and
    adds it to the overall loss based on the specified args. There are three
    cases to consider:
        1) Generic MLM training without NSP loss. In this case sentence_targets
           and sentence_logits are both None.
        2) BERT training without NSP loss. In this case sentence_targets is
           not None but sentence_logits is None and we should not be computing
           a sentence level loss.
        3) BERT training with NSP loss. In this case both sentence_targets and
           sentence_logits are not None and we should be computing a sentence
           level loss. The weight of the sentence level loss is specified as
           an argument.
    """

    def __init__(self, args, task):
        super().__init__(args, task)

    @staticmethod
    def add_args(parser):
        """Args for MaskedLM Loss"""
        # Default for masked_lm_only is False so as to not break BERT training
        parser.add_argument('--masked-lm-only', default=False,
                            action='store_true', help='compute MLM loss only')
        parser.add_argument('--nsp-loss-weight', default=1.0, type=float,
                            help='weight for next sentence prediction'
                                 ' loss (default 1)')

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.
        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        lm_logits, output_metadata = model(**sample["net_input"])

        # reshape lm_logits from (N,T,C) to (N*T,C)
        lm_logits = lm_logits.view(-1, lm_logits.size(-1))
        lm_targets = sample['lm_target'].view(-1)
        lm_loss = compute_cross_entropy_loss(
            lm_logits, lm_targets, self.padding_idx)

        # compute the number of tokens for which loss is computed. This is used
        # to normalize the loss
        ntokens = utils.strip_pad(lm_targets, self.padding_idx).numel()
        loss = lm_loss / ntokens
        nsentences = sample['nsentences']
        # nsentences = 0

        # Compute sentence loss if masked_lm_only is False
        sentence_loss = None
        if not self.args.masked_lm_only:
            sentence_logits = output_metadata['sentence_logits']
            sentence_targets = sample['sentence_target'].view(-1)
            # This needs to be recomputed due to some differences between
            # TokenBlock and BlockPair dataset. This can be resolved with a
            # refactor of BERTModel which we will do in the future.
            # TODO: Remove this after refactor of BERTModel
            nsentences = sentence_targets.size(0)

            # Check for logits being none which can happen when remove_heads
            # is set to true in the BERT model. Ideally we should set
            # masked_lm_only to true in this case, but that requires some
            # refactor in the BERT model.
            if sentence_logits is not None:
                sentence_loss = compute_cross_entropy_loss(
                    sentence_logits, sentence_targets)

                loss += self.args.nsp_loss_weight * (sentence_loss / nsentences)

        # NOTE: as we are summing up per token mlm loss and per sentence nsp loss
        # we don't need to use sample_size as denominator for the gradient
        # here sample_size is just used for logging
        sample_size = 1
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'lm_loss': utils.item(lm_loss.data) if reduce else lm_loss.data,
            # sentence loss is not always computed
            'sentence_loss': (
                (
                    utils.item(sentence_loss.data) if reduce
                    else sentence_loss.data
                ) if sentence_loss is not None else 0.0
            ),
            'ntokens': ntokens,
            'nsentences': nsentences,
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output

    @staticmethod
    def aggregate_logging_outputs(logging_outputs):
        """Aggregate logging outputs from data parallel training."""
        lm_loss_sum = sum(log.get('lm_loss', 0) for log in logging_outputs)
        sentence_loss_sum = sum(
            log.get('sentence_loss', 0) for log in logging_outputs)
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
        agg_loss = sum(log.get('loss', 0) for log in logging_outputs)

        agg_output = {
            'loss': agg_loss / sample_size / math.log(2),
            'lm_loss': lm_loss_sum / ntokens / math.log(2),
            'sentence_loss': sentence_loss_sum / nsentences / math.log(2),
            'nll_loss': lm_loss_sum / ntokens / math.log(2),
            'ntokens': ntokens,
            'nsentences': nsentences,
            'sample_size': sample_size,
        }
        return agg_output
+358 −0
Original line number Diff line number Diff line
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import math

import numpy as np
import torch

from typing import Dict, List, Tuple, Union

from . import FairseqDataset, data_utils

from fairseq.data import Dictionary
from fairseq.data.fb_block_pair_dataset import BlockPairDataset
from fairseq.data.token_block_dataset import TokenBlockDataset


class MaskedLMDataset(FairseqDataset):
    """
    A wrapper Dataset for masked language modelling. The dataset
    wraps around TokenBlockDataset or BlockedPairDataset and creates a batch
    where the input blocks are masked according to the specified masking
    probability. Additionally the batch can also contain sentence level targets
    if this is specified.

    Args:
        dataset: Dataset which generates blocks of data. Only BlockPairDataset
            and TokenBlockDataset are supported.
        sizes: Sentence lengths
        vocab: Dictionary with the vocabulary and special tokens.
        pad_idx: Id of padding token in dictionary
        mask_idx: Id of mask token in dictionary
        classif_token_idx: Id of classification token in dictionary. This is the
            token associated with the sentence embedding (Eg: CLS for BERT)
        sep_token_idx: Id of separator token in dictionary
            (Eg: SEP in BERT)
        seed: Seed for random number generator for reproducibility.
        shuffle: Shuffle the elements before batching.
        has_pairs: Specifies whether the underlying dataset
            generates a pair of blocks along with a sentence_target or not.
            Setting it to True assumes that the underlying dataset generates a
            label for the pair of sentences which is surfaced as
            sentence_target. The default value assumes a single block with no
            sentence target.
        segment_id: An optional segment id for filling in the segment labels
            when we are in the single block setting (Eg: XLM). Default is 0.
        masking_ratio: specifies what percentage of the blocks should be masked.
        masking_prob: specifies the probability of a given token being
            replaced with the "MASK" token.
        random_token_prob: specifies the probability of a given token being
            replaced by a random token from the vocabulary.
        unchanged_prob: specifies the probability of keeping a given
            token unchanged.
    """

    def __init__(
            self,
            dataset: FairseqDataset,
            sizes: np.ndarray,
            vocab: Dictionary,
            pad_idx: int,
            mask_idx: int,
            classif_token_idx: int,
            sep_token_idx: int,
            seed: int = 1,
            shuffle: bool = True,
            has_pairs: bool = True,
            segment_id: int = 0,
            masking_ratio: float = 0.15,
            masking_prob: float = 0.8,
            random_token_prob: float = 0.1
    ):
        # Make sure the input datasets are the ones supported
        assert (
            isinstance(dataset, TokenBlockDataset) or
            isinstance(dataset, BlockPairDataset)
        ), "MaskedLMDataset only wraps TokenBlockDataset or  BlockPairDataset"

        self.dataset = dataset
        self.sizes = np.array(sizes)
        self.vocab = vocab
        self.pad_idx = pad_idx
        self.mask_idx = mask_idx
        self.classif_token_idx = classif_token_idx
        self.sep_token_idx = sep_token_idx
        self.shuffle = shuffle
        self.seed = seed
        self.has_pairs = has_pairs
        self.segment_id = segment_id
        self.masking_ratio = masking_ratio
        self.masking_prob = masking_prob
        self.random_token_prob = random_token_prob

        # If we have only one block then sizes needs to be updated to include
        # the classification token
        if not has_pairs:
            self.sizes = self.sizes + 1

    def __getitem__(
            self,
            index: int
    ):
        # if has_pairs, then expect 2 blocks and a sentence target
        if self.has_pairs:
            (block_one, block_two, sentence_target) = self.dataset[index]
        else:
            block_one = self.dataset[index]

        return {
            "id": index,
            "block_one": block_one,
            "block_two": block_two if self.has_pairs else None,
            "sentence_target": sentence_target if self.has_pairs else None,
        }

    def __len__(self):
        return len(self.dataset)

    def _mask_block(
            self,
            sentence: np.ndarray,
            mask_idx: int,
            pad_idx: int,
            dictionary_token_range: Tuple,
            masking_ratio: float = 0.15,
            masking_prob: float = 0.8,
            random_token_prob: float = 0.1
    ):
        """
        Mask tokens for Masked Language Model training
        Samples mask_ratio tokens that will be predicted by LM.

        Note:This function may not be efficient enough since we had multiple
        conversions between np and torch, we can replace them with torch
        operators later.

        Args:
            sentence: 1d tensor to be masked
            mask_idx: index to use for masking the sentence
            pad_idx: index to use for masking the target for tokens we aren't
                predicting
            dictionary_token_range: range of indices in dictionary which can
                be used for random word replacement
                (e.g. without special characters)
            masking_ratio: specifies what percentage of the blocks should be
                masked.
            masking_prob: specifies the probability of a given token being
                replaced with the "MASK" token.
            random_token_prob: specifies the probability of a given token being
                replaced by a random token from the vocabulary
        Return:
            masked_sent: masked sentence
            target: target with words which we are not predicting replaced
                by pad_idx
        """
        masked_sent = np.copy(sentence)
        sent_length = len(sentence)
        mask_num = math.ceil(sent_length * masking_ratio)
        mask = np.random.choice(sent_length, mask_num)
        target = np.copy(sentence)

        for i in range(sent_length):
            if i in mask:
                rand = np.random.random()

                # replace with mask if probability is less than masking_prob
                # (Eg: 0.8)
                if rand < masking_prob:
                    masked_sent[i] = mask_idx

                # replace with random token if probability is less than
                # masking_prob + random_token_prob (Eg: 0.9)
                elif rand < (masking_prob + random_token_prob):
                    # sample random token from dictionary
                    masked_sent[i] = (
                        np.random.randint(
                            dictionary_token_range[0], dictionary_token_range[1]
                        )
                    )
            else:
                target[i] = pad_idx

        return masked_sent, target

    def _collate(
            self,
            samples: List[Dict],
            pad_idx: int,
            eos_idx: int
    ):
        """
        Does the heavy lifting for creating a batch from the input list of
        examples. The logic is as follows:
            1. Mask the input blocks. In case has_pair is True then we have 2
               blocks to mask.
            2. Prepend the first masked block tensor with the special token
               used as sentence embedding. Eg: CLS in BERT. This happens
               irrespective of the value of has_pair.
            3. If has_pair is True, then append the first masked block with the
               special separator token (eg: SEP for BERT) and compute segment
               label accordingly. In this case, also append the second masked
               block with this special separator token and compute its segment
               label.
            4. For the targets tensor, prepend and append with padding index
               accordingly.
            5. Concatenate all tensors.
        """
        if len(samples) == 0:
            return {}
        # To ensure determinism, we reset the state of the PRNG after every
        # batch based on the seed and the first id of the batch. This ensures
        # that across epochs we get the same mask for the same example. This
        # is needed for reproducibility and is how BERT does masking
        # TODO: Can we add deteminism without this constraint?
        with data_utils.numpy_seed(self.seed + samples[0]["id"]):
            for s in samples:

                # token range is needed for replacing with random token during
                # masking
                token_range = (self.vocab.nspecial, len(self.vocab))

                # mask according to specified probabilities.
                masked_blk_one, masked_tgt_one = self._mask_block(
                    s["block_one"], self.mask_idx, self.pad_idx, token_range)

                tokens = np.concatenate([
                    [self.classif_token_idx], masked_blk_one
                ])
                targets = np.concatenate([[self.pad_idx], masked_tgt_one])
                segments = np.ones(len(tokens)) * self.segment_id

                # if has_pairs is True then we need to add the SEP token to both
                # the blocks after masking and re-compute segments based on the new
                # lengths.
                if self.has_pairs:
                    tokens_one = np.concatenate([tokens, [self.sep_token_idx]])
                    targets_one = np.concatenate([targets, [self.pad_idx]])

                    masked_blk_two, masked_tgt_two = self._mask_block(
                        s["block_two"], self.mask_idx, self.pad_idx, token_range)
                    tokens_two = np.concatenate(
                        [masked_blk_two, [self.sep_token_idx]])
                    targets_two = np.concatenate([masked_tgt_two, [self.pad_idx]])

                    # block + 1 sep + 1 special (CLS)
                    segments_one = np.zeros(len(tokens_one))
                    # block + 1 sep
                    segments_two = np.ones(len(tokens_two))

                    tokens = np.concatenate([tokens_one, tokens_two])
                    targets = np.concatenate([targets_one, targets_two])
                    segments = np.concatenate([segments_one, segments_two])

                s["source"] = torch.LongTensor(tokens)
                s["segment_labels"] = torch.LongTensor(segments)
                s["lm_target"] = torch.LongTensor(targets)

        def merge(key):
            return data_utils.collate_tokens(
                [s[key] for s in samples], pad_idx, eos_idx, left_pad=False
            )
        return {
            "id": torch.LongTensor([s["id"] for s in samples]),
            "ntokens": sum(len(s["source"]) for s in samples),
            "net_input": {
                "tokens": merge("source"),
                "segment_labels": merge("segment_labels"),
            },
            "lm_target": merge("lm_target"),
            "sentence_target": torch.LongTensor(
                [s["sentence_target"] for s in samples]
            ) if self.has_pairs else None,
            "nsentences": len(samples),
        }

    def collater(
            self,
            samples: List[Dict]
    ):
        """Merge a list of samples to form a mini-batch.

        Args:
            samples (List[dict]): samples to collate

        Returns:
            dict: a mini-batch of data
        """
        return self._collate(samples, self.vocab.pad(), self.vocab.eos())

    def get_dummy_batch(
            self,
            num_tokens: int,
            max_positions: Union[float, int],
            tgt_len: int = 12
    ):
        """
        Return a dummy batch with a given number of tokens.
        """
        if isinstance(max_positions, float) or isinstance(max_positions, int):
            tgt_len = min(tgt_len, max_positions)
        source = self.vocab.dummy_sentence(tgt_len)
        sentence_target = 0
        bsz = num_tokens // tgt_len

        return self.collater(
            [
                {
                    "id": i,
                    "block_one": source,
                    "block_two": source if self.has_pairs else None,
                    "sentence_target": sentence_target if self.has_pairs else None,
                }
                for i in range(bsz)
            ]
        )

    def num_tokens(
            self,
            index: int
    ):
        """
        Return the number of tokens in a sample. This value is used to
        enforce max-tokens during batching.
        """
        return self.sizes[index]

    def size(
            self,
            index: int
    ):
        """
        Return an example's size as a float or tuple. This value is used when
        filtering a dataset with max-positions.
        """
        return self.sizes[index]

    def ordered_indices(self):
        """
        Return an ordered list of indices. Batches will be constructed based
        on this order.
        """
        if self.shuffle:
            return np.random.permutation(len(self))
        else:
            order = [np.arange(len(self))]
            order.append(self.sizes)
            return np.lexsort(order)

    @property
    def supports_prefetch(self):
        return getattr(self.dataset, "supports_prefetch", False)

    def prefetch(self, indices):
        self.dataset.prefetch(indices)
+60 −0
Original line number Diff line number Diff line
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

from fairseq.data import Dictionary


class MaskedLMDictionary(Dictionary):
    """
    Dictionary for Masked Language Modelling tasks. This extends Dictionary by
    adding the mask symbol.
    """
    def __init__(
        self,
        pad='<pad>',
        eos='</s>',
        unk='<unk>',
        mask='<mask>',
    ):
        super().__init__(pad, eos, unk)
        self.mask_word = mask
        self.mask_index = self.add_symbol(mask)
        self.nspecial = len(self.symbols)

    def mask(self):
        """Helper to get index of mask symbol"""
        return self.mask_index


class BertDictionary(MaskedLMDictionary):
    """
    Dictionary for BERT task. This extends MaskedLMDictionary by adding support
    for cls and sep symbols.
    """
    def __init__(
        self,
        pad='<pad>',
        eos='</s>',
        unk='<unk>',
        mask='<mask>',
        cls='<cls>',
        sep='<sep>'
    ):
        super().__init__(pad, eos, unk, mask)
        self.cls_word = cls
        self.sep_word = sep
        self.cls_index = self.add_symbol(cls)
        self.sep_index = self.add_symbol(sep)
        self.nspecial = len(self.symbols)

    def cls(self):
        """Helper to get index of cls symbol"""
        return self.cls_index

    def sep(self):
        """Helper to get index of sep symbol"""
        return self.sep_index
+187 −0

File added.

Preview size limit exceeded, changes collapsed.

Loading