Skip to content
Snippets Groups Projects
Commit e112d501 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Make MultiCorpusSampledDataset and IndexedCachedDataset Picklable

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/691

Differential Revision: D15172543

Pulled By: myleott

fbshipit-source-id: f2b626ff7f5e95f0ddc83c105af7ab9d092a135e
parent 91c78477
No related branches found
No related tags found
No related merge requests found
......@@ -139,6 +139,10 @@ class IndexedCachedDataset(IndexedDataset):
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
ptx += size
if self.data_file:
# close and delete data file after prefetch so we can pickle
self.data_file.close()
self.data_file = None
def __getitem__(self, i):
self.check_index(i)
......
......@@ -13,6 +13,11 @@ import numpy as np
from . import FairseqDataset
def uniform_sampler(x):
# Sample from uniform distribution
return np.random.choice(x, 1).item()
class MultiCorpusSampledDataset(FairseqDataset):
"""
Stores multiple instances of FairseqDataset together and in every iteration
......@@ -30,16 +35,15 @@ class MultiCorpusSampledDataset(FairseqDataset):
def __init__(
self,
datasets: Dict[str, FairseqDataset],
sampling_func: Callable[[List], int] = (
# Sample from uniform distribution
lambda x: np.random.choice(x, 1).item()
),
sampling_func: Callable[[List], int] = None,
default_key: str = "",
):
super().__init__()
assert isinstance(datasets, OrderedDict)
assert default_key in datasets
self.datasets = datasets
if sampling_func is None:
sampling_func = uniform_sampler
self.sampling_func = sampling_func
self.default_key = default_key
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment