All we need is TensorFlow:
In [1]:
import tensorflow as tf
First, define the constants.
Let's say we're dealing with 1-dimensional vectors, and a maximum sequence size of 3.
In [2]:
input_dim = 1
seq_size = 3
Next up, define the placeholder(s).
We only need one for this simple example: the input placeholder.
In [3]:
input_placeholder = tf.placeholder(dtype=tf.float32, shape=[None, seq_size, input_dim])
Now let's make a helper function to create LSTM cells
In [4]:
def make_cell(state_dim):
return tf.contrib.rnn.LSTMCell(state_dim)
Call the function and extract the cell outputs.
In [5]:
with tf.variable_scope("first_cell") as scope:
cell = make_cell(state_dim=10)
outputs, states = tf.nn.dynamic_rnn(cell, input_placeholder, dtype=tf.float32)
You know what? We can just keep stacking cells on top of each other. In a new variable scope, you can pipe the output of the previous cell to the input of the new cell. Check it out:
In [6]:
with tf.variable_scope("second_cell") as scope:
cell2 = make_cell(state_dim=10)
outputs2, states2 = tf.nn.dynamic_rnn(cell2, outputs, dtype=tf.float32)
What if we wanted 5 layers of RNNs?
There's a useful shortcut that the TensorFlow library supplies, called MultiRNNCell
. Here's a helper function to use it:
In [7]:
def make_multi_cell(state_dim, num_layers):
cells = [make_cell(state_dim) for _ in range(num_layers)]
return tf.contrib.rnn.MultiRNNCell(cells)
Here's the helper function in action:
In [8]:
multi_cell = make_multi_cell(state_dim=10, num_layers=5)
outputs5, states5 = tf.nn.dynamic_rnn(multi_cell, input_placeholder, dtype=tf.float32)
Before starting a session, let's prepare some simple input to the network.
In [9]:
input_seq = [[1], [2], [3]]
Start the session, and initialize variables.
In [10]:
init_op = tf.global_variables_initializer()
sess = tf.InteractiveSession()
sess.run(init_op)
We can run the outputs to verify that the code is sound.
In [11]:
outputs_val, outputs2_val, outputs5_val = sess.run([outputs, outputs2, outputs5],
feed_dict={input_placeholder: [input_seq]})
print(outputs_val)
print(outputs2_val)
print(outputs5_val)