Commit 90d6eac2 authored by Ning Dong's avatar Ning Dong Committed by Facebook Github Bot
Browse files

Enable custom sampling strategy in MultiCorpusSampledDataset (#639)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/639

Add argument sampling_func in the constructor to enable custom sampling over a list of dataset keys. The default strategy is to sample uniformly as it did previously.

Reviewed By: liezl200

Differential Revision: D14965774

fbshipit-source-id: f3285688a9ae3729c0ba12c22254c1144d0eea9e
parent 17cef3f6
Loading
Loading
Loading
Loading
+12 −15
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@
# can be found in the PATENTS file in the same directory.

from collections import OrderedDict
from typing import Dict, List
from typing import Callable, Dict, List

import numpy as np

@@ -16,13 +16,13 @@ from . import FairseqDataset
class MultiCorpusSampledDataset(FairseqDataset):
    """
    Stores multiple instances of FairseqDataset together and in every iteration
    creates a batch by first sampling a dataset occording to a specified
    creates a batch by first sampling a dataset according to a specified
    probability distribution and then getting instances from that dataset.

    Args:
        datasets: an OrderedDict of FairseqDataset instances.
        sampling_dist: the sampling distribution used to select the dataset
            from which the batch is created in a given iteration.
        sampling_func: A function for sampling over list of dataset keys.
                Default strategy is to sample uniformly.
        default_key: string which specifies the default key to be used for
            generating dummy batches etc.
    """
@@ -30,14 +30,17 @@ class MultiCorpusSampledDataset(FairseqDataset):
    def __init__(
        self,
        datasets: Dict[str, FairseqDataset],
        sampling_dist: str = "uniform",
        sampling_func: Callable[[List], int] = (
            # Sample from uniform distribution
            lambda x: np.random.choice(x, 1).item()
        ),
        default_key: str = "",
    ):
        super().__init__()
        assert isinstance(datasets, OrderedDict)
        assert default_key in datasets
        self.datasets = datasets
        self.sampling_dist = sampling_dist
        self.sampling_func = sampling_func
        self.default_key = default_key

        self.total_num_instances = 0
@@ -105,15 +108,9 @@ class MultiCorpusSampledDataset(FairseqDataset):
        if len(samples) == 0:
            return None

        if self.sampling_dist == "uniform":
            candidates = list(self.datasets.keys())
            selected_key = np.random.choice(candidates, 1).item()
        selected_key = self.sampling_func(list(self.datasets.keys()))
        selected_samples = [sample[selected_key] for sample in samples]
        return self.datasets[selected_key].collater(selected_samples)
        else:
            raise NotImplementedError(
                "Specified sampling is currently not Implemented."
            )

    def get_dummy_batch(self, num_tokens: int, max_positions: int):
        """
+98 −0
Original line number Diff line number Diff line
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import unittest
from collections import OrderedDict

import numpy as np
import torch
from fairseq.data import LanguagePairDataset, TokenBlockDataset
from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset
from tests.test_train import mock_dict


class TestMultiCorpusSampledDataset(unittest.TestCase):
    def setUp(self):
        d = mock_dict()
        tokens_1 = torch.LongTensor([1]).view(1, -1)
        tokens_ds1 = TokenBlockDataset(
            tokens_1,
            sizes=[tokens_1.size(-1)],
            block_size=1,
            pad=0,
            eos=1,
            include_targets=False,
        )
        self.dataset_1 = LanguagePairDataset(
            tokens_ds1, tokens_ds1.sizes, d, shuffle=False
        )
        tokens_2 = torch.LongTensor([2]).view(1, -1)
        tokens_ds2 = TokenBlockDataset(
            tokens_2,
            sizes=[tokens_2.size(-1)],
            block_size=1,
            pad=0,
            eos=1,
            include_targets=False,
        )
        self.dataset_2 = LanguagePairDataset(
            tokens_ds2, tokens_ds2.sizes, d, shuffle=False
        )

    def _test_sample_helper(
        self,
        expected_sample_from_first_ds_percentage,
        num_samples=1000,
        sampling_func=None,
    ):
        # To make sure test is not flaky
        np.random.seed(0)
        if sampling_func is None:
            m = MultiCorpusSampledDataset(
                OrderedDict({0: self.dataset_1, 1: self.dataset_2}), default_key=0
            )
        else:
            m = MultiCorpusSampledDataset(
                OrderedDict({0: self.dataset_1, 1: self.dataset_2}),
                sampling_func=sampling_func,
                default_key=0,
            )
        m.ordered_indices()
        count_sample_from_first_dataset = 0
        for _ in range(num_samples):
            if m.collater([m[0], m[1]])["net_input"]["src_tokens"][0] == 1:
                count_sample_from_first_dataset += 1
        sample_from_first_ds_percentage = (
            1.0 * count_sample_from_first_dataset / num_samples
        )
        self.assertLess(
            abs(
                sample_from_first_ds_percentage
                - expected_sample_from_first_ds_percentage
            ),
            0.01,
        )

    def test_multi_corpus_sampled_dataset_uniform_sample(self):
        self._test_sample_helper(expected_sample_from_first_ds_percentage=0.5)

    def test_multi_corpus_sampled_dataset_weighted_sample(self):
        def naive_weighted_sample(weights):
            def f(l):
                v = np.random.random()
                agg = 0
                for i, weight in enumerate(weights):
                    agg += weight
                    if agg > v:
                        return i

            return f

        self._test_sample_helper(
            expected_sample_from_first_ds_percentage=0.9,
            sampling_func=naive_weighted_sample(weights=[0.9, 0.1]),
        )