Commit d7e19573 authored by Peng-Jen Chen's avatar Peng-Jen Chen Committed by Facebook Github Bot
Browse files

Back translation + denoising in MultilingualTranslation task (#620)

Summary:
- Add language token to MultilingualTranslation task
- Add back translation and denoising loss to MultilingualTranslation task
Pull Request resolved: https://github.com/pytorch/fairseq/pull/620

Reviewed By: liezl200

Differential Revision: D14756873

Pulled By: pipibjc

fbshipit-source-id: 89d668db26848fd95f446edf5923bab2113636f7
parent c2820af0
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -13,9 +13,11 @@ from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTex
from .language_pair_dataset import LanguagePairDataset
from .lm_context_window_dataset import LMContextWindowDataset
from .monolingual_dataset import MonolingualDataset
from .noising import NoisingDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
from .token_block_dataset import TokenBlockDataset
from .transform_eos_dataset import TransformEosDataset
from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset

from .iterators import (
    CountingIterator,
@@ -38,8 +40,10 @@ __all__ = [
    'LanguagePairDataset',
    'LMContextWindowDataset',
    'MonolingualDataset',
    'NoisingDataset',
    'RoundRobinZipDatasets',
    'ShardedIterator',
    'TokenBlockDataset',
    'TransformEosDataset',
    'TransformEosLangPairDataset',
]
+34 −13
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@ import torch
from fairseq import utils

from . import FairseqDataset
from .language_pair_dataset import collate as language_pair_collate, generate_dummy_batch


def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
@@ -36,22 +37,18 @@ def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
    """
    collated_samples = collate_fn(samples)
    s = utils.move_to_cuda(collated_samples) if cuda else collated_samples
    generated_sources = generate_fn(s['net_input'])
    generated_sources = generate_fn(s)

    def update_sample(sample, generated_source):
        sample['target'] = sample['source']  # the original source becomes the target
        sample['source'] = generated_source
        return sample
    id_to_src = {
        sample['id']: sample['source'] for sample in samples
    }

    # Go through each tgt sentence in batch and its corresponding best
    # generated hypothesis and create a backtranslation data pair
    # {id: id, source: generated backtranslation, target: original tgt}
    return [
        update_sample(
            sample=input_sample,
            generated_source=hypos[0]['tokens'].cpu(),  # highest scoring hypo is first
        )
        for input_sample, hypos in zip(samples, generated_sources)
        {'id': id.item(), 'target': id_to_src[id.item()], 'source': hypos[0]['tokens'].cpu()}
        for id, hypos in zip(collated_samples['id'], generated_sources)
    ]


@@ -66,9 +63,15 @@ class BacktranslationDataset(FairseqDataset):
            backtranslated. Only the source side of this dataset will be used.
            After backtranslation, the source sentences in this dataset will be
            returned as the targets.
        backtranslation_fn (callable): function to call to generate
        src_dict (~fairseq.data.Dictionary): the dictionary of backtranslated
            sentences.
        tgt_dict (~fairseq.data.Dictionary, optional): the dictionary of
            sentences to be backtranslated.
        backtranslation_fn (callable, optional): function to call to generate
            backtranslations. This is typically the `generate` method of a
            :class:`~fairseq.sequence_generator.SequenceGenerator` object.
            Pass in None when it is not available at initialization time, and
            use set_backtranslation_fn function to set it when available.
        output_collater (callable, optional): function to call on the
            backtranslated samples to create the final batch
            (default: ``tgt_dataset.collater``).
@@ -78,7 +81,9 @@ class BacktranslationDataset(FairseqDataset):
    def __init__(
        self,
        tgt_dataset,
        backtranslation_fn,
        src_dict,
        tgt_dict=None,
        backtranslation_fn=None,
        output_collater=None,
        cuda=True,
        **kwargs
@@ -88,6 +93,8 @@ class BacktranslationDataset(FairseqDataset):
        self.output_collater = output_collater if output_collater is not None \
            else tgt_dataset.collater
        self.cuda = cuda if torch.cuda.is_available() else False
        self.src_dict = src_dict
        self.tgt_dict = tgt_dict

    def __getitem__(self, index):
        """
@@ -100,6 +107,9 @@ class BacktranslationDataset(FairseqDataset):
    def __len__(self):
        return len(self.tgt_dataset)

    def set_backtranslation_fn(self, backtranslation_fn):
        self.backtranslation_fn = backtranslation_fn

    def collater(self, samples):
        """Merge and backtranslate a list of samples to form a mini-batch.

@@ -119,6 +129,8 @@ class BacktranslationDataset(FairseqDataset):
        Returns:
            dict: a mini-batch with keys coming from *output_collater*
        """
        if samples[0].get('is_dummy', False):
            return samples
        samples = backtranslate_samples(
            samples=samples,
            collate_fn=self.tgt_dataset.collater,
@@ -131,7 +143,16 @@ class BacktranslationDataset(FairseqDataset):

    def get_dummy_batch(self, num_tokens, max_positions):
        """Just use the tgt dataset get_dummy_batch"""
        return self.tgt_dataset.get_dummy_batch(num_tokens, max_positions)
        def collate_fn(samples):
            return language_pair_collate(
                samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(),
                input_feeding=True,
            )
        dummy_batch = generate_dummy_batch(
            num_tokens, collate_fn,
            self.src_dict, tgt_dict=self.tgt_dict)
        dummy_batch['is_dummy'] = True
        return dummy_batch

    def num_tokens(self, index):
        """Just use the tgt dataset num_tokens"""
+17 −11
Original line number Diff line number Diff line
@@ -68,6 +68,19 @@ def collate(
    return batch


def generate_dummy_batch(num_tokens, collate_fn, src_dict, src_len=128, tgt_dict=None, tgt_len=128):
    """Return a dummy batch with a given number of tokens."""
    bsz = num_tokens // max(src_len, tgt_len)
    return collate_fn([
        {
            'id': i,
            'source': src_dict.dummy_sentence(src_len),
            'target': tgt_dict.dummy_sentence(tgt_len) if tgt_dict is not None else None,
        }
        for i in range(bsz)
    ])


class LanguagePairDataset(FairseqDataset):
    """
    A pair of torch.utils.data.Datasets.
@@ -192,15 +205,7 @@ class LanguagePairDataset(FairseqDataset):
            max_positions,
            (self.max_source_positions, self.max_target_positions),
        )
        bsz = max(num_tokens // max(src_len, tgt_len), 1)
        return self.collater([
            {
                'id': i,
                'source': self.src_dict.dummy_sentence(src_len),
                'target': self.tgt_dict.dummy_sentence(tgt_len) if self.tgt_dict is not None else None,
            }
            for i in range(bsz)
        ])
        return generate_dummy_batch(num_tokens, self.collater, self.src_dict, src_len, self.tgt_dict, tgt_len)

    def num_tokens(self, index):
        """Return the number of tokens in a sample. This value is used to
@@ -227,9 +232,10 @@ class LanguagePairDataset(FairseqDataset):
    def supports_prefetch(self):
        return (
            getattr(self.src, 'supports_prefetch', False)
            and getattr(self.tgt, 'supports_prefetch', False)
            and (getattr(self.tgt, 'supports_prefetch', False) or self.tgt is None)
        )

    def prefetch(self, indices):
        self.src.prefetch(indices)
        if self.tgt is not None:
            self.tgt.prefetch(indices)
+8 −0
Original line number Diff line number Diff line
@@ -301,3 +301,11 @@ class NoisingDataset(torch.utils.data.Dataset):
        The length of the noising dataset is the length of src.
        """
        return len(self.src_dataset)

    @property
    def supports_prefetch(self):
        return self.src_dataset.supports_prefetch

    def prefetch(self, indices):
        if self.src_dataset.supports_prefetch:
            self.src_dataset.prefetch(indices)
+80 −0
Original line number Diff line number Diff line
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.


from . import FairseqDataset
from typing import Optional


class TransformEosLangPairDataset(FairseqDataset):
    """A :class:`~fairseq.data.FairseqDataset` wrapper that transform bos on
    collated samples of language pair dataset.

    Note that the transformation is applied in :func:`collater`.

    Args:
        dataset (~fairseq.data.FairseqDataset): dataset that collates sample into
            LanguagePairDataset schema
        src_eos (int): original source end-of-sentence symbol index to be replaced
        new_src_eos (int, optional): new end-of-sentence symbol index to replace source eos symbol
        tgt_bos (int, optional): original target beginning-of-sentence symbol index to be replaced
        new_tgt_bos (int, optional): new beginning-of-sentence symbol index to replace at the
            beginning of 'prev_output_tokens'
    """

    def __init__(
        self,
        dataset: FairseqDataset,
        src_eos: int,
        new_src_eos: Optional[int] = None,
        tgt_bos: Optional[int] = None,
        new_tgt_bos: Optional[int] = None,
    ):
        self.dataset = dataset
        self.src_eos = src_eos
        self.new_src_eos = new_src_eos
        self.tgt_bos = tgt_bos
        self.new_tgt_bos = new_tgt_bos

    def __getitem__(self, index):
        return self.dataset[index]

    def __len__(self):
        return len(self.dataset)

    def collater(self, samples):
        samples = self.dataset.collater(samples)

        # TODO: support different padding direction
        if self.new_src_eos is not None:
            assert(samples['net_input']['src_tokens'][:, -1] != self.src_eos).sum() == 0
            samples['net_input']['src_tokens'][:, -1] = self.new_src_eos

        if self.new_tgt_bos is not None:
            assert (samples['net_input']['prev_output_tokens'][:, 0] != self.tgt_bos).sum() == 0
            samples['net_input']['prev_output_tokens'][:, 0] = self.new_tgt_bos

        return samples

    def get_dummy_batch(self, *args, **kwargs):
        return self.dataset.get_dummy_batch(*args, **kwargs)

    def num_tokens(self, index):
        return self.dataset.num_tokens(index)

    def size(self, index):
        return self.dataset.size(index)

    def ordered_indices(self):
        return self.dataset.ordered_indices()

    @property
    def supports_prefetch(self):
        return getattr(self.dataset, 'supports_prefetch', False)

    def prefetch(self, indices):
        return self.dataset.prefetch(indices)
Loading