Commit d63477e1 authored by Yongqiang Wang's avatar Yongqiang Wang Committed by Facebook Github Bot
Browse files

reduce memory footprint for average_checkpoints (#647)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/647

the current implementation of average_checkpoints requires loading all
the model parameters into memory and then do the averaging. To average large
models (e.g., transformer) over a large number of checkpoints (e.g., >50),
it may require over 100GB memory.

Loading all the parameters is not necessary, as we know the number of models in advance.

Reviewed By: skritika

Differential Revision: D15027513

fbshipit-source-id: 0afe37c9a031a9ab0f1e78844a37be49ec5f76f1
parent d2f3007c
Loading
Loading
Loading
Loading
+7 −7
Original line number Diff line number Diff line
@@ -21,6 +21,8 @@ def average_checkpoints(inputs):
    params_dict = collections.OrderedDict()
    params_keys = None
    new_state = None
    num_models = len(inputs)

    for f in inputs:
        state = torch.load(
            f,
@@ -44,20 +46,18 @@ def average_checkpoints(inputs):
            )

        for k in params_keys:
            if k not in params_dict:
                params_dict[k] = []
            p = model_params[k]
            if isinstance(p, torch.HalfTensor):
                p = p.float()
            params_dict[k].append(p)
            if k not in params_dict:
                params_dict[k] = p
            else:
                params_dict[k] += p

    averaged_params = collections.OrderedDict()
    # v should be a list of torch Tensor.
    for k, v in params_dict.items():
        summed_v = None
        for x in v:
            summed_v = summed_v + x if summed_v is not None else x
        averaged_params[k] = summed_v / len(v)
        averaged_params[k] = v / num_models
    new_state['model'] = averaged_params
    return new_state