- PyTorch version: 0.3.0
- Python version: 3.5
- GPU models and configuration: 8x NVIDIA Titan
I found that requires_grad setting in module is ignored when module is wrapped with nn.DataParallel. For example,
import torch
from torch.nn import DataParallel
from torchvision.models.resnet import resnet50
module_ = resnet50()
for name, param in module_.named_parameters():
if name.startswith('conv1') or name.startswith('bn1'):
param.requires_grad = False
if name.startswith('layer1') or name.startswith('layer2'):
param.requires_grad = False
if name.startswith('layer3') or name.startswith('layer4'):
param.requires_grad = False
module_ = DataParallel(module_).cuda()
x = torch.rand(32, 3, 1600, 1600)
x = torch.autograd.Variable(x)
x = module_(x)
print(x)
This code should be executed because very small part of network activation (only fc layer) is stored for backward computation, but it results to run-time error:
RuntimeError: cuda runtime error (2) : out of memory at /tmp/pip-8pfswvat-build/torch/lib/THC/generic/THCStorage.cu:58
Internally, I found a crack in replicate function which is in torch.nn.parallel.replicate. In replicate function, it copies all parameter in module (# of replica times) with Broadcast.apply. In broadcasting code, it just defines new torch.nn.Parameter with default constructor requires_grad parameter, which is always set to True.
I think there can be some choices to fix this issue.
- It is intended behavior for DataParallel, so we should use volatile (or torch.no_grad) to implement transfer learning in DataParallel
- Broadcast model should be fixed to handle requires_grad when copying parameters.
- There should be synchronizing
requires_grad in torch.nn.parallel.replicate
I think any of solution doesn't need lots of effort to fix it.
I found that requires_grad setting in module is ignored when module is wrapped with nn.DataParallel. For example,
This code should be executed because very small part of network activation (only fc layer) is stored for backward computation, but it results to run-time error:
Internally, I found a crack in
replicatefunction which is intorch.nn.parallel.replicate. In replicate function, it copies all parameter in module (# of replica times) withBroadcast.apply. In broadcasting code, it just defines newtorch.nn.Parameterwith default constructorrequires_gradparameter, which is always set to True.I think there can be some choices to fix this issue.
requires_gradintorch.nn.parallel.replicateI think any of solution doesn't need lots of effort to fix it.