Licensed under the Apache License, Version 2.0 (the "License");


In [ ]:
# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

MNIST on TPU (Tensor Processing Unit)
or GPU using tf.Keras and tf.data.Dataset

Overview

This sample trains an "MNIST" handwritten digit recognition model on a GPU or TPU backend using a Keras model. Data are handled using the tf.data.Datset API. This is a very simple sample provided for educational purposes. Do not expect outstanding TPU performance on a dataset as small as MNIST.

This notebook is hosted on GitHub. To view it in its original repository, after opening the notebook, select File > View on GitHub.

Learning objectives

In this notebook, you will learn how to:

  • Authenticate in Colab to access Google Cloud Storage (GSC)
  • Format and prepare a dataset using tf.data.Dataset
  • Create convolutional and dense layers using tf.keras.Sequential
  • Build a Keras classifier with softmax, cross-entropy, and the adam optimizer
  • Run training and validation in Keras using Cloud TPU
  • Export a model for serving from ML Engine
  • Deploy a trained model to ML Engine
  • Test predictions on a deployed model

Instructions

  Train on GPU or TPU  

  1. Select a GPU or TPU backend (Runtime > Change runtime type)
  2. Runtime > Run All
    (Watch out: the "Colab-only auth" cell requires user input.
    The "Deploy" part at the end requires cloud project and bucket configuration.)

  Deploy to AI Platform

At the bottom of this notebook you can deploy your trained model to AI Platform for a serverless, autoscaled, REST API experience. You will need a Google Cloud project and a GCS (Google Cloud Storage) bucket for this last part.

TPUs are located in Google Cloud, for optimal performance, they read data directly from Google Cloud Storage.

Imports


In [ ]:
import os, re, time, json
import PIL.Image, PIL.ImageFont, PIL.ImageDraw
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
print("Tensorflow version " + tf.__version__)


Tensorflow version 2.2.0

In [ ]:
#@title visualization utilities [RUN ME]
"""
This cell contains helper functions used for visualization
and downloads only. You can skip reading it. There is very
little useful Keras/Tensorflow code here.
"""

# Matplotlib config
plt.rc('image', cmap='gray_r')
plt.rc('grid', linewidth=0)
plt.rc('xtick', top=False, bottom=False, labelsize='large')
plt.rc('ytick', left=False, right=False, labelsize='large')
plt.rc('axes', facecolor='F8F8F8', titlesize="large", edgecolor='white')
plt.rc('text', color='a8151a')
plt.rc('figure', facecolor='F0F0F0')# Matplotlib fonts
MATPLOTLIB_FONT_DIR = os.path.join(os.path.dirname(plt.__file__), "mpl-data/fonts/ttf")

# pull a batch from the datasets. This code is not very nice, it gets much better in eager mode (TODO)
def dataset_to_numpy_util(training_dataset, validation_dataset, N):
  
  # get one batch from each: 10000 validation digits, N training digits
  batch_train_ds = training_dataset.unbatch().batch(N)
  
  # eager execution: loop through datasets normally
  if tf.executing_eagerly():
    for validation_digits, validation_labels in validation_dataset:
      validation_digits = validation_digits.numpy()
      validation_labels = validation_labels.numpy()
      break
    for training_digits, training_labels in batch_train_ds:
      training_digits = training_digits.numpy()
      training_labels = training_labels.numpy()
      break
    
  else:
    v_images, v_labels = tf.compat.v1.data.make_one_shot_iterator(validation_dataset).get_next()
    t_images, t_labels = tf.compat.v1.data.make_one_shot_iterator(batch_train_ds).get_next()
    # Run once, get one batch. Session.run returns numpy results
    with tf.Session() as ses:
      (validation_digits, validation_labels,
       training_digits, training_labels) = ses.run([v_images, v_labels, t_images, t_labels])
  
  # these were one-hot encoded in the dataset
  validation_labels = np.argmax(validation_labels, axis=1)
  training_labels = np.argmax(training_labels, axis=1)
  
  return (training_digits, training_labels,
          validation_digits, validation_labels)

# create digits from local fonts for testing
def create_digits_from_local_fonts(n):
  font_labels = []
  img = PIL.Image.new('LA', (28*n, 28), color = (0,255)) # format 'LA': black in channel 0, alpha in channel 1
  font1 = PIL.ImageFont.truetype(os.path.join(MATPLOTLIB_FONT_DIR, 'DejaVuSansMono-Oblique.ttf'), 25)
  font2 = PIL.ImageFont.truetype(os.path.join(MATPLOTLIB_FONT_DIR, 'STIXGeneral.ttf'), 25)
  d = PIL.ImageDraw.Draw(img)
  for i in range(n):
    font_labels.append(i%10)
    d.text((7+i*28,0 if i<10 else -4), str(i%10), fill=(255,255), font=font1 if i<10 else font2)
  font_digits = np.array(img.getdata(), np.float32)[:,0] / 255.0 # black in channel 0, alpha in channel 1 (discarded)
  font_digits = np.reshape(np.stack(np.split(np.reshape(font_digits, [28, 28*n]), n, axis=1), axis=0), [n, 28*28])
  return font_digits, font_labels

# utility to display a row of digits with their predictions
def display_digits(digits, predictions, labels, title, n):
  plt.figure(figsize=(13,3))
  digits = np.reshape(digits, [n, 28, 28])
  digits = np.swapaxes(digits, 0, 1)
  digits = np.reshape(digits, [28, 28*n])
  plt.yticks([])
  plt.xticks([28*x+14 for x in range(n)], predictions)
  for i,t in enumerate(plt.gca().xaxis.get_ticklabels()):
    if predictions[i] != labels[i]: t.set_color('red') # bad predictions in red
  plt.imshow(digits)
  plt.grid(None)
  plt.title(title)
  
# utility to display multiple rows of digits, sorted by unrecognized/recognized status
def display_top_unrecognized(digits, predictions, labels, n, lines):
  idx = np.argsort(predictions==labels) # sort order: unrecognized first
  for i in range(lines):
    display_digits(digits[idx][i*n:(i+1)*n], predictions[idx][i*n:(i+1)*n], labels[idx][i*n:(i+1)*n],
                   "{} sample validation digits out of {} with bad predictions in red and sorted first".format(n*lines, len(digits)) if i==0 else "", n)
    
# utility to display training and validation curves
def display_training_curves(training, validation, title, subplot):
  if subplot%10==1: # set up the subplots on the first call
    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
    plt.tight_layout()
  ax = plt.subplot(subplot)
  ax.grid(linewidth=1, color='white')
  ax.plot(training)
  ax.plot(validation)
  ax.set_title('model '+ title)
  ax.set_ylabel(title)
  ax.set_xlabel('epoch')
  ax.legend(['train', 'valid.'])

(you can double-click on collapsed cells to view the non-essential code inside)

Colab-only auth for this notebook and the TPU


In [ ]:
IS_COLAB_BACKEND = 'COLAB_GPU' in os.environ  # this is always set on Colab, the value is 0 or 1 depending on GPU presence
if IS_COLAB_BACKEND:
  from google.colab import auth
  # Authenticates the Colab machine and also the TPU using your
  # credentials so that they can access your private GCS buckets.
  auth.authenticate_user()

TPU or GPU detection


In [ ]:
# Detect hardware
try:
  tpu_resolver = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
except ValueError:
  tpu_resolver = None
  gpus = tf.config.experimental.list_logical_devices("GPU")

# Select appropriate distribution strategy
if tpu_resolver:
  tf.config.experimental_connect_to_cluster(tpu_resolver)
  tf.tpu.experimental.initialize_tpu_system(tpu_resolver)
  strategy = tf.distribute.experimental.TPUStrategy(tpu_resolver)
  print('Running on TPU ', tpu_resolver.cluster_spec().as_dict()['worker'])
elif len(gpus) > 1:
  strategy = tf.distribute.MirroredStrategy([gpu.name for gpu in gpus])
  print('Running on multiple GPUs ', [gpu.name for gpu in gpus])
elif len(gpus) == 1:
  strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
  print('Running on single GPU ', gpus[0].name)
else:
  strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
  print('Running on CPU')
  
print("Number of accelerators: ", strategy.num_replicas_in_sync)


WARNING:tensorflow:TPU system grpc://10.41.182.106:8470 has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.
WARNING:tensorflow:TPU system grpc://10.41.182.106:8470 has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.
INFO:tensorflow:Initializing the TPU system: grpc://10.41.182.106:8470
INFO:tensorflow:Initializing the TPU system: grpc://10.41.182.106:8470
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
Running on TPU  ['10.41.182.106:8470']
Number of accelerators:  8

Parameters


In [ ]:
BATCH_SIZE = 64 * strategy.num_replicas_in_sync # Gobal batch size.
# The global batch size will be automatically sharded across all
# replicas by the tf.data.Dataset API. A single TPU has 8 cores.
# The best practice is to scale the batch size by the number of
# replicas (cores). The learning rate should be increased as well.

LEARNING_RATE = 0.01
LEARNING_RATE_EXP_DECAY = 0.6 if strategy.num_replicas_in_sync == 1 else 0.7
# Learning rate computed later as LEARNING_RATE * LEARNING_RATE_EXP_DECAY**epoch
# 0.7 decay instead of 0.6 means a slower decay, i.e. a faster learnign rate.

training_images_file   = 'gs://mnist-public/train-images-idx3-ubyte'
training_labels_file   = 'gs://mnist-public/train-labels-idx1-ubyte'
validation_images_file = 'gs://mnist-public/t10k-images-idx3-ubyte'
validation_labels_file = 'gs://mnist-public/t10k-labels-idx1-ubyte'

tf.data.Dataset: parse files and prepare training and validation datasets

Please read the best practices for building input pipelines with tf.data.Dataset


In [ ]:
def read_label(tf_bytestring):
    label = tf.io.decode_raw(tf_bytestring, tf.uint8)
    label = tf.reshape(label, [])
    label = tf.one_hot(label, 10)
    return label
  
def read_image(tf_bytestring):
    image = tf.io.decode_raw(tf_bytestring, tf.uint8)
    image = tf.cast(image, tf.float32)/255.0
    image = tf.reshape(image, [28*28])
    return image
  
def load_dataset(image_file, label_file):
    imagedataset = tf.data.FixedLengthRecordDataset(image_file, 28*28, header_bytes=16)
    imagedataset = imagedataset.map(read_image, num_parallel_calls=16)
    labelsdataset = tf.data.FixedLengthRecordDataset(label_file, 1, header_bytes=8)
    labelsdataset = labelsdataset.map(read_label, num_parallel_calls=16)
    dataset = tf.data.Dataset.zip((imagedataset, labelsdataset))
    return dataset 
  
def get_training_dataset(image_file, label_file, batch_size):
    dataset = load_dataset(image_file, label_file)
    dataset = dataset.cache()  # this small dataset can be entirely cached in RAM
    dataset = dataset.shuffle(5000, reshuffle_each_iteration=True)
    dataset = dataset.repeat() # Mandatory for Keras for now
    dataset = dataset.batch(batch_size, drop_remainder=True) # drop_remainder is important on TPU, batch size must be fixed
    dataset = dataset.prefetch(-1)  # fetch next batches while training on the current one (-1: autotune prefetch buffer size)
    return dataset
  
def get_validation_dataset(image_file, label_file):
    dataset = load_dataset(image_file, label_file)
    dataset = dataset.cache() # this small dataset can be entirely cached in RAM
    dataset = dataset.batch(10000, drop_remainder=True) # 10000 items in eval dataset, all in one batch
    dataset = dataset.repeat() # Mandatory for Keras for now
    return dataset

# instantiate the datasets
training_dataset = get_training_dataset(training_images_file, training_labels_file, BATCH_SIZE)
validation_dataset = get_validation_dataset(validation_images_file, validation_labels_file)

Let's have a look at the data


In [ ]:
N = 24
(training_digits, training_labels,
 validation_digits, validation_labels) = dataset_to_numpy_util(training_dataset, validation_dataset, N)
display_digits(training_digits, training_labels, training_labels, "training digits and their labels", N)
display_digits(validation_digits[:N], validation_labels[:N], validation_labels[:N], "validation digits and their labels", N)
font_digits, font_labels = create_digits_from_local_fonts(N)


Keras model: 3 convolutional layers, 2 dense layers

If you are not sure what cross-entropy, dropout, softmax or batch-normalization mean, head here for a crash-course: Tensorflow and deep learning without a PhD


In [ ]:
# This model trains to 99.4% accuracy in 10 epochs (with a batch size of 64)  

