Commit de8aeab5 authored by freewym's avatar freewym Committed by Facebook Github Bot
Browse files

fix checkpoint timer (#634)

Summary:
If arg.keep_interval_updates or args.keep_last_epochs > 0, `checkpoints` would refer to a list of checkpoint files to be removed, which can be empty. So moved the logging code to the right position.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/634

Differential Revision: D14933655

Pulled By: myleott

fbshipit-source-id: 68182ee99d9701e1536833d31e0a7c5d2eb2d679
parent e12e1d25
Loading
Loading
Loading
Loading
+4 −5
Original line number Diff line number Diff line
@@ -320,6 +320,10 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
        for cp in checkpoints:
            trainer.save_checkpoint(cp, extra_state)

        write_timer.stop()
        print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format(
            checkpoints[0], epoch, updates, write_timer.sum))

    if not end_of_epoch and args.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
@@ -334,11 +338,6 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    write_timer.stop()

    print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format(
        checkpoints[0], epoch, updates, write_timer.sum))


def load_checkpoint(args, trainer, epoch_itr):
    """Load a checkpoint and replay dataloader to match."""