Skip to content

DataParallel replicates network object for device 0 #11476

@ssnl

Description

@ssnl

This makes some hooks just not work, e.g., if your hook sets some attribute on the module input, they are lost. E.g., spectral norm. This makes DataParallel + spectral norm not converging.

Also DataParallel + spectral norm doesn't work in eval mode because we don't properly broadcast buffers that require grad. See the following repro:

import torch
import torch.nn as nn
from torch.nn.utils import spectral_norm

net = nn.Conv2d(3, 3, 3).cuda()
net = spectral_norm(net)
net = nn.DataParallel(net)

inp = torch.randn(2, 3, 4, 4).cuda()

net(inp)

net.eval()
net(inp)

I'll fix both.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions