This notebook contains code to train a fully connected neural network on MNIST using tf.contrib.learn. At the end is a short exercise.
In [ ]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf
learn = tf.contrib.learn
tf.logging.set_verbosity(tf.logging.ERROR)
In [ ]:
mnist = learn.datasets.load_dataset('mnist')
data = mnist.train.images
labels = np.asarray(mnist.train.labels, dtype=np.int32)
test_data = mnist.test.images
test_labels = np.asarray(mnist.test.labels, dtype=np.int32)
There are 55k examples in train, and 10k in eval. You may wish to limit the size to experiment faster.
In [ ]:
max_examples = 10000
data = data[:max_examples]
labels = labels[:max_examples]
In [ ]:
def display(i):
img = test_data[i]
plt.title('Example %d. Label: %d' % (i, test_labels[i]))
plt.imshow(img.reshape((28,28)), cmap=plt.cm.gray_r)
In [ ]:
display(0)
In [ ]:
display(1)
These digits are clearly drawn. Here's one that's not.
In [ ]:
display(8)
Now let's take a look at how many features we have.
In [ ]:
print len(data[0])
In [ ]:
feature_columns = learn.infer_real_valued_columns_from_input(data)
classifier = learn.LinearClassifier(feature_columns=feature_columns, n_classes=10)
classifier.fit(data, labels, batch_size=100, steps=1000)
In [ ]:
classifier.evaluate(test_data, test_labels)["accuracy"]
In [ ]:
# here's one it gets right
print ("Predicted %d, Label: %d" % (list(classifier.predict(test_data[0:1]))[0], test_labels[0]))
display(0)
In [ ]:
# and one it gets wrong
print ("Predicted %d, Label: %d" % (list(classifier.predict(test_data[8:9]))[0], test_labels[8]))
display(8)
Let's see if we can reproduce the pictures of the weights in the TensorFlow Basic MNSIT tutorial.
In [ ]:
weights = classifier.weights_
f, axes = plt.subplots(2, 5, figsize=(10,4))
axes = axes.reshape(-1)
for i in range(len(axes)):
a = axes[i]
a.imshow(weights.T[i].reshape(28, 28), cmap=plt.cm.seismic)
a.set_title(i)
a.set_xticks(()) # ticks be gone
a.set_yticks(())
plt.show()
In [ ]:
# Build 2 layer DNN with 128, 32 units respectively.
# Play with these parameters to see if you can do better
# How? See https://www.tensorflow.org/versions/r0.12/tutorials/tflearn/index.html#tf-contrib-learn-quickstart
In [ ]:
classifier.evaluate(test_data, test_labels)["accuracy"]