In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import time

In [2]:
# Define paramaters for the model
learning_rate = 0.01
batch_size = 128
n_epochs = 40

Step 1: Read in data
using TF Learn's built in function to load MNIST data to the folder data/mnist


In [3]:
mnist = input_data.read_data_sets('../data/mnist', one_hot=True)


Extracting ../data/mnist/train-images-idx3-ubyte.gz
Extracting ../data/mnist/train-labels-idx1-ubyte.gz
Extracting ../data/mnist/t10k-images-idx3-ubyte.gz
Extracting ../data/mnist/t10k-labels-idx1-ubyte.gz

Step 2: create placeholders for features and labels
each image in the MNIST data is of shape 28*28 = 784
therefore, each image is represented with a 1x784 tensor
there are 10 classes for each image, corresponding to digits 0 - 9.
Features are of the type float, and labels are of the type int


In [4]:
X = tf.placeholder(tf.float32, shape=[None, 784], name='features')
Y = tf.placeholder(tf.int32, shape=[None, 10], name='labels')

Step 3: create weights and bias
weights and biases are initialized to 0
shape of w depends on the dimension of X and Y so that Y = X * w + b
shape of b depends on Y


In [5]:
W = tf.Variable(tf.zeros([784, 10]), name='weights')
b = tf.Variable(tf.zeros([10]), name='bias')

Step 4: build model
the model that returns the logits.
this logits will be later passed through softmax layer
to get the probability distribution of possible label of the image
DO NOT DO SOFTMAX HERE


In [6]:
logits = tf.matmul(X, W) + b

Step 5: define loss function
use cross entropy loss of the real labels with the softmax of logits
use the method:
tf.nn.softmax_cross_entropy_with_logits(logits, Y)
then use tf.reduce_mean to get the mean loss of the batch


In [7]:
loss_func = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits(labels=Y, logits=logits), name="loss")

Step 6: define training op
using gradient descent to minimize loss


In [8]:
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_func)

with tf.Session() as sess:
    start_time = time.time()
    sess.run(tf.global_variables_initializer())	
    n_batches = int(mnist.train.num_examples/batch_size)
    for i in range(n_epochs): # train the model n_epochs times
        total_loss = 0

        for _ in range(n_batches):
            X_batch, Y_batch = mnist.train.next_batch(batch_size)

            # run optimizer + fetch loss_batch
            _, loss_batch = sess.run([optimizer, loss_func], feed_dict={X:X_batch, Y:Y_batch})

            total_loss += loss_batch
        print('Average loss epoch {0}: {1}'.format(i, total_loss/n_batches))

    print('Total time: {0} seconds'.format(time.time() - start_time))

    print('Optimization Finished!') # should be around 0.35 after 25 epochs

    # test the model
    preds = tf.nn.softmax(logits)
    correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(Y, 1))
    accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32)) # need numpy.count_nonzero(boolarr) :(

    n_batches = int(mnist.test.num_examples/batch_size)
    total_correct_preds = 0

    for i in range(n_batches):
        X_batch, Y_batch = mnist.test.next_batch(batch_size)
        accuracy_batch = sess.run([accuracy], feed_dict={X: X_batch, Y:Y_batch})[0]
        total_correct_preds += accuracy_batch

    print('Accuracy {0}'.format(total_correct_preds/mnist.test.num_examples))


Average loss epoch 0: 1.2882110213899947
Average loss epoch 1: 0.7328066400714688
Average loss epoch 2: 0.6006202877679349
Average loss epoch 3: 0.536202555769807
Average loss epoch 4: 0.4982368903004484
Average loss epoch 5: 0.4710763553222576
Average loss epoch 6: 0.4505740274221469
Average loss epoch 7: 0.43710910926609886
Average loss epoch 8: 0.42283396127618555
Average loss epoch 9: 0.412501236219784
Average loss epoch 10: 0.4046692543930107
Average loss epoch 11: 0.3974340678928615
Average loss epoch 12: 0.3900773207416068
Average loss epoch 13: 0.38467453145758534
Average loss epoch 14: 0.3790666154631368
Average loss epoch 15: 0.3750903131592246
Average loss epoch 16: 0.3697237441520313
Average loss epoch 17: 0.3659326490405556
Average loss epoch 18: 0.3638986906944177
Average loss epoch 19: 0.35999457944523205
Average loss epoch 20: 0.35540290576316813
Average loss epoch 21: 0.3540274834994114
Average loss epoch 22: 0.35177521327714545
Average loss epoch 23: 0.3487731095724728
Average loss epoch 24: 0.34598777048932366
Average loss epoch 25: 0.34475264873676925
Average loss epoch 26: 0.34249747000929914
Average loss epoch 27: 0.33968463501869106
Average loss epoch 28: 0.3389067274012488
Average loss epoch 29: 0.3357210760211055
Average loss epoch 30: 0.3365112953669541
Average loss epoch 31: 0.3317439413709796
Average loss epoch 32: 0.33296743747376617
Average loss epoch 33: 0.33020946036129845
Average loss epoch 34: 0.3296469686887203
Average loss epoch 35: 0.32650633416809405
Average loss epoch 36: 0.32747466254345464
Average loss epoch 37: 0.32562732119938154
Average loss epoch 38: 0.3245558879383794
Average loss epoch 39: 0.3231874586263181
Total time: 55.77837538719177 seconds
Optimization Finished!
Accuracy 0.9144

In [ ]: