In this post I’ll talk through the pieces of the Tensorflow API most relevant to building recurrent neural networks tensorflow documentation is great for explaining how to build standard RNNs, but it falls a little flat for building highly customized RNNs.
I’ll use the network described in Hierarchical Multiscale Recurrent Neural Networks by Chung et al. as an example of a fairly non-standard RNN. There’s an open-source implementation of that network on github.
The basics
In this section I’ll provide a quick overview of the tools available to create standard RNNs in tensorflow.
Standard RNN Cells
If you need a standard RNN, GRU, or LSTM, tensorflow has you covered. The API contains these pre-written cells, all of which extend the base class RNNCell. Their tutorial on RNNs gives a good overview of how to use these cells, so I won’t spend much time here. If you are completely new to RNNs in tensorflow, it may be a good idea to review that tutorial before continuing.
One thing that is worth pointing out about their tutorial is their use of the MultiRNNCell. This is a class that is constructed with a list of objects that extend RNNCell. It is used for creating multi-layered RNNs, where the lowest layer gets fed in the input, then each subsequent layer gets fed the output of the previous layer, and the output of the last layer is the output of the network at a given timestep. If you want to pass information between layers in a different way, you’ll need a custom multicell. We’ll come back to this later.
Note that MutliRNNCell itself extends RNNCell as well, so it can be used anywhere any other RNNCell can be used.
Dynamic v. Unrolled RNNs
I find it easiest to think about RNNs when they are unrolled. In tensorflow, this means that we rebuild an identical computational graph for each timestep, and pass the hidden state(s) from one timestep forward to the next manually, as it were.
This is the approach taken the tensorflow tutorial model. The upside to this approach is that it is easy to think about, and it is flexible. If you want to process the hidden states from one time step in any way before you pass them on, you can easily put more nodes in the graph to do so.
There are two major downsides. First, a graph composed in this way has to be fixed length, which means you’ll have to rebuild the graph for different length signals, or pad them out with zeros. Neither solution is great.
Second, large graphs take much longer to build and consume much more RAM. Depending on the constraints of your computing environment, this could be prohibitive.
The tf.dynamic_rnn function will transform your RNNCell into a dynamically generated graph that passes the state, whatever that may be, from one time step to the next, and keeps track of the outputs. If you have other needs, tf.scan can serve a similar purpose more flexibly, as we will see later.
Novel RNN Configurations
In this section, I’ll review some available options for creating RNNs with less standard architectures.
Extending RNNCell Directly
If you need a network that’s a little different from any of the standard implementations, you can extend RNNCell directly. You can use your own subclass with the MultiRNNCell class described above as well as the DropoutWrapper and other predefined RNNCell wrappers.
Extending RNNCell means overriding at least the state_size property, the output_size property, And the call method. Tensorflow’s Prebuilt cells represent state either as a signal tensor or as a tuple of tensors. If it is a single tensor, it gets broken down into cell and hidden states (or whatever the case may be) upon entry into the cell, and then the new states are stuck back together at the end.
I’ve found it simpler to treat cell state as a tuple. In this case, The state_size property is just a tuple with the lengths of states you’re keeping track of The call function is where the logic of your cell will live. RNNCell’s __call__ method will wrap your call method and help with scoping and other logistics.
In order for your subclass to be a valid RNNCell, the call
method must accept parameters input
and state
, and return a tuple of output, new_state
, where state
and new_state
must have the same form.
Note that if you construct a new RNNCell that you want to use with tensorflow variables that already exist in your tensorflow session, you can pass a _reuse=True
argument in to the parent constructor within your __init__
method. If the variables already exist but you do not pass _reuse=True
, you’ll get an error because tensorflow will neither overwrite existing variables or reuse them without explicit instruction.
For reference, the HMLSTMCell class is an RNNCell used to represent one cell of the Hierarchical and Multiscale RNN mentioned above. Its implementation covers all the main points above.
That code also makes use of an undocumented function in the rnn_cell_impl module called _linear, which is used in most of the baked in RNNCell subclasses. This is a little risky, Because it’s clearly not meant for outside use, But it’s a useful little function that handles matrix properly and addition of weights and biases.
Unusual Multilayered RNNs
If you’re building a multi-layered RNN where the layers don’t simply pass their output up from layer to layer, you’ll have to build your own version of a MultiCell. Much like the built in MultiRNNCell, your multicell should extend RNNCell.
In this case the cell state will be a list, where each element is the cell state at the layer corresponding to its index.
Writing your own multicell is useful in two cases. First, in the case where you want to do something to the result of one layer before you pass it into the cell at the next layer, but you don’t want to execute that operation for the lowest layer (otherwise you could just build it into the cell).
Second, It’s useful if there’s information from the previous time step that you need to distribute among the different layers, But that didn’t fit neatly into the paradigm of passing along state from one time step to the next.
For example, in the hierarchical multiscall LSTM, each cell expects to receive the hidden state from the layer above it at the previous time step as part of its input. This didn’t neatly fit the standard idea of STACKED RNNs, so we can’t use the usual MultiRNNCell. For reference, here is the implementation of the MultiHMLSTMCell.
Building Dynamic RNNs with tf.scan
The Hierarchical Multiscale LSTM network calls for the hidden states to be fed into some output network. We’ve already seen that this HMLSTM network doesn’t neatly fit into the tensorflow RNN paradigm because of how it handles passing information between layers; now we’ve hit another obstacle. Instead of considering the last output of the last layer the output of the network, we need to pass the hidden states of all layers through another network to get the output we really care about. Not only that, we care about the value of some of the indicator neurons, which are treated as cell state.
For these reasons, we can’t use the tf.dynamic_rnn
network, which returns only the output at each time step and the final state.
tf.scan
, instead, takes an arbitrary function and a ordered collection of elements. It then applies the function to each element in the collection, keeping track of some accumulator. It returns an ordered collection of the value of the accumulator at each step in the process.
This is perfect for a more customized RNN. Because you get to define the function, you can manipulate the inputs, outputs, and states however you so choose. Afterwards, you get a full accounting of the state at every time step, rather than just the output.
In the case of the HMLSTM, we use these states to keep track of the boundary detection, and we also map over them to obtain the final predictions.
Here’s the code for reference.
Conclusion
In this post, we looked at the standard tools for dealing with RNNs in tensorflow, and explored some more flexible options to use when those tools fall short.