Sequence classification with LSTM on MNIST

In this notebook you will learn the How to use TensorFlow for create a Recurrent Neural Network


Introduction

Recurrent Neural Networks are Deep Learning models with simple structures and a feedback mechanism builted-in, or in different words, the output of a layer is added to the next input and fed back to the same layer.

The Recurrent Neural Network is a specialized type of Neural Network that solves the issue of maintaining context for Sequential data -- such as Weather data, Stocks, Genes, etc. At each iterative step, the processing unit takes in an input and the current state of the network, and produces an output and a new state that is re-fed into the network.

However, this model has some problems. It's very computationally expensive to maintain the state for a large amount of units, even more so over a long amount of time. Additionally, Recurrent Networks are very sensitive to changes in their parameters. As such, they are prone to different problems with their Gradient Descent optimizer -- they either grow exponentially (Exploding Gradient) or drop down to near zero and stabilize (Vanishing Gradient), both problems that greatly harm a model's learning capability.

To solve these problems, Hochreiter and Schmidhuber published a paper in 1997 describing a way to keep information over long periods of time and additionally solve the oversensitivity to parameter changes, i.e., make backpropagating through the Recurrent Networks more viable.

(In this notebook, we will cover only LSTM and its implementation using TensorFlow)

Architectures

  • Fully Recurrent Network
  • Recursive Neural Networks
  • Hopfield Networks
  • Elman Networks and Jordan Networks
  • Echo State Networks
  • Neural history compressor
  • The Long Short-Term Memory Model (LSTM)

LSTM

LSTM is one of the proposed solutions or upgrades to the Recurrent Neural Network model.

It is an abstraction of how computer memory works. It is "bundled" with whatever processing unit is implemented in the Recurrent Network, although outside of its flow, and is responsible for keeping, reading, and outputting information for the model. The way it works is simple: you have a linear unit, which is the information cell itself, surrounded by three logistic gates responsible for maintaining the data. One gate is for inputting data into the information cell, one is for outputting data from the input cell, and the last one is to keep or forget data depending on the needs of the network.

Thanks to that, it not only solves the problem of keeping states, because the network can choose to forget data whenever information is not needed, it also solves the gradient problems, since the Logistic Gates have a very nice derivative.

Long Short-Term Memory Architecture

As seen before, the Long Short-Term Memory is composed of a linear unit surrounded by three logistic gates. The name for these gates vary from place to place, but the most usual names for them are:

  • the "Input" or "Write" Gate, which handles the writing of data into the information cell,
  • the "Output" or "Read" Gate, which handles the sending of data back onto the Recurrent Network, and
  • the "Keep" or "Forget" Gate, which handles the maintaining and modification of the data stored in the information cell.

*Diagram of the Long Short-Term Memory Unit*

The three gates are the centerpiece of the LSTM unit. The gates, when activated by the network, perform their respective functions. For example, the Input Gate will write whatever data it is passed onto the information cell, the Output Gate will return whatever data is in the information cell, and the Keep Gate will maintain the data in the information cell. These gates are analog and multiplicative, and as such, can modify the data based on the signal they are sent.


Building a LSTM with TensorFlow

LSTM for Classification

Although RNN is mostly used to model sequences and predict sequential data, we can still classify images using a LSTM network. If we consider every image row as a sequence of pixels, we can feed a LSTM network for classification. Lets use the famous MNIST dataset here. Because MNIST image shape is 28*28px, we will then handle 28 sequences of 28 steps for every sample.

MNIST Dataset

Tensor flow already provides helper functions to download and process the MNIST dataset.


In [1]:
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("../../data/MNIST/", one_hot=True)


Extracting ../../data/MNIST/train-images-idx3-ubyte.gz
Extracting ../../data/MNIST/train-labels-idx1-ubyte.gz
Extracting ../../data/MNIST/t10k-images-idx3-ubyte.gz
Extracting ../../data/MNIST/t10k-labels-idx1-ubyte.gz

The function input_data.read_data_sets(...) loads the entire dataset and returns an object tensorflow.contrib.learn.python.learn.datasets.mnist.DataSets

The argument (one_hot=False) creates the label arrays as 10-dimensional binary vectors (only zeros and ones), in which the index cell for the number one, is the class label.


In [3]:
trainimgs = mnist.train.images
trainlabels = mnist.train.labels
testimgs = mnist.test.images
testlabels = mnist.test.labels 

ntrain = trainimgs.shape[0]
ntest = testimgs.shape[0]
dim = trainimgs.shape[1]
nclasses = trainlabels.shape[1]
print("Train Images: ", trainimgs.shape)
print("Train Labels  ", trainlabels.shape)
print()
print("Test Images:  " , testimgs.shape)
print("Test Labels:  ", testlabels.shape)


Train Images:  (55000, 784)
Train Labels   (55000, 10)

Test Images:   (10000, 784)
Test Labels:   (10000, 10)

Let's get one sample, just to understand the structure of MNIST dataset

The next code snippet prints the label vector (one_hot format), the class and actual sample formatted as image:


In [4]:
samplesIdx = [100, 101, 102]  #<-- You can change these numbers here to see other samples

from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()

ax1 = fig.add_subplot(121)
ax1.imshow(testimgs[samplesIdx[0]].reshape([28,28]), cmap='gray')


xx, yy = np.meshgrid(np.linspace(0,28,28), np.linspace(0,28,28))
X =  xx ; Y =  yy
Z =  100*np.ones(X.shape)

img = testimgs[77].reshape([28,28])
ax = fig.add_subplot(122, projection='3d')
ax.set_zlim((0,200))


offset=200
for i in samplesIdx:
    img = testimgs[i].reshape([28,28]).transpose()
    ax.contourf(X, Y, img, 200, zdir='z', offset=offset, cmap="gray")
    offset -= 100

    ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])

plt.show()


for i in samplesIdx:
    print("Sample: {0} - Class: {1} - Label Vector: {2} ".format(i, np.nonzero(testlabels[i])[0], testlabels[i]))


Sample: 100 - Class: [6] - Label Vector: [ 0.  0.  0.  0.  0.  0.  1.  0.  0.  0.] 
Sample: 101 - Class: [0] - Label Vector: [ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.] 
Sample: 102 - Class: [5] - Label Vector: [ 0.  0.  0.  0.  0.  1.  0.  0.  0.  0.] 

Let's Understand the parameters, inputs and outputs

We will treat the MNIST image $\in \mathcal{R}^{28 \times 28}$ as $28$ sequences of a vector $\mathbf{x} \in \mathcal{R}^{28}$.

Our simple RNN consists of

  1. One input layer which converts a $28$ dimensional input to an $128$ dimensional hidden layer,
  2. One intermediate recurrent neural network (LSTM)
  3. One output layer which converts an $128$ dimensional output of the LSTM to $10$ dimensional output indicating a class label.

In [5]:
n_input = 28 # MNIST data input (img shape: 28*28)
n_steps = 28 # timesteps
n_hidden = 128 # hidden layer num of features
n_classes = 10 # MNIST total classes (0-9 digits)


learning_rate = 0.001
training_iters = 100000
batch_size = 100
display_step = 10

Construct a Recurrent Neural Network


In [6]:
x = tf.placeholder(dtype="float", shape=[None, n_steps, n_input], name="x")
y = tf.placeholder(dtype="float", shape=[None, n_classes], name="y")

In [7]:
weights = {
    'out': tf.Variable(tf.random_normal([n_hidden, n_classes]))
}
biases = {
    'out': tf.Variable(tf.random_normal([n_classes]))
}

The input should be a Tensor of shape: [batch_size, max_time, ...], in our case it would be (?, 28, 28)


In [8]:
# Define a lstm cell with tensorflow
lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0, state_is_tuple=True)

#initial state
#initial_state = (tf.zeros([1,n_hidden]),)*2

In [9]:
def RNN(x, weights, biases):

    # Prepare data shape to match `rnn` function requirements
    # Current data input shape: (batch_size, n_steps, n_input) [100x28x28]

    # Define a lstm cell with tensorflow
    lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
   

    # Get lstm cell output
    outputs, states = tf.nn.dynamic_rnn(lstm_cell, inputs=x, dtype=tf.float32)

    # Get lstm cell output
    #outputs, states = lstm_cell(x , initial_state)
    
    # The output of the rnn would be a [100x28x128] matrix. we use the linear activation to map it to a [?x10 matrix]
    # Linear activation, using rnn inner loop last output
    # output [100x128] x  weight [128, 10] + []
    output = tf.reshape(tf.split(outputs, 28, axis=1, num=None, name='split')[-1],[-1,128])
    return tf.matmul(output, weights['out']) + biases['out']

In [10]:
with tf.variable_scope('forward3'):
    pred = RNN(x, weights, biases)

labels and logits should be tensors of shape [100x10]


In [11]:
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=pred ))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

In [12]:
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

In [13]:
accuracy_v2 = tf.contrib.metrics.accuracy(
    labels=tf.arg_max(y, dimension=1), 
    predictions=tf.arg_max(pred, dimension=1)
)

Just recall that we will treat the MNIST image $\in \mathcal{R}^{28 \times 28}$ as $28$ sequences of a vector $\mathbf{x} \in \mathcal{R}^{28}$.


In [14]:
sess = tf.InteractiveSession()
init = tf.global_variables_initializer()

In [15]:
sess.run(init)
step = 1
# Keep training until reach max iterations
while step * batch_size < training_iters:

    # We will read a batch of 100 images [100 x 784] as batch_x
    # batch_y is a matrix of [100x10]
    batch_x, batch_y = mnist.train.next_batch(batch_size)

    # We consider each row of the image as one sequence
    # Reshape data to get 28 seq of 28 elements, so that, batxh_x is [100x28x28]
    batch_x = batch_x.reshape((batch_size, n_steps, n_input))


    # Run optimization op (backprop)
    sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})


    if step % display_step == 0:
        # Calculate batch accuracy
        acc, acc2 = sess.run([accuracy, accuracy_v2], feed_dict={x: batch_x, y: batch_y})
        # Calculate batch loss
        loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
        print("({} / {})    Minibatch loss={:.6f}    Accuracy={:.5f}    Accuracy (tf)={:.5f}".format(
                step*batch_size,
                training_iters,
                loss,
                acc,
                acc2
            ))
    step += 1
print("Optimization Finished!")


(1000 / 100000)    Minibatch loss=1.850881    Accuracy=0.43000    Accuracy (tf)=0.43000
(2000 / 100000)    Minibatch loss=1.456735    Accuracy=0.55000    Accuracy (tf)=0.55000
(3000 / 100000)    Minibatch loss=1.155387    Accuracy=0.59000    Accuracy (tf)=0.59000
(4000 / 100000)    Minibatch loss=1.249778    Accuracy=0.54000    Accuracy (tf)=0.54000
(5000 / 100000)    Minibatch loss=1.142365    Accuracy=0.64000    Accuracy (tf)=0.64000
(6000 / 100000)    Minibatch loss=0.885371    Accuracy=0.74000    Accuracy (tf)=0.74000
(7000 / 100000)    Minibatch loss=0.507802    Accuracy=0.81000    Accuracy (tf)=0.81000
(8000 / 100000)    Minibatch loss=0.628446    Accuracy=0.75000    Accuracy (tf)=0.75000
(9000 / 100000)    Minibatch loss=0.534296    Accuracy=0.80000    Accuracy (tf)=0.80000
(10000 / 100000)    Minibatch loss=0.328939    Accuracy=0.91000    Accuracy (tf)=0.91000
(11000 / 100000)    Minibatch loss=0.335713    Accuracy=0.90000    Accuracy (tf)=0.90000
(12000 / 100000)    Minibatch loss=0.470027    Accuracy=0.86000    Accuracy (tf)=0.86000
(13000 / 100000)    Minibatch loss=0.404602    Accuracy=0.86000    Accuracy (tf)=0.86000
(14000 / 100000)    Minibatch loss=0.193695    Accuracy=0.97000    Accuracy (tf)=0.97000
(15000 / 100000)    Minibatch loss=0.379994    Accuracy=0.89000    Accuracy (tf)=0.89000
(16000 / 100000)    Minibatch loss=0.209487    Accuracy=0.93000    Accuracy (tf)=0.93000
(17000 / 100000)    Minibatch loss=0.195182    Accuracy=0.96000    Accuracy (tf)=0.96000
(18000 / 100000)    Minibatch loss=0.296144    Accuracy=0.89000    Accuracy (tf)=0.89000
(19000 / 100000)    Minibatch loss=0.340441    Accuracy=0.90000    Accuracy (tf)=0.90000
(20000 / 100000)    Minibatch loss=0.350278    Accuracy=0.91000    Accuracy (tf)=0.91000
(21000 / 100000)    Minibatch loss=0.291284    Accuracy=0.87000    Accuracy (tf)=0.87000
(22000 / 100000)    Minibatch loss=0.297869    Accuracy=0.89000    Accuracy (tf)=0.89000
(23000 / 100000)    Minibatch loss=0.173801    Accuracy=0.93000    Accuracy (tf)=0.93000
(24000 / 100000)    Minibatch loss=0.234894    Accuracy=0.91000    Accuracy (tf)=0.91000
(25000 / 100000)    Minibatch loss=0.174178    Accuracy=0.96000    Accuracy (tf)=0.96000
(26000 / 100000)    Minibatch loss=0.182272    Accuracy=0.93000    Accuracy (tf)=0.93000
(27000 / 100000)    Minibatch loss=0.172375    Accuracy=0.93000    Accuracy (tf)=0.93000
(28000 / 100000)    Minibatch loss=0.191111    Accuracy=0.92000    Accuracy (tf)=0.92000
(29000 / 100000)    Minibatch loss=0.246552    Accuracy=0.92000    Accuracy (tf)=0.92000
(30000 / 100000)    Minibatch loss=0.189268    Accuracy=0.94000    Accuracy (tf)=0.94000
(31000 / 100000)    Minibatch loss=0.138168    Accuracy=0.97000    Accuracy (tf)=0.97000
(32000 / 100000)    Minibatch loss=0.102459    Accuracy=0.99000    Accuracy (tf)=0.99000
(33000 / 100000)    Minibatch loss=0.401999    Accuracy=0.89000    Accuracy (tf)=0.89000
(34000 / 100000)    Minibatch loss=0.110543    Accuracy=0.97000    Accuracy (tf)=0.97000
(35000 / 100000)    Minibatch loss=0.168854    Accuracy=0.95000    Accuracy (tf)=0.95000
(36000 / 100000)    Minibatch loss=0.125984    Accuracy=0.97000    Accuracy (tf)=0.97000
(37000 / 100000)    Minibatch loss=0.176938    Accuracy=0.94000    Accuracy (tf)=0.94000
(38000 / 100000)    Minibatch loss=0.105703    Accuracy=0.96000    Accuracy (tf)=0.96000
(39000 / 100000)    Minibatch loss=0.257360    Accuracy=0.93000    Accuracy (tf)=0.93000
(40000 / 100000)    Minibatch loss=0.178438    Accuracy=0.96000    Accuracy (tf)=0.96000
(41000 / 100000)    Minibatch loss=0.160108    Accuracy=0.95000    Accuracy (tf)=0.95000
(42000 / 100000)    Minibatch loss=0.100121    Accuracy=0.97000    Accuracy (tf)=0.97000
(43000 / 100000)    Minibatch loss=0.190255    Accuracy=0.96000    Accuracy (tf)=0.96000
(44000 / 100000)    Minibatch loss=0.242325    Accuracy=0.95000    Accuracy (tf)=0.95000
(45000 / 100000)    Minibatch loss=0.262376    Accuracy=0.93000    Accuracy (tf)=0.93000
(46000 / 100000)    Minibatch loss=0.082209    Accuracy=0.97000    Accuracy (tf)=0.97000
(47000 / 100000)    Minibatch loss=0.163795    Accuracy=0.94000    Accuracy (tf)=0.94000
(48000 / 100000)    Minibatch loss=0.122494    Accuracy=0.97000    Accuracy (tf)=0.97000
(49000 / 100000)    Minibatch loss=0.153163    Accuracy=0.95000    Accuracy (tf)=0.95000
(50000 / 100000)    Minibatch loss=0.170939    Accuracy=0.93000    Accuracy (tf)=0.93000
(51000 / 100000)    Minibatch loss=0.087172    Accuracy=0.98000    Accuracy (tf)=0.98000
(52000 / 100000)    Minibatch loss=0.147273    Accuracy=0.97000    Accuracy (tf)=0.97000
(53000 / 100000)    Minibatch loss=0.082001    Accuracy=0.96000    Accuracy (tf)=0.96000
(54000 / 100000)    Minibatch loss=0.071467    Accuracy=0.99000    Accuracy (tf)=0.99000
(55000 / 100000)    Minibatch loss=0.152581    Accuracy=0.97000    Accuracy (tf)=0.97000
(56000 / 100000)    Minibatch loss=0.117195    Accuracy=0.96000    Accuracy (tf)=0.96000
(57000 / 100000)    Minibatch loss=0.078302    Accuracy=0.97000    Accuracy (tf)=0.97000
(58000 / 100000)    Minibatch loss=0.162988    Accuracy=0.96000    Accuracy (tf)=0.96000
(59000 / 100000)    Minibatch loss=0.098463    Accuracy=0.96000    Accuracy (tf)=0.96000
(60000 / 100000)    Minibatch loss=0.075145    Accuracy=0.98000    Accuracy (tf)=0.98000
(61000 / 100000)    Minibatch loss=0.151033    Accuracy=0.95000    Accuracy (tf)=0.95000
(62000 / 100000)    Minibatch loss=0.031515    Accuracy=0.99000    Accuracy (tf)=0.99000
(63000 / 100000)    Minibatch loss=0.092133    Accuracy=0.96000    Accuracy (tf)=0.96000
(64000 / 100000)    Minibatch loss=0.056619    Accuracy=0.97000    Accuracy (tf)=0.97000
(65000 / 100000)    Minibatch loss=0.066067    Accuracy=0.98000    Accuracy (tf)=0.98000
(66000 / 100000)    Minibatch loss=0.065037    Accuracy=0.98000    Accuracy (tf)=0.98000
(67000 / 100000)    Minibatch loss=0.040570    Accuracy=0.98000    Accuracy (tf)=0.98000
(68000 / 100000)    Minibatch loss=0.124935    Accuracy=0.96000    Accuracy (tf)=0.96000
(69000 / 100000)    Minibatch loss=0.135773    Accuracy=0.94000    Accuracy (tf)=0.94000
(70000 / 100000)    Minibatch loss=0.060647    Accuracy=0.98000    Accuracy (tf)=0.98000
(71000 / 100000)    Minibatch loss=0.049782    Accuracy=0.98000    Accuracy (tf)=0.98000
(72000 / 100000)    Minibatch loss=0.088643    Accuracy=0.97000    Accuracy (tf)=0.97000
(73000 / 100000)    Minibatch loss=0.150619    Accuracy=0.96000    Accuracy (tf)=0.96000
(74000 / 100000)    Minibatch loss=0.368645    Accuracy=0.92000    Accuracy (tf)=0.92000
(75000 / 100000)    Minibatch loss=0.104583    Accuracy=0.97000    Accuracy (tf)=0.97000
(76000 / 100000)    Minibatch loss=0.183688    Accuracy=0.95000    Accuracy (tf)=0.95000
(77000 / 100000)    Minibatch loss=0.118187    Accuracy=0.97000    Accuracy (tf)=0.97000
(78000 / 100000)    Minibatch loss=0.105380    Accuracy=0.96000    Accuracy (tf)=0.96000
(79000 / 100000)    Minibatch loss=0.075565    Accuracy=0.97000    Accuracy (tf)=0.97000
(80000 / 100000)    Minibatch loss=0.141826    Accuracy=0.96000    Accuracy (tf)=0.96000
(81000 / 100000)    Minibatch loss=0.111309    Accuracy=0.96000    Accuracy (tf)=0.96000
(82000 / 100000)    Minibatch loss=0.133840    Accuracy=0.97000    Accuracy (tf)=0.97000
(83000 / 100000)    Minibatch loss=0.198778    Accuracy=0.94000    Accuracy (tf)=0.94000
(84000 / 100000)    Minibatch loss=0.197886    Accuracy=0.94000    Accuracy (tf)=0.94000
(85000 / 100000)    Minibatch loss=0.115872    Accuracy=0.97000    Accuracy (tf)=0.97000
(86000 / 100000)    Minibatch loss=0.084435    Accuracy=0.98000    Accuracy (tf)=0.98000
(87000 / 100000)    Minibatch loss=0.053613    Accuracy=0.98000    Accuracy (tf)=0.98000
(88000 / 100000)    Minibatch loss=0.075559    Accuracy=0.98000    Accuracy (tf)=0.98000
(89000 / 100000)    Minibatch loss=0.120616    Accuracy=0.96000    Accuracy (tf)=0.96000
(90000 / 100000)    Minibatch loss=0.216229    Accuracy=0.94000    Accuracy (tf)=0.94000
(91000 / 100000)    Minibatch loss=0.064998    Accuracy=0.98000    Accuracy (tf)=0.98000
(92000 / 100000)    Minibatch loss=0.113708    Accuracy=0.95000    Accuracy (tf)=0.95000
(93000 / 100000)    Minibatch loss=0.064808    Accuracy=0.97000    Accuracy (tf)=0.97000
(94000 / 100000)    Minibatch loss=0.076773    Accuracy=0.97000    Accuracy (tf)=0.97000
(95000 / 100000)    Minibatch loss=0.181465    Accuracy=0.94000    Accuracy (tf)=0.94000
(96000 / 100000)    Minibatch loss=0.119581    Accuracy=0.96000    Accuracy (tf)=0.96000
(97000 / 100000)    Minibatch loss=0.103791    Accuracy=0.97000    Accuracy (tf)=0.97000
(98000 / 100000)    Minibatch loss=0.057231    Accuracy=0.98000    Accuracy (tf)=0.98000
(99000 / 100000)    Minibatch loss=0.125223    Accuracy=0.95000    Accuracy (tf)=0.95000
Optimization Finished!

In [16]:
# Calculate accuracy for the whole test set
test_data = mnist.test.images.reshape((-1, n_steps, n_input))
test_label = mnist.test.labels
print("Testing Accuracy: {:.3%}".format(sess.run(accuracy, feed_dict={x: test_data, y: test_label})))


Testing Accuracy: 96.670%

In [17]:
sess.close()