Paste this solution into the previous notebook.


In [ ]:
prediction = tf.argmax(y, 1)

def predict(idx):
    image = mnist.test.images[idx]
    return sess.run(prediction, feed_dict={x: [image]})

idx = 0
actual = np.argmax(mnist.test.labels[idx])
print ("Predicted: %d, Actual: %d" % (predict(idx), actual))
plt.imshow(mnist.test.images[idx].reshape((28,28)), cmap=plt.cm.gray_r)