In [1]:
import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

In [2]:
# Load MNIST data
mnist = input_data.read_data_sets("MNIST_data/")


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

In [3]:
train_data = mnist.train.images
train_labels = mnist.train.labels

test_data = mnist.test.images
test_labels = mnist.test.labels

print("train_data", train_data.shape)
print("train_labels", train_labels.shape)


train_data (55000, 784)
train_labels (55000,)

In [4]:
def my_model_fn(features, labels, mode):
    """Model function for our CNN"""
    
    net = tf.reshape(features['x'], [-1, 28, 28, 1])
    for _ in range(3):
        net = tf.layers.conv2d(
            inputs=net,
            filters=32,
            kernel_size=[3, 3],
            padding="same",
            activation=tf.nn.relu
        )
        net = tf.layers.max_pooling2d(
            inputs=net,
            pool_size=[2, 2],
            strides=[2,2]
        )
    net = tf.layers.flatten(net)
    net = tf.layers.dense(inputs=net, units=64)
    logits = tf.layers.dense(inputs=net, units=10)
    
    onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=10)
    loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, 
                                           logits=logits)
    
    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
        train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
    eval_metric_ops = {
        "accuracy": tf.metrics.accuracy(labels=labels, 
                                       predictions=tf.argmax(input=logits, axis=1))
    }
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

In [5]:
mnist_estimator = tf.estimator.Estimator(
    model_fn=my_model_fn,
    model_dir="E:\\temp\\mnist_estimator"
)


INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_task_type': 'worker', '_service': None, '_tf_random_seed': None, '_master': '', '_keep_checkpoint_max': 5, '_session_config': None, '_task_id': 0, '_save_summary_steps': 100, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x000001E3979645F8>, '_save_checkpoints_steps': None, '_is_chief': True, '_keep_checkpoint_every_n_hours': 10000, '_num_worker_replicas': 1, '_model_dir': 'E:\\temp\\mnist_estimator', '_log_step_count_steps': 100, '_save_checkpoints_secs': 600, '_num_ps_replicas': 0}

In [6]:
train_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"x": train_data},
    y=train_labels,
    batch_size=64,
    num_epochs=10,
    shuffle=True
)

In [ ]:
mnist_estimator.train(
    input_fn=train_input_fn)


INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from E:\temp\mnist_estimator\model.ckpt-68754

In [8]:
test_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"x": test_data},
    y=test_labels,
    num_epochs=1,
    shuffle=False)

In [ ]:
test_results = mnist_estimator.evaluate(input_fn=test_input_fn)
print(test_results)

In [11]:
train_spec = tf.estimator.TrainSpec(
    input_fn=train_input_fn,
    max_steps=50000
)

test_spec = tf.estimator.EvalSpec(
    input_fn=test_input_fn,
    steps=50, throttle_secs=60
)
tf.estimator.train_and_evaluate(mnist_estimator, train_spec, test_spec)


INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after 60 secs (eval_spec.throttle_secs) or training is finished.
INFO:tensorflow:Skipping training since max_steps has already saved.
INFO:tensorflow:Starting evaluation at 2018-01-14-15:22:35
INFO:tensorflow:Restoring parameters from E:\temp\mnist_estimator\model.ckpt-68754
INFO:tensorflow:Evaluation [1/50]
INFO:tensorflow:Evaluation [2/50]
INFO:tensorflow:Evaluation [3/50]
INFO:tensorflow:Evaluation [4/50]
INFO:tensorflow:Evaluation [5/50]
INFO:tensorflow:Evaluation [6/50]
INFO:tensorflow:Evaluation [7/50]
INFO:tensorflow:Evaluation [8/50]
INFO:tensorflow:Evaluation [9/50]
INFO:tensorflow:Evaluation [10/50]
INFO:tensorflow:Evaluation [11/50]
INFO:tensorflow:Evaluation [12/50]
INFO:tensorflow:Evaluation [13/50]
INFO:tensorflow:Evaluation [14/50]
INFO:tensorflow:Evaluation [15/50]
INFO:tensorflow:Evaluation [16/50]
INFO:tensorflow:Evaluation [17/50]
INFO:tensorflow:Evaluation [18/50]
INFO:tensorflow:Evaluation [19/50]
INFO:tensorflow:Evaluation [20/50]
INFO:tensorflow:Evaluation [21/50]
INFO:tensorflow:Evaluation [22/50]
INFO:tensorflow:Evaluation [23/50]
INFO:tensorflow:Evaluation [24/50]
INFO:tensorflow:Evaluation [25/50]
INFO:tensorflow:Evaluation [26/50]
INFO:tensorflow:Evaluation [27/50]
INFO:tensorflow:Evaluation [28/50]
INFO:tensorflow:Evaluation [29/50]
INFO:tensorflow:Evaluation [30/50]
INFO:tensorflow:Evaluation [31/50]
INFO:tensorflow:Evaluation [32/50]
INFO:tensorflow:Evaluation [33/50]
INFO:tensorflow:Evaluation [34/50]
INFO:tensorflow:Evaluation [35/50]
INFO:tensorflow:Evaluation [36/50]
INFO:tensorflow:Evaluation [37/50]
INFO:tensorflow:Evaluation [38/50]
INFO:tensorflow:Evaluation [39/50]
INFO:tensorflow:Evaluation [40/50]
INFO:tensorflow:Evaluation [41/50]
INFO:tensorflow:Evaluation [42/50]
INFO:tensorflow:Evaluation [43/50]
INFO:tensorflow:Evaluation [44/50]
INFO:tensorflow:Evaluation [45/50]
INFO:tensorflow:Evaluation [46/50]
INFO:tensorflow:Evaluation [47/50]
INFO:tensorflow:Evaluation [48/50]
INFO:tensorflow:Evaluation [49/50]
INFO:tensorflow:Evaluation [50/50]
INFO:tensorflow:Finished evaluation at 2018-01-14-15:22:36
INFO:tensorflow:Saving dict for global step 68754: accuracy = 0.978437, global_step = 68754, loss = 0.0697263

In [ ]: