🐛 Bug
If DistributedDataParallel is used with non-default streams, averaging the gradients afterwards sometimes doesn't give you the average.
To Reproduce
Here's a minimal example that demonstrates both buggy and correct behaviour, differing only in whether the default or independent streams are used.
import os
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch import distributed as dist
from torch import multiprocessing as mp
from torch import nn
def worker(rank, streamed):
# Set up the worker to use NCCL
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
torch.cuda.set_device(rank)
dist.init_process_group("nccl", rank=rank, world_size=2)
# Pick the stream we're going to work on
stream = torch.cuda.Stream() if streamed else torch.cuda.default_stream()
with torch.cuda.stream(stream):
# Create the network
net = DDP(nn.Linear(1, 1, bias=False).cuda(rank), device_ids=[rank])
# Do a bunch of forward-backward passes and check the gradients are
# consistent after each
print(f'{rank}: Looping')
bad = False
for i in range(100):
# Clear the gradients manually since we don't have an optimizer around
grad = net.module.weight.grad
if grad is not None:
grad.detach_()
grad.zero_()
# Forward-backward. Gradient on worker #rank should be [[rank]].
batch = torch.tensor([rank]).float().cuda(rank)
loss = net(batch).sum()
loss.backward()
# Get the new gradient
grad = net.module.weight.grad
# Calculate the average gradient across workers. It should be .5.
average = grad.clone()
dist.all_reduce(average) # takes the sum across ranks
average = average/2
if average[0, 0] != .5:
print(f'{rank}: Average-of-averaged-gradients is wrong on loop {i}; it\'s {average[0, 0]} when my own averaged-gradient is {grad[0, 0]}. They should be both .5')
bad = True
if not bad:
print(f'{rank}: Looped successfully; all grads consistent')
def run(streamed):
# Launch a pair of workers that'll use DDP to compute gradients together
workers = [mp.Process(target=worker, args=(r, streamed)) for r in [0, 1]]
for w in workers:
w.start()
for w in workers:
w.join()
if __name__ == '__main__':
mp.set_start_method('spawn')
print('\nRunning with-stream, with-bug version')
run(streamed=True)
print('Running no-stream, no-bug version')
run(streamed=False)
Expected behavior
The output of the script - showing both buggy and expected behaviour - is
Running with-stream, with-bug version
0: Looping
1: Looping
0: Average-of-averaged-gradients is wrong on loop 0; it's 0.25 when my own averaged-gradient is 0.5. They should be both .5
1: Average-of-averaged-gradients is wrong on loop 0; it's 0.25 when my own averaged-gradient is 0.5. They should be both .5
...
0: Average-of-averaged-gradients is wrong on loop 0; it's 0.75 when my own averaged-gradient is 0.5. They should be both .5
1: Average-of-averaged-gradients is wrong on loop 0; it's 0.75 when my own averaged-gradient is 0.5. They should be both .5
...
Running no-stream, no-bug version
0: Looping
1: Looping
1: Looped successfully; all grads consistent
0: Looped successfully; all grads consistent
Keep in mind that the correct averaged gradient is 0.5, which means that the correctly averaged gradient is being reported on both workers, just the average of those averages is wrong.
Environment
The background env is a nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 container with a conda-installed PyTorch 1.5.
PyTorch version: 1.5.0
Is debug build: No
CUDA used to build PyTorch: 10.2
OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: Could not collect
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.2.89
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
Nvidia driver version: 440.64
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
Versions of relevant libraries:
[pip] numpy==1.18.1
[pip] pytorch-memlab==0.0.4
[pip] torch==1.5.0
[pip] torchfile==0.1.0
[pip] torchvision==0.6.0a0+82fd1c8
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.2.89 hfd86e86_1
[conda] mkl 2020.0 166
[conda] mkl-service 2.3.0 py37he904b0f_0
[conda] mkl_fft 1.0.15 py37ha843d7b_0
[conda] mkl_random 1.1.0 py37hd6b4f25_0
[conda] numpy 1.18.1 py37h4f9e942_0
[conda] numpy-base 1.18.1 py37hde5b4d6_1
[conda] pytorch 1.5.0 py3.7_cuda10.2.89_cudnn7.6.5_0 pytorch
[conda] pytorch-memlab 0.0.4 pypi_0 pypi
[conda] torchfile 0.1.0 pypi_0 pypi
[conda] torchvision 0.6.0 py37_cu102 pytorch
Additional context
The genesis of this is that I had some code to check my DDP workers weren't diverging, and it all worked fine until I upgraded to PyTorch 1.5. I upgraded a bunch of other things simultaneously though - most notably CUDA 10.1 to CUDA 10.2 - so unfortunately I can't point the finger directly at 1.5.
I tried simplifying the example further by constructing the grad tensors directly and averaging them myself rather than getting them from DDP, but that didn't produce the same buggy behaviour. The combination of DDP, reducing the result, and streams seems critical.
That the two bad values it spits out are .25 and .75 suggest that the all_reduce is happening when the DDP-averaged gradient has been written to one of the gradient tensors but not both.
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528 @osalpekar
🐛 Bug
If DistributedDataParallel is used with non-default streams, averaging the gradients afterwards sometimes doesn't give you the average.
To Reproduce
Here's a minimal example that demonstrates both buggy and correct behaviour, differing only in whether the default or independent streams are used.
Expected behavior
The output of the script - showing both buggy and expected behaviour - is
Keep in mind that the correct averaged gradient is 0.5, which means that the correctly averaged gradient is being reported on both workers, just the average of those averages is wrong.
Environment
The background env is a
nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04container with a conda-installed PyTorch 1.5.Additional context
The genesis of this is that I had some code to check my DDP workers weren't diverging, and it all worked fine until I upgraded to PyTorch 1.5. I upgraded a bunch of other things simultaneously though - most notably CUDA 10.1 to CUDA 10.2 - so unfortunately I can't point the finger directly at 1.5.
I tried simplifying the example further by constructing the grad tensors directly and averaging them myself rather than getting them from DDP, but that didn't produce the same buggy behaviour. The combination of DDP, reducing the result, and streams seems critical.
That the two bad values it spits out are .25 and .75 suggest that the
all_reduceis happening when the DDP-averaged gradient has been written to one of the gradient tensors but not both.cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528 @osalpekar