Loading train.py +14 −16 Original line number Diff line number Diff line Loading @@ -39,7 +39,7 @@ def main(args, init_distributed=False): task = tasks.setup_task(args) # Load dataset splits load_dataset_splits(task, ['train', 'valid']) load_dataset_splits(args, task) # Initialize distributed training (after data loading) if init_distributed: Loading @@ -64,8 +64,8 @@ def main(args, init_distributed=False): task.max_positions(), model.max_positions(), ) dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions) oom_batch = task.dataset('train').get_dummy_batch(1, max_positions) dummy_batch = task.dataset(args.train_subset).get_dummy_batch(args.max_tokens, max_positions) oom_batch = task.dataset(args.train_subset).get_dummy_batch(1, max_positions) # Build trainer trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch) Loading Loading @@ -358,11 +358,9 @@ def load_checkpoint(args, trainer, epoch_itr): return False def load_dataset_splits(task, splits): for split in splits: if split == 'train': task.load_dataset(split, combine=True) else: def load_dataset_splits(args, task): task.load_dataset(args.train_subset, combine=True) for split in args.valid_subset.split(','): for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') try: Loading Loading
train.py +14 −16 Original line number Diff line number Diff line Loading @@ -39,7 +39,7 @@ def main(args, init_distributed=False): task = tasks.setup_task(args) # Load dataset splits load_dataset_splits(task, ['train', 'valid']) load_dataset_splits(args, task) # Initialize distributed training (after data loading) if init_distributed: Loading @@ -64,8 +64,8 @@ def main(args, init_distributed=False): task.max_positions(), model.max_positions(), ) dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions) oom_batch = task.dataset('train').get_dummy_batch(1, max_positions) dummy_batch = task.dataset(args.train_subset).get_dummy_batch(args.max_tokens, max_positions) oom_batch = task.dataset(args.train_subset).get_dummy_batch(1, max_positions) # Build trainer trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch) Loading Loading @@ -358,11 +358,9 @@ def load_checkpoint(args, trainer, epoch_itr): return False def load_dataset_splits(task, splits): for split in splits: if split == 'train': task.load_dataset(split, combine=True) else: def load_dataset_splits(args, task): task.load_dataset(args.train_subset, combine=True) for split in args.valid_subset.split(','): for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') try: Loading