Skip to content

weight_norm on GRU followed by .cuda causes an assert #2343

@gchanan

Description

@gchanan

I haven't looked into this in depth but this looks fishy:

import torch.nn as nn
from torch.nn.utils.weight_norm import weight_norm
a = nn.GRU(100, 20)
b = weight_norm(a, name='weight_hh_l0')
b = weight_norm(b, name='weight_ih_l0')
b.cuda()
  File "<stdin>", line 1, in <module>
  File "/home/didoyang/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 147, in cuda
    return self._apply(lambda t: t.cuda(device_id))
  File "/home/didoyang/anaconda2/lib/python2.7/site-packages/torch/nn/modules/rnn.py", line 116, in _apply
    self.flatten_parameters()
  File "/home/didoyang/anaconda2/lib/python2.7/site-packages/torch/nn/modules/rnn.py", line 107, in flatten_parameters
    rnn._copyParams(all_weights, params)
  File "/home/didoyang/anaconda2/lib/python2.7/site-packages/torch/backends/cudnn/rnn.py", line 186, in _copyParams
    assert param_from.type() == param_to.type()

More here: https://discuss.pytorch.org/t/built-in-weight-norm-on-rnn/5905

cc @ezyang @gchanan @zou3519 @jerryzh168

Metadata

Metadata

Assignees

Labels

high prioritymodule: rnnIssues related to RNN support (LSTM, GRU, etc)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions