import torch.nn as nn
from torch.nn.utils.weight_norm import weight_norm
a = nn.GRU(100, 20)
b = weight_norm(a, name='weight_hh_l0')
b = weight_norm(b, name='weight_ih_l0')
b.cuda()
File "<stdin>", line 1, in <module>
File "/home/didoyang/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 147, in cuda
return self._apply(lambda t: t.cuda(device_id))
File "/home/didoyang/anaconda2/lib/python2.7/site-packages/torch/nn/modules/rnn.py", line 116, in _apply
self.flatten_parameters()
File "/home/didoyang/anaconda2/lib/python2.7/site-packages/torch/nn/modules/rnn.py", line 107, in flatten_parameters
rnn._copyParams(all_weights, params)
File "/home/didoyang/anaconda2/lib/python2.7/site-packages/torch/backends/cudnn/rnn.py", line 186, in _copyParams
assert param_from.type() == param_to.type()
I haven't looked into this in depth but this looks fishy:
More here: https://discuss.pytorch.org/t/built-in-weight-norm-on-rnn/5905
cc @ezyang @gchanan @zou3519 @jerryzh168