Commit eef6663c authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add checkpoint write timer

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/613

Differential Revision: D14712311

Pulled By: myleott

fbshipit-source-id: 3e7646629b539c10b6af89dece2c0c564f31125f
parent e88ad84b
Loading
Loading
Loading
Loading
+10 −1
Original line number Diff line number Diff line
@@ -11,8 +11,8 @@ Train a new model on one or across multiple GPUs.

import collections
import itertools
import os
import math
import os
import random

import torch
@@ -282,6 +282,10 @@ def get_perplexity(loss):
def save_checkpoint(args, trainer, epoch_itr, val_loss):
    if args.no_save or not distributed_utils.is_master(args):
        return

    write_timer = StopwatchMeter()
    write_timer.start()

    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
    updates = trainer.get_num_updates()
@@ -330,6 +334,11 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    write_timer.stop()

    print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format(
        checkpoints[0], epoch, updates, write_timer.sum))


def load_checkpoint(args, trainer, epoch_itr):
    """Load a checkpoint and replay dataloader to match."""