def make_model():
    model = tf.keras.Sequential(
      [
        tf.keras.layers.Reshape(input_shape=(28*28,), target_shape=(28, 28, 1), name="image"),

        tf.keras.layers.Conv2D(filters=12, kernel_size=3, padding='same', use_bias=False), # no bias necessary before batch norm
        tf.keras.layers.BatchNormalization(scale=False, center=True), # no batch norm scaling necessary before "relu"
        tf.keras.layers.Activation('relu'), # activation after batch norm

        tf.keras.layers.Conv2D(filters=24, kernel_size=6, padding='same', use_bias=False, strides=2),
        tf.keras.layers.BatchNormalization(scale=False, center=True),
        tf.keras.layers.Activation('relu'),

        tf.keras.layers.Conv2D(filters=32, kernel_size=6, padding='same', use_bias=False, strides=2),
        tf.keras.layers.BatchNormalization(scale=False, center=True),
        tf.keras.layers.Activation('relu'),

        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(200, use_bias=False),
        tf.keras.layers.BatchNormalization(scale=False, center=True),
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.Dropout(0.4), # Dropout on dense layer only

        tf.keras.layers.Dense(10, activation='softmax')
      ])

    model.compile(optimizer='adam', # learning rate will be set by LearningRateScheduler
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    return model
    
with strategy.scope():
    model = make_model()

# print model layers
model.summary()

# set up learning rate decay
lr_decay = tf.keras.callbacks.LearningRateScheduler(
    lambda epoch: LEARNING_RATE * LEARNING_RATE_EXP_DECAY**epoch,
    verbose=True)


Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
image (Reshape)              (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 28, 28, 12)        108       
_________________________________________________________________
batch_normalization_4 (Batch (None, 28, 28, 12)        36        
_________________________________________________________________
activation_4 (Activation)    (None, 28, 28, 12)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 14, 14, 24)        10368     
_________________________________________________________________
batch_normalization_5 (Batch (None, 14, 14, 24)        72        
_________________________________________________________________
activation_5 (Activation)    (None, 14, 14, 24)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 7, 7, 32)          27648     
_________________________________________________________________
batch_normalization_6 (Batch (None, 7, 7, 32)          96        
_________________________________________________________________
activation_6 (Activation)    (None, 7, 7, 32)          0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 1568)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 200)               313600    
_________________________________________________________________
batch_normalization_7 (Batch (None, 200)               600       
_________________________________________________________________
activation_7 (Activation)    (None, 200)               0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 200)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 10)                2010      
=================================================================
Total params: 354,538
Trainable params: 354,002
Non-trainable params: 536
_________________________________________________________________

Train and validate the model


In [ ]:
EPOCHS = 10
steps_per_epoch = 60000//BATCH_SIZE  # 60,000 items in this dataset
print("Steps per epoch: ", steps_per_epoch)
  
history = model.fit(training_dataset,
                    steps_per_epoch=steps_per_epoch, epochs=EPOCHS,
                    callbacks=[lr_decay])

final_stats = model.evaluate(validation_dataset, steps=1)
print("Validation accuracy: ", final_stats[1])


Steps per epoch:  117

Epoch 00001: LearningRateScheduler reducing learning rate to 0.01.
Epoch 1/10
117/117 [==============================] - 1s 12ms/step - loss: 0.1829 - accuracy: 0.9450 - lr: 0.0100

Epoch 00002: LearningRateScheduler reducing learning rate to 0.006999999999999999.
Epoch 2/10
117/117 [==============================] - 1s 11ms/step - loss: 0.0501 - accuracy: 0.9851 - lr: 0.0070

Epoch 00003: LearningRateScheduler reducing learning rate to 0.0049.
Epoch 3/10
117/117 [==============================] - 1s 11ms/step - loss: 0.0331 - accuracy: 0.9898 - lr: 0.0049

Epoch 00004: LearningRateScheduler reducing learning rate to 0.003429999999999999.
Epoch 4/10
117/117 [==============================] - 1s 11ms/step - loss: 0.0242 - accuracy: 0.9923 - lr: 0.0034

Epoch 00005: LearningRateScheduler reducing learning rate to 0.0024009999999999995.
Epoch 5/10
117/117 [==============================] - 1s 10ms/step - loss: 0.0180 - accuracy: 0.9946 - lr: 0.0024

