This makes some hooks just not work, e.g., if your hook sets some attribute on the module input, they are lost. E.g., spectral norm. This makes DataParallel + spectral norm not converging.
Also DataParallel + spectral norm doesn't work in eval mode because we don't properly broadcast buffers that require grad. See the following repro:
import torch
import torch.nn as nn
from torch.nn.utils import spectral_norm
net = nn.Conv2d(3, 3, 3).cuda()
net = spectral_norm(net)
net = nn.DataParallel(net)
inp = torch.randn(2, 3, 4, 4).cuda()
net(inp)
net.eval()
net(inp)
I'll fix both.
This makes some hooks just not work, e.g., if your hook sets some attribute on the
moduleinput, they are lost. E.g., spectral norm. This makes DataParallel + spectral norm not converging.Also DataParallel + spectral norm doesn't work in eval mode because we don't properly broadcast buffers that require grad. See the following repro:
I'll fix both.