Skip to content
Snippets Groups Projects
Commit 4a30a5f6 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix inconsistent gradient check

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/692

Differential Revision: D15174954

fbshipit-source-id: 1a7bff9aeed3e2cc658577be9d79e8c9f72314c2
parent ffc9c8cc
No related branches found
No related tags found
No related merge requests found
......@@ -11,6 +11,7 @@ Train a network across multiple GPUs.
from collections import OrderedDict
from itertools import chain
import math
import os
import torch
......@@ -239,8 +240,10 @@ class Trainer(object):
logging_outputs = list(chain.from_iterable(logging_outputs))
sample_sizes = list(chain.from_iterable(sample_sizes))
ooms = sum(ooms)
assert all(norm == prev_norms[0] for norm in prev_norms), \
'Fatal error: gradients are inconsistent between workers'
assert (
all(norm == prev_norms[0] for norm in prev_norms)
or all(math.isnan(norm) or math.isinf(norm) for norm in prev_norms)
), 'Fatal error: gradients are inconsistent between workers'
self.meters['oom'].update(ooms, len(samples))
if ooms == self.args.distributed_world_size * len(samples):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment