Add rnn args check#3925
Conversation
apaszke
left a comment
There was a problem hiding this comment.
Looks good, but please check everything for LSTM.
| mini_batch, self.hidden_size) | ||
| if self.mode == 'LSTM': | ||
| hidden = hidden[0] | ||
| if tuple(hidden.size()) != expected_hidden_size: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
apaszke
left a comment
There was a problem hiding this comment.
Didn't this PR have a test for these cases as well? It would be nice to add it back before merging.
* Add rnn args check * Check both hidden sizes for LSTM * RNN args check test
|
@zou3519 I have similar problem, RuntimeError: Expected hidden[0] size (1, 64, 256), got (64, 256), i tried different ways but unable to get it. RuntimeError: Expected hidden[0] size (1, 64, 256), got (64, 256), i tried different ways but unable to get it. |
|
@Shandilya21 Please ask a question on the forums or open a (new) issue if you think there is a bug. |
* Add rnn args check * Check both hidden sizes for LSTM * RNN args check test
Fixes #3851, #3259
Added a high-level check for arguments to RNNBase (these were moved from arg checks in
cudnn/rnn.py)Test Plan
python test/test_nn.py