Epoch 00006: LearningRateScheduler reducing learning rate to 0.0016806999999999994.
Epoch 6/10
117/117 [==============================] - 1s 11ms/step - loss: 0.0134 - accuracy: 0.9959 - lr: 0.0017

Epoch 00007: LearningRateScheduler reducing learning rate to 0.0011764899999999997.
Epoch 7/10
117/117 [==============================] - 1s 11ms/step - loss: 0.0100 - accuracy: 0.9973 - lr: 0.0012

Epoch 00008: LearningRateScheduler reducing learning rate to 0.0008235429999999996.
Epoch 8/10
117/117 [==============================] - 1s 10ms/step - loss: 0.0085 - accuracy: 0.9975 - lr: 8.2354e-04

Epoch 00009: LearningRateScheduler reducing learning rate to 0.0005764800999999997.
Epoch 9/10
117/117 [==============================] - 1s 10ms/step - loss: 0.0071 - accuracy: 0.9982 - lr: 5.7648e-04

Epoch 00010: LearningRateScheduler reducing learning rate to 0.0004035360699999998.
Epoch 10/10
117/117 [==============================] - 1s 10ms/step - loss: 0.0066 - accuracy: 0.9983 - lr: 4.0354e-04
1/1 [==============================] - 0s 2ms/step - loss: 0.0188 - accuracy: 0.9945
Validation accuracy:  0.9944999814033508

Visualize predictions


In [ ]:
# recognize digits from local fonts
probabilities = model.predict(font_digits, steps=1)
predicted_labels = np.argmax(probabilities, axis=1)
display_digits(font_digits, predicted_labels, font_labels, "predictions from local fonts (bad predictions in red)", N)

# recognize validation digits
probabilities = model.predict(validation_digits, steps=1)
predicted_labels = np.argmax(probabilities, axis=1)
display_top_unrecognized(validation_digits, predicted_labels, validation_labels, N, 7)

Deploy the trained model to AI Platform model serving

Push your trained model to production on AI Platform for a serverless, autoscaled, REST API experience.

You will need a GCS (Google Cloud Storage) bucket and a GCP project for this. Models deployed on AI Platform autoscale to zero if not used. There will be no AI Platform charges after you are done testing. Google Cloud Storage incurs charges. Empty the bucket after deployment if you want to avoid these. Once the model is deployed, the bucket is not useful anymore.

Configuration


In [ ]:
PROJECT = "ml-writers" #@param {type:"string"}
BUCKET = "gs://your-bucket"  #@param {type:"string", default:"jddj"}
NEW_MODEL = True #@param {type:"boolean"}
MODEL_NAME = "mnist_test" #@param {type:"string"}
MODEL_VERSION = "v1" #@param {type:"string"}

assert PROJECT, 'For this part, you need a GCP project. Head to http://console.cloud.google.com/ and create one.'
assert re.search(r'gs://.+', BUCKET), 'For this part, you need a GCS bucket. Head to http://console.cloud.google.com/storage and create one.'

Export the model for serving from AI Platform


In [ ]:
# Wrap the model so that we can add a serving function
class ExportModel(tf.keras.Model):
  def __init__(self, model):
    super().__init__(self)
    self.model = model

  # The serving function performig data pre- and post-processing.
  # Pre-processing:  images are received in uint8 format converted
  #                  to float32 before being sent to through the model.
  # Post-processing: the Keras model outputs digit probabilities. We want
  #                  the detected digits. An additional tf.argmax is needed.
  # @tf.function turns the code in this function into a Tensorflow graph that
  # can be exported. This way, the model itself, as well as its pre- and post-
  # processing steps are exported in the SavedModel and deployed in a single step.
  @tf.function(input_signature=[tf.TensorSpec([None, 28*28], dtype=tf.uint8)])
  def my_serve(self, images):
    images = tf.cast(images, tf.float32)/255   # pre-processing
    probabilities = self.model(images)          # prediction from model
    classes = tf.argmax(probabilities, axis=-1) # post-processing
    return {'digits': classes}
    
# Must copy the model from TPU to CPU to be able to compose them.
restored_model = make_model()
restored_model.set_weights(model.get_weights()) # this copies the weights from TPU, does nothing on GPU

# create the ExportModel and export it to the Tensorflow standard SavedModel format
serving_model = ExportModel(restored_model)
export_path = os.path.join(BUCKET, 'keras_export', str(time.time()))
tf.keras.backend.set_learning_phase(0) # inference only
tf.saved_model.save(serving_model, export_path, signatures={'serving_default': serving_model.my_serve})

print("Model exported to: ", export_path)

# Note: in Tensorflow 2.0, it will also be possible to
# export to the SavedModel format using model.save():
# serving_model.save(export_path, save_format='tf')

In [ ]:
# saved_model_cli: a useful tool for troubleshooting SavedModels (the tool is part of the Tensorflow installation)
!saved_model_cli show --dir {export_path}
!saved_model_cli show --dir {export_path} --tag_set serve
!saved_model_cli show --dir {export_path} --tag_set serve --signature_def serving_default
# A note on naming:
# The "serve" tag set (i.e. serving functionality) is the only one exported by tf.saved_model.save
# All the other names are defined by the user in the fllowing lines of code:
#      def myserve(self, images):
#                        ******
#        return {'digits': classes}
#                 ******
#      tf.saved_model.save(..., signatures={'serving_default': serving_model.myserve})
#                                            ***************

Deploy the model

This uses the command-line interface. You can do the same thing through the AI Platform UI at https://console.cloud.google.com/mlengine/models


In [ ]:
# Create the model
if NEW_MODEL:
  !gcloud ai-platform models create {MODEL_NAME} --project={PROJECT} --regions=us-central1

In [ ]:
# Create a version of this model (you can add --async at the end of the line to make this call non blocking)
# Additional config flags are available: https://cloud.google.com/ml-engine/reference/rest/v1/projects.models.versions
# You can also deploy a model that is stored locally by providing a --staging-bucket=... parameter
!echo "Deployment takes a couple of minutes. You can watch your deployment here: https://console.cloud.google.com/mlengine/models/{MODEL_NAME}"
!gcloud ai-platform versions create {MODEL_VERSION} --model={MODEL_NAME} --origin={export_path} --project={PROJECT} --runtime-version=1.14 --python-version=3.5

Test the deployed model

Your model is now available as a REST API. Let us try to call it. The cells below use the "gcloud ml-engine" command line tool but any tool that can send a JSON payload to a REST endpoint will work.


In [ ]:
# prepare digits to send to online prediction endpoint
digits_float32 = np.concatenate((font_digits, validation_digits[:100-N])) # pixel values in [0.0, 1.0] float range
digits_uint8 = np.round(digits_float32*255).astype(np.uint8) # pixel values in [0, 255] int range
labels = np.concatenate((font_labels, validation_labels[:100-N]))
with open("digits.json", "w") as f:
  for digit in digits_uint8:
    # the format for AI Platform online predictions is: one JSON object per line
    data = json.dumps({"images": digit.tolist()})  # "images" because that was the name you gave this parametr in the serving funtion my_serve
    f.write(data+'\n')

In [ ]:
# Request online predictions from deployed model (REST API) using the "gcloud ml-engine" command line.
predictions = !gcloud ai-platform predict --model={MODEL_NAME} --json-instances digits.json --project={PROJECT} --version {MODEL_VERSION}
print(predictions)

predictions = np.stack([json.loads(p) for p in predictions[2:]]) # first elemet is the name of the output layer: drop it, parse the rest
display_top_unrecognized(digits_float32, predictions, labels, N, 100//N)

What's next

  • Learn about Cloud TPUs that Google designed and optimized specifically to speed up and scale up ML workloads for training and inference and to enable ML engineers and researchers to iterate more quickly.
  • Explore the range of Cloud TPU tutorials and Colabs to find other examples that can be used when implementing your ML project.

On Google Cloud Platform, in addition to GPUs and TPUs available on pre-configured deep learning VMs, you will find AutoML(beta) for training custom models without writing code and Cloud ML Engine which will allows you to run parallel trainings and hyperparameter tuning of your custom models on powerful distributed hardware.

License


author: Martin Gorner
twitter: @martin_gorner


Copyright 2019 Google LLC

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


This is not an official Google product but sample code provided for an educational purpose