-
Notifications
You must be signed in to change notification settings - Fork 19.7k
[WIP] Recursive container #620
Description
What?
Over the weekend I worked on a solution to design arbitrary RNNs using Keras API. The result allow us write a vanilla RNN as:
self.input_dim = 2
self.state_dim = 2
self.model = Recursive(return_sequences=True)
self.model.add_input('input', ndim=3) # Input is 3D tensor
self.model.add_state('h', dim=self.state_dim)
self.model.add_node(Dense(self.input_dim, self.state_dim, init='one'),
name='i2h', inputs=['input', ])
self.model.add_node(Dense(self.state_dim, self.state_dim, init='orthogonal'),
name='h2h', inputs=['h', ])
self.model.add_node(Activation('tanh'). name='rec', inputs=['i2h', 'h2h'],
merge_mode='sum', return_state='h', create_output=True)Note that the class definition of SimpleRNN is much bigger than this and we don't have the choice of outputting intermediate values, like ex. the input-to-hidden projection. This should be interesting to design different state based models without having to dig into Theano code. I started the development on the repo I usually put my Keras extensions. There is a test here showing how to use this new container (there will be lots of printing I used for debugging. That should be cleaned up soon). If there is a general interest on this, I could just PR a new branch here.
How? (dev details)
The Recursive container is basically a Graph container with a few mods. The most important difference is the way we connect layers. Contrary to regular feedforward networks, we cannot use set_previous inside the add_node method. Everything has to be done inside a _step function and we have to take care of the order which we pass arguments to scan (I didn't explore the idea of using dictionaries as inputs to scan yet). In other words, the entire connection logic is moved from add_node (like it is done for Sequential and Graph) to _step.
Next Steps?
There is a lot of code clean up and refactoring that could possibly make the internals cleaner. For example, in a conversation with @fchollet, he also suggested me that we should define the states as self.model.add_state('h', dim=self.state_dim, input='rec') instead of using a return_state option inside the add_node.
Stateful?
Another interesting problem is how to handle stateful models, where the hidden states are not wiped out after each batch. In a previous experiment I did, I set up the initial states of an RNN to be a shared variable and defined its update to be the last state returned by scan. I did that inside the get_output method. Now that Keras gets all the individual layers self.updates, everything else was handled by the Model class. We could also do this here. The problem is that shared variables can't change sizes and we have to make sure we always have batches of the same size, otherwise, we would have to recompile the model. I would love to hear about alternatives for this.
Final words
Sorry for the long post, but hopefully it will get people interested in developing this API (and/or inspiring new ones) that I believe will make our lives much easier when it comes to design new RNNs.