Skip to content
Snippets Groups Projects
Commit 17cef3f6 authored by Ning Dong's avatar Ning Dong Committed by Facebook Github Bot
Browse files

Black formatting for multi_corpus_sampled_dataset.py (#638)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/638

RT

Reviewed By: liezl200

Differential Revision: D14967268

fbshipit-source-id: 2da361497743d90a841fdbf2a50085136c70b468
parent 8776928c
No related branches found
No related tags found
No related merge requests found
......@@ -28,10 +28,10 @@ class MultiCorpusSampledDataset(FairseqDataset):
"""
def __init__(
self,
datasets: Dict[str, FairseqDataset],
sampling_dist: str = 'uniform',
default_key: str = ''
self,
datasets: Dict[str, FairseqDataset],
sampling_dist: str = "uniform",
default_key: str = "",
):
super().__init__()
assert isinstance(datasets, OrderedDict)
......@@ -62,34 +62,26 @@ class MultiCorpusSampledDataset(FairseqDataset):
if self._ordered_indices is None:
self._ordered_indices = OrderedDict(
[
(
key, dataset.ordered_indices()
)
(key, dataset.ordered_indices())
for key, dataset in self.datasets.items()
]
)
return np.arange(len(self))
def _map_index_to_dataset(
self,
key: int,
index: int
):
def _map_index_to_dataset(self, key: int, index: int):
"""
Different underlying datasets have different lengths. In order to ensure
we are not accessing an index outside the range of the current dataset
size, we wrap around. This function should be called after we have
created an ordering for this and all underlying datasets.
"""
assert self._ordered_indices is not None, \
'Must call MultiCorpusSampledDataset.ordered_indices() first'
assert (
self._ordered_indices is not None
), "Must call MultiCorpusSampledDataset.ordered_indices() first"
mapped_index = index % len(self.datasets[key])
return self._ordered_indices[key][mapped_index]
def __getitem__(
self,
index: int
):
def __getitem__(self, index: int):
"""
Get the item associated with index from each underlying dataset.
Since index is in the range of [0, TotalNumInstances], we need to
......@@ -97,17 +89,12 @@ class MultiCorpusSampledDataset(FairseqDataset):
"""
return OrderedDict(
[
(
key, dataset[self._map_index_to_dataset(key, index)]
)
(key, dataset[self._map_index_to_dataset(key, index)])
for key, dataset in self.datasets.items()
]
)
def collater(
self,
samples: List[Dict]
):
def collater(self, samples: List[Dict]):
"""
Generate a mini-batch for this dataset.
To convert this into a regular mini-batch we use the following
......@@ -118,35 +105,26 @@ class MultiCorpusSampledDataset(FairseqDataset):
if len(samples) == 0:
return None
if self.sampling_dist == 'uniform':
if self.sampling_dist == "uniform":
candidates = list(self.datasets.keys())
selected_key = np.random.choice(candidates, 1).item()
selected_samples = [
sample[selected_key]
for sample in samples
]
selected_samples = [sample[selected_key] for sample in samples]
return self.datasets[selected_key].collater(selected_samples)
else:
raise NotImplementedError(
"Specified sampling is currently not Implemented."
)
def get_dummy_batch(
self,
num_tokens: int,
max_positions: int,
):
def get_dummy_batch(self, num_tokens: int, max_positions: int):
"""
Return a dummy batch with a given number of tokens. Assumes that the
max_positions specified is the same for all underlying datasets.
"""
return self.datasets[self.default_key].get_dummy_batch(
num_tokens, max_positions)
num_tokens, max_positions
)
def num_tokens(
self,
index: int
):
def num_tokens(self, index: int):
"""
Return an example's length (number of tokens), used for batching. Here
we return the max across all examples at index across all underlying
......@@ -157,10 +135,7 @@ class MultiCorpusSampledDataset(FairseqDataset):
for key, dataset in self.datasets.items()
)
def size(
self,
index: int
):
def size(self, index: int):
"""
Return an example's size as a float or tuple. Here we return the max
across all underlying datasets. This value is used when filtering a
......@@ -174,14 +149,12 @@ class MultiCorpusSampledDataset(FairseqDataset):
@property
def supports_prefetch(self):
return all(
getattr(dataset, 'supports_prefetch', False)
getattr(dataset, "supports_prefetch", False)
for dataset in self.datasets.values()
)
def prefetch(self, indices):
for key, dataset in self.datasets.items():
dataset.prefetch(
[
self._map_index_to_dataset(key, index) for index in indices
]
[self._map_index_to_dataset(key, index) for index in indices]
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment