Default initial hidden states for recurrent layers : Issue#434#605
Default initial hidden states for recurrent layers : Issue#434#605apaszke merged 7 commits intopytorch:masterfrom goelhardik:issue-434
Conversation
torch/nn/modules/rnn.py
Outdated
|
|
||
| def forward(self, input, hx): | ||
| def forward(self, input, hx=None): | ||
| if (hx == None): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/rnn.py
Outdated
| def forward(self, input, hx): | ||
| def forward(self, input, hx=None): | ||
| if (hx == None): | ||
| batch_sz = input.size()[0] if self.batch_first else input.size()[1] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/rnn.py
Outdated
| def forward(self, input, hx=None): | ||
| if (hx == None): | ||
| batch_sz = input.size()[0] if self.batch_first else input.size()[1] | ||
| hx = torch.autograd.Variable(torch.Tensor(self.num_layers, batch_sz, |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/rnn.py
Outdated
| if (hx == None): | ||
| batch_sz = input.size()[0] if self.batch_first else input.size()[1] | ||
| hx = torch.autograd.Variable(torch.Tensor(self.num_layers, batch_sz, | ||
| self.input_size).zero_()) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/rnn.py
Outdated
| batch_sz = input.size()[0] if self.batch_first else input.size()[1] | ||
| hx = torch.autograd.Variable(torch.Tensor(self.num_layers, batch_sz, | ||
| self.input_size).zero_()) | ||
| if (self.mode == 'LSTM'): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/rnn.py
Outdated
| self.input_size).zero_()) | ||
| if (self.mode == 'LSTM'): | ||
| hx = (torch.autograd.Variable(hx.data), | ||
| torch.autograd.Variable(hx.data)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
One last thing. Can you please add a test that uses this change? Just instantiate one of each kind of RNNs we have and pass a batch through it - once without passing the hidden state, and once with a manually constructed one. Then use |
|
I think I did a merge while trying to rebase my branch - that's why it shows the last commit 722c407. Is this okay? Should I try to revert this or just go ahead with adding the test case? |
|
go ahead and add the testcase. we'll squash it down before merging. |
Add pruning tutorial. Will create another PR to add it into the ToC.
Setting a default initial hidden state of zeros if the hidden state is not provided by the user. Doing this in the RNNBase class, so it works for all three - RNN, GRU and LSTM.