MNIST Image Classification with TensorFlow

This notebook demonstrates how to implement a simple linear image models on MNIST using Estimator.


This companion notebook extends the basic harness of this notebook to a variety of models including DNN, CNN, dropout, pooling etc.


In [1]:
import numpy as np
import shutil
import os
import tensorflow as tf
print(tf.__version__)


1.13.1

Exploring the data

Let's download MNIST data and examine the shape. We will need these numbers ...


In [2]:
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("mnist/data", one_hot = True, reshape = False)
print(mnist.train.images.shape)
print(mnist.train.labels.shape)


WARNING:tensorflow:From <ipython-input-2-09f4f2344a9e>:3: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting mnist/data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting mnist/data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting mnist/data/t10k-images-idx3-ubyte.gz
Extracting mnist/data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
(55000, 28, 28, 1)
(55000, 10)

In [3]:
HEIGHT = 28
WIDTH = 28
NCLASSES = 10

In [4]:
import matplotlib.pyplot as plt
IMGNO = 12
plt.imshow(mnist.test.images[IMGNO].reshape(HEIGHT, WIDTH));


Define the model.

Let's start with a very simple linear classifier. All our models will have this basic interface -- they will take an image and return logits.


In [5]:
# Using low-level tensorflow
def linear_model(img):
    X = tf.reshape(tensor = img, shape = [-1, HEIGHT * WIDTH]) #flatten
    W = tf.get_variable(name = "W", shape = [HEIGHT * WIDTH, NCLASSES], initializer = tf.truncated_normal_initializer(stddev = 0.1, seed = 1))
    b = tf.get_variable(name = "b", shape = [NCLASSES], initializer = tf.zeros_initializer)
    ylogits = tf.matmul(a = X, b = W) + b
    return ylogits, NCLASSES

Note we can also build our linear classifer using the tf.layers API. Notice when using tf.layers we don't have to define or initialize our weights and biases. This happens automatically for us in the background.

When building more complex models such as DNNs and CNNs our code will be much more readable by using the tf.layers API


In [6]:
# Using tf.layers API
def linear_model(img):
    X = tf.reshape(tensor = img, shape = [-1, HEIGHT * WIDTH]) #flatten
    ylogits = tf.layers.dense(inputs = X, units = NCLASSES, activation = None)
    return ylogits, NCLASSES

Write Input Functions

As usual, we need to specify input functions for training, evaluation, and predicition.


In [7]:
train_input_fn = tf.estimator.inputs.numpy_input_fn(
    x = {"image": mnist.train.images},
    y = mnist.train.labels,
    batch_size = 100,
    num_epochs = None,
    shuffle = True,
    queue_capacity = 5000
)

eval_input_fn = tf.estimator.inputs.numpy_input_fn(
    x = {"image": mnist.test.images},
    y = mnist.test.labels,
    batch_size = 100,
    num_epochs = 1,
    shuffle = False,
    queue_capacity = 5000
)

def serving_input_fn():
    inputs = {"image": tf.placeholder(dtype = tf.float32, shape = [None, HEIGHT, WIDTH])}
    features = inputs # as-is
    return tf.estimator.export.ServingInputReceiver(features = features, receiver_tensors = inputs)

Write Custom Estimator

I could have simply used a canned LinearClassifier, but later on, I will want to use different models, and so let's write a custom estimator


In [8]:
def image_classifier(features, labels, mode, params):
    ylogits, nclasses = linear_model(features["image"])
    probabilities = tf.nn.softmax(logits = ylogits)
    class_ids = tf.cast(x = tf.argmax(input = probabilities, axis = 1), dtype = tf.uint8)

    if mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL:
        loss = tf.reduce_mean(input_tensor = tf.nn.softmax_cross_entropy_with_logits_v2(logits = ylogits, labels = labels))
        
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = tf.contrib.layers.optimize_loss(
                loss = loss, 
                global_step = tf.train.get_global_step(),
                learning_rate = params["learning_rate"], 
                optimizer = "Adam")
            eval_metric_ops = None
        else:
            train_op = None
            eval_metric_ops =  {"accuracy": tf.metrics.accuracy(labels = tf.argmax(input = labels, axis = 1), predictions = class_ids)}
    else:
        loss = None
        train_op = None
        eval_metric_ops = None
 
    return tf.estimator.EstimatorSpec(
        mode = mode,
        predictions = {"probabilities": probabilities, "class_ids": class_ids},
        loss = loss,
        train_op = train_op,
        eval_metric_ops = eval_metric_ops,
        export_outputs = {"predictions": tf.estimator.export.PredictOutput({"probabilities": probabilities, "class_ids": class_ids})}
    )

