Recurrent Neural Network Example

Build a recurrent neural network (LSTM) with TensorFlow.

RNN Overview

References:

MNIST Dataset Overview

This example is using MNIST handwritten digits. The dataset contains 60,000 examples for training and 10,000 examples for testing. The digits have been size-normalized and centered in a fixed-size image (28x28 pixels) with values from 0 to 1. For simplicity, each image has been flattened and converted to a 1-D numpy array of 784 features (28*28).

To classify images using a recurrent neural network, we consider every image row as a sequence of pixels. Because MNIST image shape is 28*28px, we will then handle 28 sequences of 28 timesteps for every sample.

More info: http://yann.lecun.com/exdb/mnist/


In [1]:
from __future__ import print_function

import tensorflow as tf
from tensorflow.contrib import rnn

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)


Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz

In [2]:
# Training Parameters
learning_rate = 0.001
training_steps = 10000
batch_size = 128
display_step = 200

# Network Parameters
num_input = 28 # MNIST data input (img shape: 28*28)
timesteps = 28 # timesteps
num_hidden = 128 # hidden layer num of features
num_classes = 10 # MNIST total classes (0-9 digits)

# tf Graph input
X = tf.placeholder("float", [None, timesteps, num_input])
Y = tf.placeholder("float", [None, num_classes])

In [3]:
# Define weights
weights = {
    'out': tf.Variable(tf.random_normal([num_hidden, num_classes]))
}
biases = {
    'out': tf.Variable(tf.random_normal([num_classes]))
}

In [4]:
def RNN(x, weights, biases):

    # Prepare data shape to match `rnn` function requirements
    # Current data input shape: (batch_size, timesteps, n_input)
    # Required shape: 'timesteps' tensors list of shape (batch_size, n_input)

    # Unstack to get a list of 'timesteps' tensors of shape (batch_size, n_input)
    x = tf.unstack(x, timesteps, 1)

    # Define a lstm cell with tensorflow
    lstm_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0)

    # Get lstm cell output
    outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)

    # Linear activation, using rnn inner loop last output
    return tf.matmul(outputs[-1], weights['out']) + biases['out']

In [5]:
logits = RNN(X, weights, biases)
prediction = tf.nn.softmax(logits)

# Define loss and optimizer
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
    logits=logits, labels=Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(loss_op)

# Evaluate model (with test logits, for dropout to be disabled)
correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Initialize the variables (i.e. assign their default value)
init = tf.global_variables_initializer()

In [6]:
# Start training
with tf.Session() as sess:

    # Run the initializer
    sess.run(init)

    for step in range(1, training_steps+1):
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        # Reshape data to get 28 seq of 28 elements
        batch_x = batch_x.reshape((batch_size, timesteps, num_input))
        # Run optimization op (backprop)
        sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})
        if step % display_step == 0 or step == 1:
            # Calculate batch loss and accuracy
            loss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_x,
                                                                 Y: batch_y})
            print("Step " + str(step) + ", Minibatch Loss= " + \
                  "{:.4f}".format(loss) + ", Training Accuracy= " + \
                  "{:.3f}".format(acc))

    print("Optimization Finished!")

    # Calculate accuracy for 128 mnist test images
    test_len = 128
    test_data = mnist.test.images[:test_len].reshape((-1, timesteps, num_input))
    test_label = mnist.test.labels[:test_len]
    print("Testing Accuracy:", \
        sess.run(accuracy, feed_dict={X: test_data, Y: test_label}))


Step 1, Minibatch Loss= 2.6268, Training Accuracy= 0.102
Step 200, Minibatch Loss= 2.0722, Training Accuracy= 0.328
Step 400, Minibatch Loss= 1.9181, Training Accuracy= 0.336
Step 600, Minibatch Loss= 1.8858, Training Accuracy= 0.336
Step 800, Minibatch Loss= 1.7022, Training Accuracy= 0.422
Step 1000, Minibatch Loss= 1.6365, Training Accuracy= 0.477
Step 1200, Minibatch Loss= 1.6691, Training Accuracy= 0.516
Step 1400, Minibatch Loss= 1.4626, Training Accuracy= 0.547
Step 1600, Minibatch Loss= 1.4707, Training Accuracy= 0.539
Step 1800, Minibatch Loss= 1.4087, Training Accuracy= 0.570
Step 2000, Minibatch Loss= 1.3033, Training Accuracy= 0.570
Step 2200, Minibatch Loss= 1.3773, Training Accuracy= 0.508
Step 2400, Minibatch Loss= 1.3092, Training Accuracy= 0.570
Step 2600, Minibatch Loss= 1.2272, Training Accuracy= 0.609
Step 2800, Minibatch Loss= 1.1827, Training Accuracy= 0.633
Step 3000, Minibatch Loss= 1.0453, Training Accuracy= 0.641
Step 3200, Minibatch Loss= 1.0400, Training Accuracy= 0.648
Step 3400, Minibatch Loss= 1.1145, Training Accuracy= 0.656
Step 3600, Minibatch Loss= 0.9884, Training Accuracy= 0.688
Step 3800, Minibatch Loss= 1.0395, Training Accuracy= 0.703
Step 4000, Minibatch Loss= 1.0096, Training Accuracy= 0.664
Step 4200, Minibatch Loss= 0.8806, Training Accuracy= 0.758
Step 4400, Minibatch Loss= 0.9090, Training Accuracy= 0.766
Step 4600, Minibatch Loss= 1.0060, Training Accuracy= 0.703
Step 4800, Minibatch Loss= 0.8954, Training Accuracy= 0.703
Step 5000, Minibatch Loss= 0.8163, Training Accuracy= 0.750
Step 5200, Minibatch Loss= 0.7620, Training Accuracy= 0.773
Step 5400, Minibatch Loss= 0.7388, Training Accuracy= 0.758
Step 5600, Minibatch Loss= 0.7604, Training Accuracy= 0.695
Step 5800, Minibatch Loss= 0.7459, Training Accuracy= 0.734
Step 6000, Minibatch Loss= 0.7448, Training Accuracy= 0.734
Step 6200, Minibatch Loss= 0.7208, Training Accuracy= 0.773
Step 6400, Minibatch Loss= 0.6557, Training Accuracy= 0.773
Step 6600, Minibatch Loss= 0.8616, Training Accuracy= 0.758
Step 6800, Minibatch Loss= 0.6089, Training Accuracy= 0.773
Step 7000, Minibatch Loss= 0.5020, Training Accuracy= 0.844
Step 7200, Minibatch Loss= 0.5980, Training Accuracy= 0.812
Step 7400, Minibatch Loss= 0.6786, Training Accuracy= 0.766
Step 7600, Minibatch Loss= 0.4891, Training Accuracy= 0.859
Step 7800, Minibatch Loss= 0.7042, Training Accuracy= 0.797
Step 8000, Minibatch Loss= 0.4200, Training Accuracy= 0.859
Step 8200, Minibatch Loss= 0.6442, Training Accuracy= 0.742
Step 8400, Minibatch Loss= 0.5569, Training Accuracy= 0.828
Step 8600, Minibatch Loss= 0.5838, Training Accuracy= 0.836
Step 8800, Minibatch Loss= 0.5579, Training Accuracy= 0.812
Step 9000, Minibatch Loss= 0.4337, Training Accuracy= 0.867
Step 9200, Minibatch Loss= 0.4366, Training Accuracy= 0.844
Step 9400, Minibatch Loss= 0.5051, Training Accuracy= 0.844
Step 9600, Minibatch Loss= 0.5244, Training Accuracy= 0.805
Step 9800, Minibatch Loss= 0.4932, Training Accuracy= 0.805
Step 10000, Minibatch Loss= 0.4833, Training Accuracy= 0.852
Optimization Finished!
Testing Accuracy: 0.882812

In [ ]: