Ch 11: Concept 01

Multi RNN

All we need is TensorFlow:

In [1]:
import tensorflow as tf

First, define the constants.

Let's say we're dealing with 1-dimensional vectors, and a maximum sequence size of 3.

In [2]:
input_dim = 1
seq_size = 3

Next up, define the placeholder(s).

We only need one for this simple example: the input placeholder.

In [3]:
input_placeholder = tf.placeholder(dtype=tf.float32, shape=[None, seq_size, input_dim])

Now let's make a helper function to create LSTM cells

In [4]:
def make_cell(state_dim):
    return tf.contrib.rnn.LSTMCell(state_dim)

Call the function and extract the cell outputs.

In [5]:
with tf.variable_scope("first_cell") as scope:
    cell = make_cell(state_dim=10)
    outputs, states = tf.nn.dynamic_rnn(cell, input_placeholder, dtype=tf.float32)

You know what? We can just keep stacking cells on top of each other. In a new variable scope, you can pipe the output of the previous cell to the input of the new cell. Check it out:

In [6]:
with tf.variable_scope("second_cell") as scope:
    cell2 = make_cell(state_dim=10)
    outputs2, states2 = tf.nn.dynamic_rnn(cell2, outputs, dtype=tf.float32)

What if we wanted 5 layers of RNNs?

There's a useful shortcut that the TensorFlow library supplies, called MultiRNNCell. Here's a helper function to use it:

In [7]:
def make_multi_cell(state_dim, num_layers):
    cells = [make_cell(state_dim) for _ in range(num_layers)]
    return tf.contrib.rnn.MultiRNNCell(cells)

Here's the helper function in action:

In [8]:
multi_cell = make_multi_cell(state_dim=10, num_layers=5)
outputs5, states5 = tf.nn.dynamic_rnn(multi_cell, input_placeholder, dtype=tf.float32)

Before starting a session, let's prepare some simple input to the network.

In [9]:
input_seq = [[1], [2], [3]]

Start the session, and initialize variables.

In [10]:
init_op = tf.global_variables_initializer()

sess = tf.InteractiveSession()

We can run the outputs to verify that the code is sound.

In [11]:
outputs_val, outputs2_val, outputs5_val =[outputs, outputs2, outputs5], 
                                                   feed_dict={input_placeholder: [input_seq]})

[[[ 0.00141981  0.0722062  -0.08216076  0.03216607  0.02329798  0.06957388
   -0.04552787 -0.05649291  0.05059107  0.01796713]
  [-0.00228921  0.16875982 -0.21222715  0.0772687   0.05970224  0.16551635
   -0.10631067 -0.13780777  0.12389956  0.05248111]
  [-0.01359726  0.24421771 -0.33965409  0.11412902  0.0964628   0.25151449
   -0.16440172 -0.22563797  0.18972857  0.09557904]]]
[[[-0.00224876  0.01402885  0.00929528 -0.00392457  0.00333697 -0.00213898
   -0.0046619  -0.01061259  0.00368386  0.00040365]
  [-0.00939647  0.04281064  0.02773804 -0.01503811  0.01025065 -0.00612708
   -0.01655139 -0.03407493  0.01263932  0.00136939]
  [-0.02229579  0.07774397  0.0480789  -0.03287651  0.017735   -0.01063949
   -0.03610384 -0.06736942  0.02458673  0.00139557]]]
[[[  1.42336748e-05  -1.58296571e-05  -1.62987853e-05   1.59381907e-05
     1.33105495e-05  -1.38451333e-05  -2.20941274e-05   2.54621627e-05
     2.26147549e-05  -3.30040712e-05]
  [  8.32868463e-05  -8.99614606e-05  -1.02287340e-04   8.68237403e-05
     8.19651614e-05  -7.38111194e-05  -1.29947555e-04   1.52955734e-04
     1.36927833e-04  -1.89826897e-04]
  [  2.79068598e-04  -2.77705112e-04  -3.54834105e-04   2.59170338e-04
     2.73422222e-04  -2.13255596e-04  -4.27101302e-04   5.09521109e-04
     4.57556103e-04  -6.11643540e-04]]]