tf.estimator.train_and_evaluate does distributed training.


In [9]:
def train_and_evaluate(output_dir, hparams):
    estimator = tf.estimator.Estimator(
        model_fn = image_classifier,
        model_dir = output_dir,
        params = hparams)

    train_spec = tf.estimator.TrainSpec(
        input_fn = train_input_fn,
        max_steps = hparams["train_steps"])

    exporter = tf.estimator.LatestExporter(name = "exporter", serving_input_receiver_fn = serving_input_fn)

    eval_spec = tf.estimator.EvalSpec(
        input_fn = eval_input_fn,
        steps = None,
        exporters = exporter)

    tf.estimator.train_and_evaluate(estimator = estimator, train_spec = train_spec, eval_spec = eval_spec)

This is the main() function


In [10]:
OUTDIR = "mnist/learned"
shutil.rmtree(path = OUTDIR, ignore_errors = True) # start fresh each time

hparams = {"train_steps": 1000, "learning_rate": 0.01}
train_and_evaluate(OUTDIR, hparams)


INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_keep_checkpoint_max': 5, '_master': '', '_is_chief': True, '_task_type': 'worker', '_service': None, '_save_checkpoints_steps': None, '_evaluation_master': '', '_num_ps_replicas': 0, '_experimental_distribute': None, '_tf_random_seed': None, '_num_worker_replicas': 1, '_task_id': 0, '_model_dir': 'mnist/learned', '_global_id_in_cluster': 0, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fd993011550>, '_log_step_count_steps': 100, '_save_checkpoints_secs': 600, '_device_fn': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_every_n_hours': 10000, '_save_summary_steps': 100, '_train_distribute': None, '_eval_distribute': None, '_protocol': None}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:62: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:500: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From <ipython-input-6-bae7dbd14ceb>:4: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.dense instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/monitored_session.py:809: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Saving checkpoints for 0 into mnist/learned/model.ckpt.
INFO:tensorflow:loss = 2.462242, step = 1
INFO:tensorflow:global_step/sec: 205.865
INFO:tensorflow:loss = 0.28505534, step = 101 (0.488 sec)
INFO:tensorflow:global_step/sec: 472.335
INFO:tensorflow:loss = 0.33613923, step = 201 (0.212 sec)
INFO:tensorflow:global_step/sec: 663.842
INFO:tensorflow:loss = 0.38363686, step = 301 (0.151 sec)
INFO:tensorflow:global_step/sec: 663.681
INFO:tensorflow:loss = 0.33253673, step = 401 (0.151 sec)
INFO:tensorflow:global_step/sec: 596.742
INFO:tensorflow:loss = 0.118054286, step = 501 (0.167 sec)
INFO:tensorflow:global_step/sec: 201.554
INFO:tensorflow:loss = 0.39076436, step = 601 (0.496 sec)
INFO:tensorflow:global_step/sec: 363.999
INFO:tensorflow:loss = 0.20863627, step = 701 (0.275 sec)
INFO:tensorflow:global_step/sec: 581.541
INFO:tensorflow:loss = 0.47417432, step = 801 (0.173 sec)
INFO:tensorflow:global_step/sec: 673.19
INFO:tensorflow:loss = 0.45061016, step = 901 (0.148 sec)
INFO:tensorflow:Saving checkpoints for 1000 into mnist/learned/model.ckpt.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/metrics_impl.py:455: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2019-04-10T18:16:11Z
INFO:tensorflow:Graph was finalized.
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from mnist/learned/model.ckpt-1000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2019-04-10-18:16:12
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.9207, global_step = 1000, loss = 0.28777623
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: mnist/learned/model.ckpt-1000
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:205: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
INFO:tensorflow:Signatures INCLUDED in export for Regress: None
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default', 'predictions']
INFO:tensorflow:Signatures INCLUDED in export for Classify: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Restoring parameters from mnist/learned/model.ckpt-1000
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: mnist/learned/export/exporter/temp-b'1554920172'/saved_model.pb
INFO:tensorflow:Loss for final step: 0.4148974.

I got:

Saving dict for global step 1000: accuracy = 0.9158, global_step = 1000, loss = 0.29720208

In other words, we achieved 91.6% accuracy with the simple linear model!

# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.