Skip to content

torch.cuda.current_device() is always 0 at backward in DataParallel #2017

@hzaskywalker

Description

@hzaskywalker

Consider following code, and run it with multi gpus (e.g. 4):

import torch
from torch import nn
from torch.autograd import Function

class Func(Function):
    def forward(self, x):
        print('forward devices', torch.cuda.current_device())
        return x

    def backward(self, grad):
        print('backward devices', torch.cuda.current_device())
        return grad

class Module(nn.Module):
    def forward(self, x):
        return Func()(x)

f = nn.DataParallel(Module().cuda())
x = torch.autograd.Variable( torch.zeros(4, 1).cuda(), requires_grad=True)
loss = f(x)
loss.sum().backward()

It will output:

forward devices 0
forward devices 1
forward devices 2
forward devices 3
backward devices 0
backward devices 0
backward devices 0
backward devices 0

That is, the default device would always be zero even in DataParallel. My pytorch version is 0.1.12_2.

I think this is not the desired behavior. It would cause some troubles. I tried to insert my own cuda kernel into backward to calculate the gradients, it become very slow, and I fixed it by torch.cuda.set_device(grad.get_device()).

Anyway, I think current_device in forward and backward should be the same, or could anyone explain to me why they are different?

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions