MNIST on TPU (Tensor Processing Unit)
or GPU using tf.Keras and tf.data.Dataset

This sample trains an "MNIST" handwritten digit recognition model on a GPU or TPU backend using a Keras model. Data are handled using the tf.data.Datset API. This is a very simple sample provided for educational purposes. Do not expect outstanding TPU performance on a dataset as small as MNIST.

  Train on GPU or TPU  

  1. Select a GPU or TPU backend (Runtime > Change runtime type)
  2. Runtime > Run All
    (Watch out: the "Colab-only auth" cell requires user input.
    The "Deploy" part at the end requires cloud project and bucket configuration.)

  Deploy to AI Platform

At the bottom of this notebook you can deploy your trained model to AI Platform for a serverless, autoscaled, REST API experience. You will need a Google Cloud project and a GCS (Google Cloud Storage) bucket for this last part.

TPUs are located in Google Cloud, for optimal performance, they read data directly from Google Cloud Storage.

Imports


In [0]:
import os, re, time, json
import PIL.Image, PIL.ImageFont, PIL.ImageDraw
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
print("Tensorflow version " + tf.__version__)

In [0]:
#@title visualization utilities [RUN ME]
"""
This cell contains helper functions used for visualization
and downloads only. You can skip reading it. There is very
little useful Keras/Tensorflow code here.
"""

# Matplotlib config
plt.rc('image', cmap='gray_r')
plt.rc('grid', linewidth=0)
plt.rc('xtick', top=False, bottom=False, labelsize='large')
plt.rc('ytick', left=False, right=False, labelsize='large')
plt.rc('axes', facecolor='F8F8F8', titlesize="large", edgecolor='white')
plt.rc('text', color='a8151a')
plt.rc('figure', facecolor='F0F0F0')# Matplotlib fonts
MATPLOTLIB_FONT_DIR = os.path.join(os.path.dirname(plt.__file__), "mpl-data/fonts/ttf")

# pull a batch from the datasets. This code is not very nice, it gets much better in eager mode (TODO)
def dataset_to_numpy_util(training_dataset, validation_dataset, N):
  
  # get one batch from each: 10000 validation digits, N training digits
  batch_train_ds = training_dataset.apply(tf.data.experimental.unbatch()).batch(N)
  
  # eager execution: loop through datasets normally
  if tf.executing_eagerly():
    for validation_digits, validation_labels in validation_dataset:
      validation_digits = validation_digits.numpy()
      validation_labels = validation_labels.numpy()
      break
    for training_digits, training_labels in batch_train_ds:
      training_digits = training_digits.numpy()
      training_labels = training_labels.numpy()
      break
    
  else:
    v_images, v_labels = validation_dataset.make_one_shot_iterator().get_next()
    t_images, t_labels = batch_train_ds.make_one_shot_iterator().get_next()
    # Run once, get one batch. Session.run returns numpy results
    with tf.Session() as ses:
      (validation_digits, validation_labels,
       training_digits, training_labels) = ses.run([v_images, v_labels, t_images, t_labels])
  
  # these were one-hot encoded in the dataset
  validation_labels = np.argmax(validation_labels, axis=1)
  training_labels = np.argmax(training_labels, axis=1)
  
  return (training_digits, training_labels,
          validation_digits, validation_labels)

# create digits from local fonts for testing
def create_digits_from_local_fonts(n):
  font_labels = []
  img = PIL.Image.new('LA', (28*n, 28), color = (0,255)) # format 'LA': black in channel 0, alpha in channel 1
  font1 = PIL.ImageFont.truetype(os.path.join(MATPLOTLIB_FONT_DIR, 'DejaVuSansMono-Oblique.ttf'), 25)
  font2 = PIL.ImageFont.truetype(os.path.join(MATPLOTLIB_FONT_DIR, 'STIXGeneral.ttf'), 25)
  d = PIL.ImageDraw.Draw(img)
  for i in range(n):
    font_labels.append(i%10)
    d.text((7+i*28,0 if i<10 else -4), str(i%10), fill=(255,255), font=font1 if i<10 else font2)
  font_digits = np.array(img.getdata(), np.float32)[:,0] / 255.0 # black in channel 0, alpha in channel 1 (discarded)
  font_digits = np.reshape(np.stack(np.split(np.reshape(font_digits, [28, 28*n]), n, axis=1), axis=0), [n, 28*28])
  return font_digits, font_labels

# utility to display a row of digits with their predictions
def display_digits(digits, predictions, labels, title, n):
  plt.figure(figsize=(13,3))
  digits = np.reshape(digits, [n, 28, 28])
  digits = np.swapaxes(digits, 0, 1)
  digits = np.reshape(digits, [28, 28*n])
  plt.yticks([])
  plt.xticks([28*x+14 for x in range(n)], predictions)
  for i,t in enumerate(plt.gca().xaxis.get_ticklabels()):
    if predictions[i] != labels[i]: t.set_color('red') # bad predictions in red
  plt.imshow(digits)
  plt.grid(None)
  plt.title(title)
  
# utility to display multiple rows of digits, sorted by unrecognized/recognized status
def display_top_unrecognized(digits, predictions, labels, n, lines):
  idx = np.argsort(predictions==labels) # sort order: unrecognized first
  for i in range(lines):
    display_digits(digits[idx][i*n:(i+1)*n], predictions[idx][i*n:(i+1)*n], labels[idx][i*n:(i+1)*n],
                   "{} sample validation digits out of {} with bad predictions in red and sorted first".format(n*lines, len(digits)) if i==0 else "", n)
    
# utility to display training and validation curves
def display_training_curves(training, validation, title, subplot):
  if subplot%10==1: # set up the subplots on the first call
    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
    plt.tight_layout()
  ax = plt.subplot(subplot)
  ax.grid(linewidth=1, color='white')
  ax.plot(training)
  ax.plot(validation)
  ax.set_title('model '+ title)
  ax.set_ylabel(title)
  ax.set_xlabel('epoch')
  ax.legend(['train', 'valid.'])

(you can double-ckick on collapsed cells to view the non-essential code inside)

Colab-only auth for this notebook and the TPU


In [0]:
IS_COLAB_BACKEND = 'COLAB_GPU' in os.environ  # this is always set on Colab, the value is 0 or 1 depending on GPU presence
if IS_COLAB_BACKEND:
  from google.colab import auth
  # Authenticates the Colab machine and also the TPU using your
  # credentials so that they can access your private GCS buckets.
  auth.authenticate_user()

TPU or GPU detection


In [0]:
# Detect hardware
try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
except ValueError:
  tpu = None
  gpus = tf.config.experimental.list_logical_devices("GPU")
    
# Select appropriate distribution strategy
if tpu:
  tf.config.experimental_connect_to_cluster(tpu)
  tf.tpu.experimental.initialize_tpu_system(tpu)
  strategy = tf.distribute.experimental.TPUStrategy(tpu, steps_per_run=128) # Going back and forth between TPU and host is expensive. Better to run 128 batches on the TPU before reporting back.
  print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])  
elif len(gpus) > 1:
  strategy = tf.distribute.MirroredStrategy([gpu.name for gpu in gpus])
  print('Running on multiple GPUs ', [gpu.name for gpu in gpus])
elif len(gpus) == 1:
  strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
  print('Running on single GPU ', gpus[0].name)
else:
  strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
  print('Running on CPU')
print("Number of accelerators: ", strategy.num_replicas_in_sync)


Running on TPU  ['10.122.37.122:8470']
Number of accelerators:  8

Parameters


In [0]:
BATCH_SIZE = 64 * strategy.num_replicas_in_sync # Gobal batch size.
# The global batch size will be automatically sharded across all
# replicas by the tf.data.Dataset API. A single TPU has 8 cores.
# The best practice is to scale the batch size by the number of
# replicas (cores). The learning rate should be increased as well.

LEARNING_RATE = 0.01
LEARNING_RATE_EXP_DECAY = 0.6 if strategy.num_replicas_in_sync == 1 else 0.7
# Learning rate computed later as LEARNING_RATE * LEARNING_RATE_EXP_DECAY**epoch
# 0.7 decay instead of 0.6 means a slower decay, i.e. a faster learnign rate.

training_images_file   = 'gs://mnist-public/train-images-idx3-ubyte'
training_labels_file   = 'gs://mnist-public/train-labels-idx1-ubyte'
validation_images_file = 'gs://mnist-public/t10k-images-idx3-ubyte'
validation_labels_file = 'gs://mnist-public/t10k-labels-idx1-ubyte'

tf.data.Dataset: parse files and prepare training and validation datasets

Please read the best practices for building input pipelines with tf.data.Dataset


In [0]:
def read_label(tf_bytestring):
    label = tf.io.decode_raw(tf_bytestring, tf.uint8)
    label = tf.reshape(label, [])
    label = tf.one_hot(label, 10)
    return label
  
def read_image(tf_bytestring):
    image = tf.io.decode_raw(tf_bytestring, tf.uint8)
    image = tf.cast(image, tf.float32)/256.0
    image = tf.reshape(image, [28*28])
    return image
  
def load_dataset(image_file, label_file):
    imagedataset = tf.data.FixedLengthRecordDataset(image_file, 28*28, header_bytes=16)
    imagedataset = imagedataset.map(read_image, num_parallel_calls=16)
    labelsdataset = tf.data.FixedLengthRecordDataset(label_file, 1, header_bytes=8)
    labelsdataset = labelsdataset.map(read_label, num_parallel_calls=16)
    dataset = tf.data.Dataset.zip((imagedataset, labelsdataset))
    return dataset 
  
def get_training_dataset(image_file, label_file, batch_size):
    dataset = load_dataset(image_file, label_file)
    dataset = dataset.cache()  # this small dataset can be entirely cached in RAM
    dataset = dataset.shuffle(5000, reshuffle_each_iteration=True)
    dataset = dataset.repeat() # Mandatory for Keras for now
    dataset = dataset.batch(batch_size, drop_remainder=True) # drop_remainder is important on TPU, batch size must be fixed
    dataset = dataset.prefetch(-1)  # fetch next batches while training on the current one (-1: autotune prefetch buffer size)
    return dataset
  
def get_validation_dataset(image_file, label_file):
    dataset = load_dataset(image_file, label_file)
    dataset = dataset.cache() # this small dataset can be entirely cached in RAM
    dataset = dataset.batch(10000, drop_remainder=True) # 10000 items in eval dataset, all in one batch
    dataset = dataset.repeat() # Mandatory for Keras for now
    return dataset

# instantiate the datasets
training_dataset = get_training_dataset(training_images_file, training_labels_file, BATCH_SIZE)
validation_dataset = get_validation_dataset(validation_images_file, validation_labels_file)

Let's have a look at the data


In [0]:
N = 24
(training_digits, training_labels,
 validation_digits, validation_labels) = dataset_to_numpy_util(training_dataset, validation_dataset, N)
display_digits(training_digits, training_labels, training_labels, "training digits and their labels", N)
display_digits(validation_digits[:N], validation_labels[:N], validation_labels[:N], "validation digits and their labels", N)
font_digits, font_labels = create_digits_from_local_fonts(N)


Keras model: 3 convolutional layers, 2 dense layers

If you are not sure what cross-entropy, dropout, softmax or batch-normalization mean, head here for a crash-course: Tensorflow and deep learning without a PhD


In [0]:
# This model trains to 99.4% accuracy in 10 epochs (with a batch size of 64)  

def make_model():
    model = tf.keras.Sequential(
      [
        tf.keras.layers.Reshape(input_shape=(28*28,), target_shape=(28, 28, 1), name="image"),

        tf.keras.layers.Conv2D(filters=12, kernel_size=3, padding='same', use_bias=False), # no bias necessary before batch norm
        tf.keras.layers.BatchNormalization(scale=False, center=True), # no batch norm scaling necessary before "relu"
        tf.keras.layers.Activation('relu'), # activation after batch norm

        tf.keras.layers.Conv2D(filters=24, kernel_size=6, padding='same', use_bias=False, strides=2),
        tf.keras.layers.BatchNormalization(scale=False, center=True),
        tf.keras.layers.Activation('relu'),

        tf.keras.layers.Conv2D(filters=32, kernel_size=6, padding='same', use_bias=False, strides=2),
        tf.keras.layers.BatchNormalization(scale=False, center=True),
        tf.keras.layers.Activation('relu'),

        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(200, use_bias=False),
        tf.keras.layers.BatchNormalization(scale=False, center=True),
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.Dropout(0.4), # Dropout on dense layer only

        tf.keras.layers.Dense(10, activation='softmax')
      ])

    model.compile(optimizer='adam', # learning rate will be set by LearningRateScheduler
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    return model
    
with strategy.scope():
    model = make_model()

# print model layers
model.summary()

# set up learning rate decay
lr_decay = tf.keras.callbacks.LearningRateScheduler(
    lambda epoch: LEARNING_RATE * LEARNING_RATE_EXP_DECAY**epoch,
    verbose=True)


W0805 18:33:32.316657 140549664556928 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
image (Reshape)              (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 28, 28, 12)        108       
_________________________________________________________________
batch_normalization (BatchNo (None, 28, 28, 12)        36        
_________________________________________________________________
activation (Activation)      (None, 28, 28, 12)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 24)        10368     
_________________________________________________________________
batch_normalization_1 (Batch (None, 14, 14, 24)        72        
_________________________________________________________________
activation_1 (Activation)    (None, 14, 14, 24)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 7, 7, 32)          27648     
_________________________________________________________________
batch_normalization_2 (Batch (None, 7, 7, 32)          96        
_________________________________________________________________
activation_2 (Activation)    (None, 7, 7, 32)          0         
_________________________________________________________________
flatten (Flatten)            (None, 1568)              0         
_________________________________________________________________
dense (Dense)                (None, 200)               313600    
_________________________________________________________________
batch_normalization_3 (Batch (None, 200)               600       
_________________________________________________________________
activation_3 (Activation)    (None, 200)               0         
_________________________________________________________________
dropout (Dropout)            (None, 200)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                2010      
=================================================================
Total params: 354,538
Trainable params: 354,002
Non-trainable params: 536
_________________________________________________________________

Train and validate the model


In [0]:
EPOCHS = 10
steps_per_epoch = 60000//BATCH_SIZE  # 60,000 items in this dataset
print("Steps per epoch: ", steps_per_epoch)
  
# Little wrinkle: in the present version of Tensorfow (1.14), switching a TPU
# between training and evaluation is slow (approx. 10 sec). For small models,
# it is recommeneded to run a single eval at the end.
history = model.fit(training_dataset,
                    steps_per_epoch=steps_per_epoch, epochs=EPOCHS,
                    callbacks=[lr_decay])

final_stats = model.evaluate(validation_dataset, steps=1)
print("Validation accuracy: ", final_stats[1])


Steps per epoch:  117
W0805 18:33:41.028837 140549664556928 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training_distributed.py:411: Variable.load (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer Variable.assign which has equivalent behavior in 2.X.
Epoch 00001: LearningRateScheduler reducing learning rate to 0.01.
Epoch 1/10
117/117 [==============================] - 5s 40ms/step - loss: 0.0258 - acc: 0.9449

Epoch 00002: LearningRateScheduler reducing learning rate to 0.006999999999999999.
Epoch 2/10
117/117 [==============================] - 1s 9ms/step - loss: 0.0133 - acc: 0.9840

Epoch 00003: LearningRateScheduler reducing learning rate to 0.0049.
Epoch 3/10
117/117 [==============================] - 1s 8ms/step - loss: 0.0106 - acc: 0.9897

Epoch 00004: LearningRateScheduler reducing learning rate to 0.003429999999999999.
Epoch 4/10
117/117 [==============================] - 1s 8ms/step - loss: 0.0100 - acc: 0.9924

Epoch 00005: LearningRateScheduler reducing learning rate to 0.0024009999999999995.
Epoch 5/10
117/117 [==============================] - 1s 9ms/step - loss: 0.0094 - acc: 0.9946

Epoch 00006: LearningRateScheduler reducing learning rate to 0.0016806999999999994.
Epoch 6/10
117/117 [==============================] - 1s 8ms/step - loss: 0.0156 - acc: 0.9963

Epoch 00007: LearningRateScheduler reducing learning rate to 0.0011764899999999997.
Epoch 7/10
117/117 [==============================] - 1s 9ms/step - loss: 0.0134 - acc: 0.9972

Epoch 00008: LearningRateScheduler reducing learning rate to 0.0008235429999999996.
Epoch 8/10
117/117 [==============================] - 1s 9ms/step - loss: 0.0105 - acc: 0.9976

Epoch 00009: LearningRateScheduler reducing learning rate to 0.0005764800999999997.
Epoch 9/10
117/117 [==============================] - 1s 9ms/step - loss: 0.0041 - acc: 0.9979

Epoch 00010: LearningRateScheduler reducing learning rate to 0.0004035360699999998.
Epoch 10/10
117/117 [==============================] - 1s 9ms/step - loss: 0.0039 - acc: 0.9981
1/1 [==============================] - 4s 4s/step
1/1 [==============================] - 4s 4s/step
Validation accuracy:  0.99380004

Visualize predictions


In [0]:
# recognize digits from local fonts
probabilities = model.predict(font_digits, steps=1)
predicted_labels = np.argmax(probabilities, axis=1)
display_digits(font_digits, predicted_labels, font_labels, "predictions from local fonts (bad predictions in red)", N)

# recognize validation digits
probabilities = model.predict(validation_digits, steps=1)
predicted_labels = np.argmax(probabilities, axis=1)
display_top_unrecognized(validation_digits, predicted_labels, validation_labels, N, 7)


Deploy the trained model to AI Platform model serving

Push your trained model to production on AI Platform for a serverless, autoscaled, REST API experience.

You will need a GCS (Google Cloud Storage) bucket and a GCP project for this. Models deployed on AI Platform autoscale to zero if not used. There will be no AI Platform charges after you are done testing. Google Cloud Storage incurs charges. Empty the bucket after deployment if you want to avoid these. Once the model is deployed, the bucket is not useful anymore.

Configuration


In [0]:
PROJECT = "" #@param {type:"string"}
BUCKET = "gs://"  #@param {type:"string", default:"jddj"}
NEW_MODEL = True #@param {type:"boolean"}
MODEL_NAME = "mnist" #@param {type:"string"}
MODEL_VERSION = "v1" #@param {type:"string"}

assert PROJECT, 'For this part, you need a GCP project. Head to http://console.cloud.google.com/ and create one.'
assert re.search(r'gs://.+', BUCKET), 'For this part, you need a GCS bucket. Head to http://console.cloud.google.com/storage and create one.'

Export the model for serving from AI Platform


In [0]:
# Wrap the model so that we can add a serving function
class ExportModel(tf.keras.Model):
  def __init__(self, model):
    super().__init__(self)
    self.model = model

  # The serving function performig data pre- and post-processing.
  # Pre-processing:  images are received in uint8 format converted
  #                  to float32 before being sent to through the model.
  # Post-processing: the Keras model outputs digit probabilities. We want
  #                  the detected digits. An additional tf.argmax is needed.
  # @tf.function turns the code in this function into a Tensorflow graph that
  # can be exported. This way, the model itself, as well as its pre- and post-
  # processing steps are exported in the SavedModel and deployed in a single step.
  @tf.function(input_signature=[tf.TensorSpec([None, 28*28], dtype=tf.uint8)])
  def my_serve(self, images):
    images = tf.cast(images, tf.float32)/255   # pre-processing
    probabilities = self.model(images)          # prediction from model
    classes = tf.argmax(probabilities, axis=-1) # post-processing
    return {'digits': classes}
    
# Must copy the model from TPU to CPU to be able to compose them.
restored_model = make_model()
restored_model.set_weights(model.get_weights()) # this copies the weights from TPU, does nothing on GPU

# create the ExportModel and export it to the Tensorflow standard SavedModel format
serving_model = ExportModel(restored_model)
export_path = os.path.join(BUCKET, 'keras_export', str(time.time()))
tf.keras.backend.set_learning_phase(0) # inference only
tf.saved_model.save(serving_model, export_path, signatures={'serving_default': serving_model.my_serve})

print("Model exported to: ", export_path)

# Note: in Tensorflow 2.0, it will also be possible to
# export to the SavedModel format using model.save():
# serving_model.save(export_path, save_format='tf')


Model exported to:  gs://ml1-demo-martin/keras_export/1565030988.5267901

In [0]:
# saved_model_cli: a useful too for troubleshooting SavedModels (the tool is part of the Tensorflow installation)
!saved_model_cli show --dir {export_path}
!saved_model_cli show --dir {export_path} --tag_set serve
!saved_model_cli show --dir {export_path} --tag_set serve --signature_def serving_default
# A note on naming:
# The "serve" tag set (i.e. serving functionality) is the only one exported by tf.saved_model.save
# All the other names are defined by the user in the fllowing lines of code:
#      def myserve(self, images):
#                        ******
#        return {'digits': classes}
#                 ******
#      tf.saved_model.save(..., signatures={'serving_default': serving_model.myserve})
#                                            ***************


The given SavedModel contains the following tag-sets:
serve
The given SavedModel MetaGraphDef contains SignatureDefs with the following keys:
SignatureDef key: "__saved_model_init_op"
SignatureDef key: "serving_default"
The given SavedModel SignatureDef contains the following input(s):
  inputs['images'] tensor_info:
      dtype: DT_UINT8
      shape: (-1, 784)
      name: serving_default_images:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['digits'] tensor_info:
      dtype: DT_INT64
      shape: (-1)
      name: StatefulPartitionedCall:0
Method name is: tensorflow/serving/predict

Deploy the model

This uses the command-line interface. You can do the same thing through the AI Platform UI at https://console.cloud.google.com/mlengine/models


In [0]:
# Create the model
if NEW_MODEL:
  !gcloud ai-platform models create {MODEL_NAME} --project={PROJECT} --regions=us-central1

In [0]:
# Create a version of this model (you can add --async at the end of the line to make this call non blocking)
# Additional config flags are available: https://cloud.google.com/ml-engine/reference/rest/v1/projects.models.versions
# You can also deploy a model that is stored locally by providing a --staging-bucket=... parameter
!echo "Deployment takes a couple of minutes. You can watch your deployment here: https://console.cloud.google.com/mlengine/models/{MODEL_NAME}"
!gcloud ai-platform versions create {MODEL_VERSION} --model={MODEL_NAME} --origin={export_path} --project={PROJECT} --runtime-version=1.14 --python-version=3.5

Test the deployed model

Your model is now available as a REST API. Let us try to call it. The cells below use the "gcloud ml-engine" command line tool but any tool that can send a JSON payload to a REST endpoint will work.


In [0]:
# prepare digits to send to online prediction endpoint
digits_float32 = np.concatenate((font_digits, validation_digits[:100-N])) # pixel values in [0.0, 1.0] float range
digits_uint8 = np.round(digits_float32*255).astype(np.uint8) # pixel values in [0, 255] int range
labels = np.concatenate((font_labels, validation_labels[:100-N]))
with open("digits.json", "w") as f:
  for digit in digits_uint8:
    # the format for AI Platform online predictions is: one JSON object per line
    data = json.dumps({"images": digit.tolist()})  # "images" because that was the name you gave this parametr in the serving funtion my_serve
    f.write(data+'\n')

In [0]:
# Request online predictions from deployed model (REST API) using the "gcloud ml-engine" command line.
predictions = !gcloud ai-platform predict --model={MODEL_NAME} --json-instances digits.json --project={PROJECT} --version {MODEL_VERSION}
print(predictions)

predictions = np.stack([json.loads(p) for p in predictions[1:]]) # first elemet is the name of the output layer: drop it, parse the rest
display_top_unrecognized(digits_float32, predictions, labels, N, 100//N)


['DIGITS', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '7', '2', '1', '0', '4', '1', '4', '9', '5', '9', '0', '6', '9', '0', '1', '5', '9', '7', '3', '4', '9', '6', '6', '5', '4', '0', '7', '4', '0', '1', '3', '1', '3', '4', '7', '2', '7', '1', '2', '1', '1', '7', '4', '2', '3', '5', '1', '2', '4', '4', '6', '3', '5', '5', '6', '0', '4', '1', '9', '5', '7', '8', '9', '3', '7', '4', '6', '4', '3', '0', '7', '0', '2', '9', '1', '7']

License


author: Martin Gorner
twitter: @martin_gorner


Copyright 2019 Google LLC

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


This is not an official Google product but sample code provided for an educational purpose