Commit e88ad84b authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Use --train-subset and --valid-subset properly

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/614

Differential Revision: D14712321

Pulled By: myleott

fbshipit-source-id: 8ef973c5d30ebccf0df0f1cabdddd590248a8f8d
parent 3efc39ee
Loading
Loading
Loading
Loading
+14 −16
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ def main(args, init_distributed=False):
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])
    load_dataset_splits(args, task)

    # Initialize distributed training (after data loading)
    if init_distributed:
@@ -64,8 +64,8 @@ def main(args, init_distributed=False):
        task.max_positions(),
        model.max_positions(),
    )
    dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
    oom_batch = task.dataset('train').get_dummy_batch(1, max_positions)
    dummy_batch = task.dataset(args.train_subset).get_dummy_batch(args.max_tokens, max_positions)
    oom_batch = task.dataset(args.train_subset).get_dummy_batch(1, max_positions)

    # Build trainer
    trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch)
@@ -358,11 +358,9 @@ def load_checkpoint(args, trainer, epoch_itr):
    return False


def load_dataset_splits(task, splits):
    for split in splits:
        if split == 'train':
            task.load_dataset(split, combine=True)
        else:
def load_dataset_splits(args, task):
    task.load_dataset(args.train_subset, combine=True)
    for split in args.valid_subset.split(','):
        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            try: