Bi-directional Recurrent Neural Network Example

Build a bi-directional recurrent neural network (LSTM) with TensorFlow.

BiRNN Overview

References:

MNIST Dataset Overview

This example is using MNIST handwritten digits. The dataset contains 60,000 examples for training and 10,000 examples for testing. The digits have been size-normalized and centered in a fixed-size image (28x28 pixels) with values from 0 to 1. For simplicity, each image has been flattened and converted to a 1-D numpy array of 784 features (28*28).

To classify images using a recurrent neural network, we consider every image row as a sequence of pixels. Because MNIST image shape is 28*28px, we will then handle 28 sequences of 28 timesteps for every sample.

More info: http://yann.lecun.com/exdb/mnist/


In [1]:
from __future__ import print_function

import tensorflow as tf
from tensorflow.contrib import rnn
import numpy as np

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)


Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz

In [2]:
# Training Parameters
learning_rate = 0.001
training_steps = 10000
batch_size = 128
display_step = 200

# Network Parameters
num_input = 28 # MNIST data input (img shape: 28*28)
timesteps = 28 # timesteps
num_hidden = 128 # hidden layer num of features
num_classes = 10 # MNIST total classes (0-9 digits)

# tf Graph input
X = tf.placeholder("float", [None, timesteps, num_input])
Y = tf.placeholder("float", [None, num_classes])

In [3]:
# Define weights
weights = {
    # Hidden layer weights => 2*n_hidden because of forward + backward cells
    'out': tf.Variable(tf.random_normal([2*num_hidden, num_classes]))
}
biases = {
    'out': tf.Variable(tf.random_normal([num_classes]))
}

In [4]:
def BiRNN(x, weights, biases):

    # Prepare data shape to match `rnn` function requirements
    # Current data input shape: (batch_size, timesteps, n_input)
    # Required shape: 'timesteps' tensors list of shape (batch_size, num_input)

    # Unstack to get a list of 'timesteps' tensors of shape (batch_size, num_input)
    x = tf.unstack(x, timesteps, 1)

    # Define lstm cells with tensorflow
    # Forward direction cell
    lstm_fw_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0)
    # Backward direction cell
    lstm_bw_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0)

    # Get lstm cell output
    try:
        outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,
                                              dtype=tf.float32)
    except Exception: # Old TensorFlow version only returns outputs not states
        outputs = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,
                                        dtype=tf.float32)

    # Linear activation, using rnn inner loop last output
    return tf.matmul(outputs[-1], weights['out']) + biases['out']

In [5]:
logits = BiRNN(X, weights, biases)
prediction = tf.nn.softmax(logits)

# Define loss and optimizer
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
    logits=logits, labels=Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(loss_op)

# Evaluate model (with test logits, for dropout to be disabled)
correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Initialize the variables (i.e. assign their default value)
init = tf.global_variables_initializer()

In [6]:
# Start training
with tf.Session() as sess:

    # Run the initializer
    sess.run(init)

    for step in range(1, training_steps+1):
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        # Reshape data to get 28 seq of 28 elements
        batch_x = batch_x.reshape((batch_size, timesteps, num_input))
        # Run optimization op (backprop)
        sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})
        if step % display_step == 0 or step == 1:
            # Calculate batch loss and accuracy
            loss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_x,
                                                                 Y: batch_y})
            print("Step " + str(step) + ", Minibatch Loss= " + \
                  "{:.4f}".format(loss) + ", Training Accuracy= " + \
                  "{:.3f}".format(acc))

    print("Optimization Finished!")

    # Calculate accuracy for 128 mnist test images
    test_len = 128
    test_data = mnist.test.images[:test_len].reshape((-1, timesteps, num_input))
    test_label = mnist.test.labels[:test_len]
    print("Testing Accuracy:", \
        sess.run(accuracy, feed_dict={X: test_data, Y: test_label}))


Step 1, Minibatch Loss= 2.6218, Training Accuracy= 0.086
Step 200, Minibatch Loss= 2.1900, Training Accuracy= 0.211
Step 400, Minibatch Loss= 2.0144, Training Accuracy= 0.375
Step 600, Minibatch Loss= 1.8729, Training Accuracy= 0.445
Step 800, Minibatch Loss= 1.8000, Training Accuracy= 0.469
Step 1000, Minibatch Loss= 1.7244, Training Accuracy= 0.453
Step 1200, Minibatch Loss= 1.5657, Training Accuracy= 0.523
Step 1400, Minibatch Loss= 1.5473, Training Accuracy= 0.547
Step 1600, Minibatch Loss= 1.5288, Training Accuracy= 0.500
Step 1800, Minibatch Loss= 1.4203, Training Accuracy= 0.555
Step 2000, Minibatch Loss= 1.2525, Training Accuracy= 0.641
Step 2200, Minibatch Loss= 1.2696, Training Accuracy= 0.594
Step 2400, Minibatch Loss= 1.2000, Training Accuracy= 0.664
Step 2600, Minibatch Loss= 1.1017, Training Accuracy= 0.625
Step 2800, Minibatch Loss= 1.2656, Training Accuracy= 0.578
Step 3000, Minibatch Loss= 1.0830, Training Accuracy= 0.656
Step 3200, Minibatch Loss= 1.1522, Training Accuracy= 0.633
Step 3400, Minibatch Loss= 0.9484, Training Accuracy= 0.680
Step 3600, Minibatch Loss= 1.0470, Training Accuracy= 0.641
Step 3800, Minibatch Loss= 1.0609, Training Accuracy= 0.586
Step 4000, Minibatch Loss= 1.1853, Training Accuracy= 0.648
Step 4200, Minibatch Loss= 0.9438, Training Accuracy= 0.750
Step 4400, Minibatch Loss= 0.7986, Training Accuracy= 0.766
Step 4600, Minibatch Loss= 0.8070, Training Accuracy= 0.750
Step 4800, Minibatch Loss= 0.8382, Training Accuracy= 0.734
Step 5000, Minibatch Loss= 0.7397, Training Accuracy= 0.766
Step 5200, Minibatch Loss= 0.7870, Training Accuracy= 0.727
Step 5400, Minibatch Loss= 0.6380, Training Accuracy= 0.828
Step 5600, Minibatch Loss= 0.7975, Training Accuracy= 0.719
Step 5800, Minibatch Loss= 0.7934, Training Accuracy= 0.766
Step 6000, Minibatch Loss= 0.6628, Training Accuracy= 0.805
Step 6200, Minibatch Loss= 0.7958, Training Accuracy= 0.672
Step 6400, Minibatch Loss= 0.6582, Training Accuracy= 0.773
Step 6600, Minibatch Loss= 0.5908, Training Accuracy= 0.812
Step 6800, Minibatch Loss= 0.6182, Training Accuracy= 0.820
Step 7000, Minibatch Loss= 0.5513, Training Accuracy= 0.812
Step 7200, Minibatch Loss= 0.6683, Training Accuracy= 0.789
Step 7400, Minibatch Loss= 0.5337, Training Accuracy= 0.828
Step 7600, Minibatch Loss= 0.6428, Training Accuracy= 0.805
Step 7800, Minibatch Loss= 0.6708, Training Accuracy= 0.797
Step 8000, Minibatch Loss= 0.4664, Training Accuracy= 0.852
Step 8200, Minibatch Loss= 0.4249, Training Accuracy= 0.859
Step 8400, Minibatch Loss= 0.7723, Training Accuracy= 0.773
Step 8600, Minibatch Loss= 0.4706, Training Accuracy= 0.859
Step 8800, Minibatch Loss= 0.4800, Training Accuracy= 0.867
Step 9000, Minibatch Loss= 0.4636, Training Accuracy= 0.891
Step 9200, Minibatch Loss= 0.5734, Training Accuracy= 0.828
Step 9400, Minibatch Loss= 0.5548, Training Accuracy= 0.875
Step 9600, Minibatch Loss= 0.3575, Training Accuracy= 0.922
Step 9800, Minibatch Loss= 0.4566, Training Accuracy= 0.844
Step 10000, Minibatch Loss= 0.5125, Training Accuracy= 0.844
Optimization Finished!
Testing Accuracy: 0.890625

In [ ]: