MNIST Data Set - Basic Approach

Get the MNIST Data


In [1]:
import tensorflow as tf

In [2]:
from tensorflow.examples.tutorials.mnist import input_data


WARNING:tensorflow:From c:\programdata\anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.

In [3]:
mnist = input_data.read_data_sets("./data/MNIST_data/", 
                                  one_hot = True)


WARNING:tensorflow:From <ipython-input-3-5064ac4c1486>:2: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From c:\programdata\anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From c:\programdata\anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ./data/MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From c:\programdata\anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ./data/MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From c:\programdata\anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting ./data/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ./data/MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From c:\programdata\anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.

In [4]:
type(mnist)


Out[4]:
tensorflow.contrib.learn.python.learn.datasets.base.Datasets

In [5]:
mnist.train.images


Out[5]:
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [6]:
mnist.train.num_examples


Out[6]:
55000

In [7]:
mnist.test.num_examples


Out[7]:
10000

In [8]:
mnist.validation.num_examples


Out[8]:
5000

Visualizing the Data


In [9]:
import matplotlib.pyplot as plt
%matplotlib inline

In [10]:
# The image is a long array
mnist.train.images[1].shape


Out[10]:
(784,)

In [11]:
# Showing reshaped image
plt.imshow(mnist.train.images[1].reshape(28, 28))


Out[11]:
<matplotlib.image.AxesImage at 0x1a1806a2160>

In [12]:
# Showing the image in gist gray scale
plt.imshow(mnist.train.images[1].reshape(28, 28), 
           cmap = 'gist_gray')


Out[12]:
<matplotlib.image.AxesImage at 0x1a180770668>

In [13]:
mnist.train.images[1].max()


Out[13]:
1.0

In [14]:
plt.imshow(mnist.train.images[1].reshape(784, 1))


Out[14]:
<matplotlib.image.AxesImage at 0x1a180814080>

In [15]:
plt.imshow(mnist.train.images[1].reshape(784, 1),
           cmap = 'gist_gray', 
           aspect = 0.02)


Out[15]:
<matplotlib.image.AxesImage at 0x1a1808aa240>

Create the Model


In [16]:
# Initializing a Placeholder of shape None (number of inputs) by 784
# Tensorflow requires float32
x = tf.placeholder(tf.float32,
                shape = [None, 784])

In [17]:
# Initializing weights between the input layer and the output layer
# It is of shape number_of_features by number_of_neurons_in_the_layer
# Initializing with zeros, which is meh, but we will use it for simplicity

# 10 because 0-9 possible numbers
W = tf.Variable(tf.zeros([784, 10]))

In [18]:
# Initializing biases 
b = tf.Variable(tf.zeros([10]))

In [19]:
# Create the Graph
y = tf.matmul(x, W) + b

Loss and Optimizer


In [20]:
# Initializing a Placeholder of shape None (number of inputs) by number_of_classes
y_true = tf.placeholder(tf.float32, 
                        shape = [None, 10])

In [21]:
# Cross Entropy
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels = y_true, 
                                                                          logits = y))

In [22]:
# Gradient Descent
optimizer = tf.train.GradientDescentOptimizer(learning_rate = 0.5)

In [23]:
# Minimizing the loss function 
train = optimizer.minimize(cross_entropy)

Create Session


In [24]:
# Intializing all variables
init = tf.global_variables_initializer()

In [25]:
with tf.Session() as sess:
    sess.run(init)
    
    # Train the model for 1000 steps on the training set
    # Using built in batch feeder from mnist for convenience
    for step in range(1000):
        # Training on a batch of 100 examples
        batch_x , batch_y = mnist.train.next_batch(100)
        
        sess.run(train, feed_dict = {x : batch_x,
                                     y_true : batch_y})
        
    # Calculating the number of matches
    matches = tf.equal(tf.argmax(y, 1), 
                       tf.argmax(y_true, 1))
    
    acc = tf.reduce_mean(tf.cast(matches, tf.float32))
    
    # Calculating the accuracy
    print(sess.run(acc, feed_dict = {x : mnist.test.images, 
                                     y_true : mnist.test.labels}))


0.9164

While this may seem pretty good, we can actually do much better, the best models can get above 99% accuracy.

How do they do this? By using other models, such as convolutional neural networks!