Skip to content
Snippets Groups Projects
Commit d63477e1 authored by Yongqiang Wang's avatar Yongqiang Wang Committed by Facebook Github Bot
Browse files

reduce memory footprint for average_checkpoints (#647)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/647

the current implementation of average_checkpoints requires loading all
the model parameters into memory and then do the averaging. To average large
models (e.g., transformer) over a large number of checkpoints (e.g., >50),
it may require over 100GB memory.

Loading all the parameters is not necessary, as we know the number of models in advance.

Reviewed By: skritika

Differential Revision: D15027513

fbshipit-source-id: 0afe37c9a031a9ab0f1e78844a37be49ec5f76f1
parent d2f3007c
No related branches found
No related tags found
No related merge requests found
......@@ -21,6 +21,8 @@ def average_checkpoints(inputs):
params_dict = collections.OrderedDict()
params_keys = None
new_state = None
num_models = len(inputs)
for f in inputs:
state = torch.load(
f,
......@@ -44,20 +46,18 @@ def average_checkpoints(inputs):
)
for k in params_keys:
if k not in params_dict:
params_dict[k] = []
p = model_params[k]
if isinstance(p, torch.HalfTensor):
p = p.float()
params_dict[k].append(p)
if k not in params_dict:
params_dict[k] = p
else:
params_dict[k] += p
averaged_params = collections.OrderedDict()
# v should be a list of torch Tensor.
for k, v in params_dict.items():
summed_v = None
for x in v:
summed_v = summed_v + x if summed_v is not None else x
averaged_params[k] = summed_v / len(v)
averaged_params[k] = v / num_models
new_state['model'] = averaged_params
return new_state
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment