Classifying MNIST with RNN


In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np

In [2]:
#pd.set_option("display.max_rows", 200)
#pd.set_option("display.max_columns",None)
#pd.set_option('display.float_format', lambda x: '%.4f' % x)

In [3]:
n_steps = 28
n_inputs = 28
n_neurons = 300
n_outputs = 10

n_epochs = 100
batch_size = 300

learning_rate = 0.0005

In [4]:
mnist = input_data.read_data_sets('/tmp/data/')
X_test = mnist.test.images.reshape((-1, n_steps, n_inputs))
y_test = mnist.test.labels


Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz

In [5]:
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.int32, [None])

In [6]:
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)

In [7]:
logits = tf.layers.dense(states, n_outputs)
x_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)

In [8]:
loss = tf.reduce_mean(x_entropy)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
training_op = optimizer.minimize(loss)
correct = tf.nn.in_top_k(logits, y, 1)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

In [9]:
tf.variable_scope('rnn', initializer=tf.contrib.layers.variance_scaling_initializer(factor=1.0, mode='FAN_AVG', uniform=True))
init = tf.global_variables_initializer()

In [10]:
with tf.Session() as session:
    init.run()
    for epoch in range(n_epochs):
        for i in range(mnist.train.num_examples // batch_size):
            X_batch, y_batch = mnist.train.next_batch(batch_size)
            X_batch = X_batch.reshape((-1, n_steps, n_inputs))
            session.run(training_op, feed_dict={X: X_batch, y: y_batch})
        acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
        acc_test = accuracy.eval(feed_dict={X: X_test, y: y_test})
        print(epoch, "Training Accuracy:", acc_train, "Test Accuracy:", acc_test)


0 Training Accuracy: 0.913333 Test Accuracy: 0.9338
1 Training Accuracy: 0.943333 Test Accuracy: 0.9471
2 Training Accuracy: 0.956667 Test Accuracy: 0.964
3 Training Accuracy: 0.966667 Test Accuracy: 0.9611
4 Training Accuracy: 0.973333 Test Accuracy: 0.9695
5 Training Accuracy: 0.96 Test Accuracy: 0.9693
6 Training Accuracy: 0.983333 Test Accuracy: 0.9753
7 Training Accuracy: 0.986667 Test Accuracy: 0.9753
8 Training Accuracy: 0.993333 Test Accuracy: 0.9749
9 Training Accuracy: 0.99 Test Accuracy: 0.9735
10 Training Accuracy: 0.983333 Test Accuracy: 0.9791
11 Training Accuracy: 0.996667 Test Accuracy: 0.9775
12 Training Accuracy: 0.97 Test Accuracy: 0.9778
13 Training Accuracy: 0.996667 Test Accuracy: 0.9793
14 Training Accuracy: 0.99 Test Accuracy: 0.9737
15 Training Accuracy: 0.973333 Test Accuracy: 0.9803
16 Training Accuracy: 0.986667 Test Accuracy: 0.9814
17 Training Accuracy: 0.996667 Test Accuracy: 0.9767
18 Training Accuracy: 1.0 Test Accuracy: 0.98
19 Training Accuracy: 0.996667 Test Accuracy: 0.9809
20 Training Accuracy: 1.0 Test Accuracy: 0.9794
21 Training Accuracy: 0.986667 Test Accuracy: 0.9819
22 Training Accuracy: 0.99 Test Accuracy: 0.9812
23 Training Accuracy: 0.996667 Test Accuracy: 0.9794
24 Training Accuracy: 0.993333 Test Accuracy: 0.9822
25 Training Accuracy: 0.996667 Test Accuracy: 0.9788
26 Training Accuracy: 1.0 Test Accuracy: 0.9804
27 Training Accuracy: 0.993333 Test Accuracy: 0.9799
28 Training Accuracy: 0.996667 Test Accuracy: 0.9843
29 Training Accuracy: 0.99 Test Accuracy: 0.9788
30 Training Accuracy: 0.993333 Test Accuracy: 0.9816
31 Training Accuracy: 0.99 Test Accuracy: 0.9832
32 Training Accuracy: 1.0 Test Accuracy: 0.9814
33 Training Accuracy: 0.983333 Test Accuracy: 0.9827
34 Training Accuracy: 0.993333 Test Accuracy: 0.9807
35 Training Accuracy: 0.986667 Test Accuracy: 0.9792
36 Training Accuracy: 0.996667 Test Accuracy: 0.981
37 Training Accuracy: 1.0 Test Accuracy: 0.9828
38 Training Accuracy: 0.993333 Test Accuracy: 0.9813
39 Training Accuracy: 0.996667 Test Accuracy: 0.982
40 Training Accuracy: 1.0 Test Accuracy: 0.985
41 Training Accuracy: 0.993333 Test Accuracy: 0.9803
42 Training Accuracy: 0.993333 Test Accuracy: 0.9778
43 Training Accuracy: 0.993333 Test Accuracy: 0.9828
44 Training Accuracy: 1.0 Test Accuracy: 0.9804
45 Training Accuracy: 0.996667 Test Accuracy: 0.9844
46 Training Accuracy: 0.996667 Test Accuracy: 0.9794
47 Training Accuracy: 0.996667 Test Accuracy: 0.9812
48 Training Accuracy: 0.996667 Test Accuracy: 0.9805
49 Training Accuracy: 1.0 Test Accuracy: 0.9835
50 Training Accuracy: 0.996667 Test Accuracy: 0.9842
51 Training Accuracy: 0.996667 Test Accuracy: 0.982
52 Training Accuracy: 0.996667 Test Accuracy: 0.9814
53 Training Accuracy: 1.0 Test Accuracy: 0.9836
54 Training Accuracy: 1.0 Test Accuracy: 0.9805
55 Training Accuracy: 0.996667 Test Accuracy: 0.9818
56 Training Accuracy: 1.0 Test Accuracy: 0.9844
57 Training Accuracy: 0.996667 Test Accuracy: 0.9815
58 Training Accuracy: 1.0 Test Accuracy: 0.9812
59 Training Accuracy: 0.996667 Test Accuracy: 0.9831
60 Training Accuracy: 1.0 Test Accuracy: 0.9804
61 Training Accuracy: 1.0 Test Accuracy: 0.9857
62 Training Accuracy: 1.0 Test Accuracy: 0.9878
63 Training Accuracy: 1.0 Test Accuracy: 0.9822
64 Training Accuracy: 1.0 Test Accuracy: 0.9814
65 Training Accuracy: 1.0 Test Accuracy: 0.9853
66 Training Accuracy: 1.0 Test Accuracy: 0.9816
67 Training Accuracy: 1.0 Test Accuracy: 0.9832
68 Training Accuracy: 0.996667 Test Accuracy: 0.9799
69 Training Accuracy: 0.996667 Test Accuracy: 0.9831
70 Training Accuracy: 1.0 Test Accuracy: 0.983
71 Training Accuracy: 1.0 Test Accuracy: 0.9857
72 Training Accuracy: 1.0 Test Accuracy: 0.9842
73 Training Accuracy: 0.996667 Test Accuracy: 0.9786
74 Training Accuracy: 1.0 Test Accuracy: 0.9811
75 Training Accuracy: 1.0 Test Accuracy: 0.9849
76 Training Accuracy: 0.996667 Test Accuracy: 0.9829
77 Training Accuracy: 1.0 Test Accuracy: 0.9805
78 Training Accuracy: 1.0 Test Accuracy: 0.9817
79 Training Accuracy: 0.996667 Test Accuracy: 0.9839
80 Training Accuracy: 1.0 Test Accuracy: 0.9825
81 Training Accuracy: 1.0 Test Accuracy: 0.9853
82 Training Accuracy: 1.0 Test Accuracy: 0.9829
83 Training Accuracy: 0.996667 Test Accuracy: 0.9841
84 Training Accuracy: 1.0 Test Accuracy: 0.9842
85 Training Accuracy: 0.996667 Test Accuracy: 0.9837
86 Training Accuracy: 1.0 Test Accuracy: 0.984
87 Training Accuracy: 1.0 Test Accuracy: 0.982
88 Training Accuracy: 1.0 Test Accuracy: 0.9814
89 Training Accuracy: 1.0 Test Accuracy: 0.986
90 Training Accuracy: 1.0 Test Accuracy: 0.9854
91 Training Accuracy: 0.996667 Test Accuracy: 0.979
92 Training Accuracy: 0.993333 Test Accuracy: 0.9825
93 Training Accuracy: 0.993333 Test Accuracy: 0.9818
94 Training Accuracy: 1.0 Test Accuracy: 0.9857
95 Training Accuracy: 1.0 Test Accuracy: 0.9878
96 Training Accuracy: 1.0 Test Accuracy: 0.9827
97 Training Accuracy: 1.0 Test Accuracy: 0.9828
98 Training Accuracy: 1.0 Test Accuracy: 0.9867
99 Training Accuracy: 1.0 Test Accuracy: 0.9746