Fix all-reduce for new versions of PyTorch
We previously assumed that once a model parameter's gradient buffer was allocated, it stayed fixed during training. However, this assumption is violated in recent versions of PyTorch (i.e., the gradient buffer may be reallocated during training), and it's no longer a safe assumption to make. This is primarily relevant when we do the all-reduce, since we all-reduce a flattened (i.e., contiguous) copy of the gradients. We can make this more robust by copying the result of the all-reduce back into the model parameter's gradient buffers after each update. Intra-device copies are cheap, so this doesn't affect performance.
Loading
Please register or sign in to comment