This notebook demonstrates how to implement a simple linear image models on MNIST using tf.keras.
In [1]:
!sudo chown -R jupyter:jupyter /home/jupyter/training-data-analyst
In [2]:
import numpy as np
import shutil
import os
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
print(tf.__version__)
In [3]:
HEIGHT = 28
WIDTH = 28
NCLASSES = 10
In [4]:
# Get mnist data
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Scale our features between 0 and 1
x_train, x_test = x_train / 255.0, x_test / 255.0
# Convert labels to categorical one-hot encoding
y_train = tf.keras.utils.to_categorical(y = y_train, num_classes = NCLASSES)
y_test = tf.keras.utils.to_categorical(y = y_test, num_classes = NCLASSES)
print("x_train.shape = {}".format(x_train.shape))
print("y_train.shape = {}".format(y_train.shape))
print("x_test.shape = {}".format(x_test.shape))
print("y_test.shape = {}".format(y_test.shape))
In [5]:
import matplotlib.pyplot as plt
IMGNO = 12
plt.imshow(x_test[IMGNO].reshape(HEIGHT, WIDTH));
In [6]:
# Build Keras Model Using Keras Sequential API
def linear_model():
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.InputLayer(input_shape = [HEIGHT, WIDTH], name = "image"))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(units = NCLASSES, activation = tf.nn.softmax, name = "probabilities"))
return model
In [7]:
# Create training input function
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x = {"image": x_train},
y = y_train,
batch_size = 100,
num_epochs = None,
shuffle = True,
queue_capacity = 5000
)
# Create evaluation input function
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
x = {"image": x_test},
y = y_test,
batch_size = 100,
num_epochs = 1,
shuffle = False,
queue_capacity = 5000
)
# Create serving input function for inference
def serving_input_fn():
placeholders = {"image": tf.placeholder(dtype = tf.float32, shape = [None, HEIGHT, WIDTH])}
features = placeholders # as-is
return tf.estimator.export.ServingInputReceiver(features = features, receiver_tensors = placeholders)
tf.estimator.train_and_evaluate does distributed training.
In [8]:
def train_and_evaluate(output_dir, hparams):
# Build Keras model
model = linear_model()
# Compile Keras model with optimizer, loss function, and eval metrics
model.compile(
optimizer = "adam",
loss = "categorical_crossentropy",
metrics = ["accuracy"])
# Convert Keras model to an Estimator
estimator = tf.keras.estimator.model_to_estimator(
keras_model = model,
model_dir = output_dir)
# Set estimator's train_spec to use train_input_fn and train for so many steps
train_spec = tf.estimator.TrainSpec(
input_fn = train_input_fn,
max_steps = hparams["train_steps"])
# Create exporter that uses serving_input_fn to create saved_model for serving
exporter = tf.estimator.LatestExporter(
name = "exporter",
serving_input_receiver_fn = serving_input_fn)
# Set estimator's eval_spec to use eval_input_fn and export saved_model
eval_spec = tf.estimator.EvalSpec(
input_fn = eval_input_fn,
steps = None,
exporters = exporter)
# Run train_and_evaluate loop
tf.estimator.train_and_evaluate(
estimator = estimator,
train_spec = train_spec,
eval_spec = eval_spec)
This is the main() function
In [9]:
OUTDIR = "mnist/learned"
shutil.rmtree(OUTDIR, ignore_errors = True) # start fresh each time
hparams = {"train_steps": 1000, "learning_rate": 0.01}
train_and_evaluate(OUTDIR, hparams)
I got:
Saving dict for global step 1000: categorical_accuracy = 0.9112, global_step = 1000, loss = 0.32516304
In other words, we achieved 91.12% accuracy with the simple linear model!
# Copyright 2020 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.