Skip to content

Nn.dataparallel with multiple output, weird gradient result None #15716

@eric-zhenhai

Description

@eric-zhenhai

🐛 Bug

Under PyTorch 1.0, nn.DataParallel() wrapper for models with multiple outputs does not calculate gradients properly.

To Reproduce

On servers with >=2 GPUs, under PyTorch 1.0.0
Steps to reproduce the behavior:

  1. Use the code in below:
import torch.nn as nn
import torch
import torch.nn.functional as F

DEVICE = torch.device('cuda:0')


class NN4(nn.Module):
    def __init__(self):
        super(NN4, self).__init__()
        self.fc1 = nn.Linear(8, 4)
        self.fc21 = nn.Linear(4, 1)

    def forward(self, x):
        x = F.selu(self.fc1(x))
        x1 = torch.sigmoid(self.fc21(x))
        # return x, x  # not None
        return x, x1  # None


def test_NN4():
    images = torch.randn(4, 8).to(DEVICE)
    fimages = torch.randn(4, 8).to(DEVICE)

    D = NN4().to(DEVICE)
    D = nn.DataParallel(D)
    D.zero_grad()

    d_loss = D(images)[0].mean() - D(fimages)[0].mean()
    print('d_loss: -->', d_loss)
    d_loss.backward()

    print('-------->>>')
    aaa = list(D.named_parameters())
    print(aaa[0][0])
    print(aaa[0][1].grad)

    D2 = NN4().to(DEVICE)
    D2.zero_grad()

    d2_loss = D2(images)[0].mean() - D2(fimages)[0].mean()
    print('d2_loss: -->', d2_loss)
    d2_loss.backward()

    print('-------->>>')
    aaa2 = list(D2.named_parameters())
    print(aaa2[0][0])
    print(aaa2[0][1].grad)

Then run the code with "CUDA_VISIBLE_DEVICES=0,1 python dp_test.py" in console. Under PyTorch 1.0.0, I get:

d_loss: --> tensor(0.1488, device='cuda:0', grad_fn=<SubBackward0>)
-------->>>
module.fc1.weight
None
d2_loss: --> tensor(0.0149, device='cuda:0', grad_fn=<SubBackward0>)
-------->>>
fc1.weight
tensor([[ 0.0284, -0.1972,  0.1553, -0.3356,  0.2737, -0.2083,  0.1420, -0.3533],
        [ 0.0473, -0.1277,  0.0903, -0.3214,  0.2385, -0.1815,  0.0369, -0.1991],
        [ 0.0231, -0.0949,  0.1218, -0.3591,  0.1832, -0.2311,  0.0685, -0.1934],
        [ 0.0858, -0.1129,  0.1216, -0.3774,  0.3795, -0.1308, -0.0006, -0.1790]],
       device='cuda:0')

However, under PyTorch 0.4.0, I get:

d_loss: --> tensor(0.1650, device='cuda:0')
-------->>>
module.fc1.weight
tensor([[-0.2463,  0.0740, -0.2929, -0.2576, -0.0346,  0.1679,  0.1501,
         -0.2375],
        [-0.2666,  0.1135, -0.3788, -0.2865, -0.0519, -0.0217,  0.0564,
         -0.2942],
        [-0.2802,  0.1207, -0.3556, -0.2959, -0.0245, -0.0106,  0.0902,
         -0.2851],
        [-0.3193,  0.0788, -0.4258, -0.2705, -0.1212,  0.0063,  0.0322,
         -0.2649]], device='cuda:0')
d2_loss: --> tensor(1.00000e-02 *
       8.7814, device='cuda:0')
-------->>>
fc1.weight
tensor([[-0.3051,  0.1011, -0.3452, -0.2829, -0.0318, -0.0299,  0.0642,
         -0.2442],
        [-0.2536,  0.1279, -0.3869, -0.3891, -0.0362,  0.0412,  0.1000,
         -0.3384],
        [-0.3321,  0.0059, -0.4514, -0.2517, -0.1013,  0.0374,  0.0124,
         -0.1985],
        [-0.3147,  0.0331, -0.3343, -0.2498, -0.0903, -0.0668,  0.0555,
         -0.2360]], device='cuda:0')

Expected behavior

aaa[0][1].grad should not be none under PyTorch 1.0.0

Environment

PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 16.04.5 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.5.1

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: Tesla V100-SXM2-16GB
GPU 1: Tesla V100-SXM2-16GB
GPU 2: Tesla V100-SXM2-16GB
GPU 3: Tesla V100-SXM2-16GB

Nvidia driver version: 396.44
cuDNN version: Probably one of the following:
/usr/local/cuda-8.0/lib64/libcudnn.so.6.0.21
/usr/local/cuda-8.0/lib64/libcudnn_static.a
/usr/local/cuda-9.0/lib64/libcudnn.so.7.3.1
/usr/local/cuda-9.0/lib64/libcudnn_static.a
/usr/local/cuda-9.1/lib64/libcudnn.so.7.0.5
/usr/local/cuda-9.1/lib64/libcudnn_static.a
/usr/local/cuda-9.2/lib64/libcudnn.so.7.3.1
/usr/local/cuda-9.2/lib64/libcudnn_static.a

Additional context

Metadata

Metadata

Assignees

Labels

oncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions