Simple Classification Model using TPUEstimator on Colab TPU

This notebook demonstrates using Cloud TPUs to build a simple classification model using iris dataset to predict the species of the flower. This model is using 4 input features (SepalLength, SepalWidth, PetalLength, PetalWidth) to determine one of these flower species (Setosa, Versicolor, Virginica).

Note: You will need a GCP account and a GCS bucket for this notebook to run!

Imports


In [0]:
#  Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,0
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""An Example of a custom TPUEstimator for the Iris dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import os
import pandas as pd
import pprint
import tensorflow as tf
import time

Resolve TPU Address and authenticate GCS Bucket


In [0]:
use_tpu = True #@param {type:"boolean"}
bucket = '' #@param {type:"string"}

assert bucket, 'Must specify an existing GCS bucket name'
print('Using bucket: {}'.format(bucket))

if use_tpu:
    assert 'COLAB_TPU_ADDR' in os.environ, 'Missing TPU; did you request a TPU in Notebook Settings?'

MODEL_DIR = 'gs://{}/{}'.format(bucket, time.strftime('tpuestimator-dnn/%Y-%m-%d-%H-%M-%S'))
print('Using model dir: {}'.format(MODEL_DIR))

from google.colab import auth
auth.authenticate_user()

if 'COLAB_TPU_ADDR' in os.environ:
  TF_MASTER = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])

  # Upload credentials to TPU.
  with tf.Session(TF_MASTER) as sess:
    with open('/content/adc.json', 'r') as f:
      auth_info = json.load(f)
    tf.contrib.cloud.configure_gcs(sess, credentials=auth_info)
  # Now credentials are set for all future sessions on this TPU.
else:
  TF_MASTER=''

with tf.Session(TF_MASTER) as session:
  print ('List of devices:')
  pprint.pprint(session.list_devices())

FLAGS used as model params


In [0]:
# Model specific parameters

# TPU address
tpu_address = TF_MASTER

# Estimators model_dir
model_dir = MODEL_DIR

# This is the global batch size, not the per-shard batch.
batch_size = 128

# Total number of training steps.
train_steps = 1000

# Total number of evaluation steps. If '0', evaluation after training is skipped
eval_steps = 4

# Number of iterations per TPU training loop
iterations = 500

Get input data and Define input functions


In [0]:
TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv"
TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth',
                    'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']

PREDICTION_INPUT_DATA = {
    'SepalLength': [6.9, 5.1, 5.9],
    'SepalWidth': [3.1, 3.3, 3.0],
    'PetalLength': [5.4, 1.7, 4.2],
    'PetalWidth': [2.1, 0.5, 1.5],
}

PREDICTION_OUTPUT_DATA = ['Virginica', 'Setosa', 'Versicolor']

def maybe_download():
    train_path = tf.keras.utils.get_file(TRAIN_URL.split('/')[-1], TRAIN_URL)
    test_path = tf.keras.utils.get_file(TEST_URL.split('/')[-1], TEST_URL)

    return train_path, test_path

def load_data(y_name='Species'):
    """Returns the iris dataset as (train_x, train_y), (test_x, test_y)."""
    train_path, test_path = maybe_download()

    train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0, dtype={'SepalLength': pd.np.float32,
        'SepalWidth': pd.np.float32, 'PetalLength': pd.np.float32, 'PetalWidth': pd.np.float32, 'Species': pd.np.int32})
    train_x, train_y = train, train.pop(y_name)

    test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0, dtype={'SepalLength': pd.np.float32,
        'SepalWidth': pd.np.float32, 'PetalLength': pd.np.float32, 'PetalWidth': pd.np.float32, 'Species': pd.np.int32})
    test_x, test_y = test, test.pop(y_name)

    return (train_x, train_y), (test_x, test_y)


def train_input_fn(features, labels, batch_size):
    """An input function for training"""

    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    # Shuffle, repeat, and batch the examples.
    dataset = dataset.shuffle(1000).repeat()

    dataset = dataset.apply(
            tf.contrib.data.batch_and_drop_remainder(batch_size))

    # Return the dataset.
    return dataset


def eval_input_fn(features, labels, batch_size):
    """An input function for evaluation"""
    features=dict(features)
    inputs = (features, labels)

    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices(inputs)
    dataset = dataset.shuffle(1000).repeat()

    dataset = dataset.apply(
            tf.contrib.data.batch_and_drop_remainder(batch_size))

    # Return the dataset.
    return dataset


def predict_input_fn(features, batch_size):
    """An input function for prediction"""

    dataset = tf.data.Dataset.from_tensor_slices(features)
    dataset = dataset.batch(batch_size)
    return dataset

Model and metric function


In [0]:
def metric_fn(labels, logits):
    """Function to return metrics for evaluation"""

    predicted_classes = tf.argmax(logits, 1)
    accuracy = tf.metrics.accuracy(labels=labels,
                                   predictions=predicted_classes,
                                   name='acc_op')
    return {'accuracy': accuracy}


def my_model(features, labels, mode, params):
    """DNN with three hidden layers, and dropout of 0.1 probability."""

    # Create three fully connected layers each layer having a dropout
    # probability of 0.1.
    net = tf.feature_column.input_layer(features, params['feature_columns'])
    for units in params['hidden_units']:
        net = tf.layers.dense(net, units=units, activation=tf.nn.relu)

    # Compute logits (1 per class).
    logits = tf.layers.dense(net, params['n_classes'], activation=None)

    # Compute predictions.
    predicted_classes = tf.argmax(logits, 1)
    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'class_ids': predicted_classes[:, tf.newaxis],
            'probabilities': tf.nn.softmax(logits),
            'logits': logits,
        }
        return tf.contrib.tpu.TPUEstimatorSpec(mode, predictions=predictions)

    # Compute loss.
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels,
                                                  logits=logits)

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.contrib.tpu.TPUEstimatorSpec(
            mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits]))

    # Create training op.
    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
        if use_tpu:
            optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
        train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
        return tf.contrib.tpu.TPUEstimatorSpec(mode, loss=loss, train_op=train_op)

Main Function


In [0]:
def main():
    # Fetch the data
    (train_x, train_y), (test_x, test_y) = load_data()

    # Feature columns describe how to use the input.
    my_feature_columns = []
    for key in train_x.keys():
        my_feature_columns.append(tf.feature_column.numeric_column(key=key))

    # Resolve TPU cluster and runconfig for this.
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            tpu_address)

    run_config = tf.contrib.tpu.RunConfig(
            model_dir=model_dir,
            cluster=tpu_cluster_resolver,
            session_config=tf.ConfigProto(
                allow_soft_placement=True, log_device_placement=True),
            tpu_config=tf.contrib.tpu.TPUConfig(iterations),
            )

    # Build 2 hidden layer DNN with 10, 10 units respectively.
    classifier = tf.contrib.tpu.TPUEstimator(
        model_fn=my_model,
        use_tpu=use_tpu,
        train_batch_size=batch_size,
        eval_batch_size=batch_size,
        predict_batch_size=batch_size,
        config=run_config,
        params={
            'feature_columns': my_feature_columns,
            # Two hidden layers of 10 nodes each.
            'hidden_units': [10, 10],
            # The model must choose between 3 classes.
            'n_classes': 3,
            'use_tpu': use_tpu,
        })

    # Train the Model.
    classifier.train(
            input_fn = lambda params: train_input_fn(
                train_x, train_y, params["batch_size"]),
            max_steps=train_steps)

    # Evaluate the model.
    eval_result = classifier.evaluate(
        input_fn = lambda params: eval_input_fn(
            test_x, test_y, params["batch_size"]),
        steps=eval_steps)

    print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))

    # Generate predictions from the model
    predictions = classifier.predict(
        input_fn = lambda params: predict_input_fn(
            PREDICTION_INPUT_DATA, params["batch_size"]))

    for pred_dict, expec in zip(predictions, PREDICTION_OUTPUT_DATA):
        template = ('\nPrediction is "{}" ({:.1f}%), expected "{}"')

        class_id = pred_dict['class_ids'][0]
        probability = pred_dict['probabilities'][class_id]

        print(template.format(SPECIES[class_id],
                              100 * probability, expec))

Run It!!


In [9]:
main()


Downloading data from http://download.tensorflow.org/data/iris_training.csv
16384/2194 [================================================================================================================================================================================================================================] - 0s 0us/step
Downloading data from http://download.tensorflow.org/data/iris_test.csv
16384/573 [=========================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================] - 0s 0us/step
INFO:tensorflow:Using config: {'_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
log_device_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      value: "10.63.218.106:8470"
    }
  }
}
, '_keep_checkpoint_max': 5, '_task_type': 'worker', '_train_distribute': None, '_is_chief': True, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fd42de33410>, '_model_dir': 'gs://amangu-test-bucket/tpuestimator-dnn/2018-10-31-00-05-49', '_protocol': None, '_save_checkpoints_steps': None, '_keep_checkpoint_every_n_hours': 10000, '_service': None, '_num_ps_replicas': 0, '_tpu_config': TPUConfig(iterations_per_loop=500, num_shards=None, num_cores_per_replica=None, per_host_input_for_training=2, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None), '_tf_random_seed': None, '_save_summary_steps': 100, '_device_fn': None, '_cluster': <tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver.TPUClusterResolver object at 0x7fd432667b90>, '_experimental_distribute': None, '_num_worker_replicas': 1, '_task_id': 0, '_log_step_count_steps': None, '_evaluation_master': 'grpc://10.63.218.106:8470', '_eval_distribute': None, '_global_id_in_cluster': 0, '_master': 'grpc://10.63.218.106:8470'}
INFO:tensorflow:_TPUContext: eval_on_tpu True
INFO:tensorflow:Querying Tensorflow master (grpc://10.63.218.106:8470) for TPU system metadata.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 3204130715220914658)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 1757131222346187914)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_GPU:0, XLA_GPU, 17179869184, 14189073471542166552)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 18089629823860601296)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 3346547477748453737)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 6512962655626301676)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 1886142505182871901)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, 49656553673112525)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 1173330136390527977)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 8029696445054773714)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 867116612521745007)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 17179869184, 18125553718687585414)
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From <ipython-input-6-c4f6e12e0926>:48: batch_and_drop_remainder (from tensorflow.contrib.data.python.ops.batching) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.batch(..., drop_remainder=True)`.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:TPU job name worker
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into gs://amangu-test-bucket/tpuestimator-dnn/2018-10-31-00-05-49/model.ckpt.
INFO:tensorflow:Initialized dataset iterators in 0 seconds
INFO:tensorflow:Installing graceful shutdown hook.
INFO:tensorflow:Creating heartbeat manager for ['/job:tpu_worker/replica:0/task:0/device:CPU:0']
INFO:tensorflow:Configuring worker heartbeat: shutdown_mode: WAIT_FOR_COORDINATOR

INFO:tensorflow:Init TPU system
INFO:tensorflow:Initialized TPU in 7 seconds
INFO:tensorflow:Starting infeed thread controller.
INFO:tensorflow:Starting outfeed thread controller.
INFO:tensorflow:Enqueue next (500) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (500) batch(es) of data from outfeed.
INFO:tensorflow:loss = 0.12402147, step = 500
INFO:tensorflow:Enqueue next (500) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (500) batch(es) of data from outfeed.
INFO:tensorflow:loss = 0.07178347, step = 1000 (0.769 sec)
INFO:tensorflow:global_step/sec: 650.223
INFO:tensorflow:examples/sec: 83228.5
INFO:tensorflow:Saving checkpoints for 1000 into gs://amangu-test-bucket/tpuestimator-dnn/2018-10-31-00-05-49/model.ckpt.
INFO:tensorflow:Stop infeed thread controller
INFO:tensorflow:Shutting down InfeedController thread.
INFO:tensorflow:InfeedController received shutdown signal, stopping.
INFO:tensorflow:Infeed thread finished, shutting down.
INFO:tensorflow:infeed marked as finished
INFO:tensorflow:Stop output thread controller
INFO:tensorflow:Shutting down OutfeedController thread.
INFO:tensorflow:OutfeedController received shutdown signal, stopping.
INFO:tensorflow:Outfeed thread finished, shutting down.
INFO:tensorflow:outfeed marked as finished
INFO:tensorflow:Shutdown TPU system.
INFO:tensorflow:Loss for final step: 0.07178347.
INFO:tensorflow:training_loop marked as finished
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2018-10-31-00:06:43
INFO:tensorflow:TPU job name worker
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from gs://amangu-test-bucket/tpuestimator-dnn/2018-10-31-00-05-49/model.ckpt-1000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Init TPU system
INFO:tensorflow:Initialized TPU in 9 seconds
INFO:tensorflow:Starting infeed thread controller.
INFO:tensorflow:Starting outfeed thread controller.
INFO:tensorflow:Initialized dataset iterators in 0 seconds
INFO:tensorflow:Enqueue next (4) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (4) batch(es) of data from outfeed.
INFO:tensorflow:Evaluation [4/4]
INFO:tensorflow:Stop infeed thread controller
INFO:tensorflow:Shutting down InfeedController thread.
INFO:tensorflow:InfeedController received shutdown signal, stopping.
INFO:tensorflow:Infeed thread finished, shutting down.
INFO:tensorflow:infeed marked as finished
INFO:tensorflow:Stop output thread controller
INFO:tensorflow:Shutting down OutfeedController thread.
INFO:tensorflow:OutfeedController received shutdown signal, stopping.
INFO:tensorflow:Outfeed thread finished, shutting down.
INFO:tensorflow:outfeed marked as finished
INFO:tensorflow:Shutdown TPU system.
INFO:tensorflow:Finished evaluation at 2018-10-31-00:06:54
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.9667969, global_step = 1000, loss = 0.13143887
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: gs://amangu-test-bucket/tpuestimator-dnn/2018-10-31-00-05-49/model.ckpt-1000
INFO:tensorflow:evaluation_loop marked as finished

Test set accuracy: 0.967

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:TPU job name worker
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from gs://amangu-test-bucket/tpuestimator-dnn/2018-10-31-00-05-49/model.ckpt-1000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Init TPU system
INFO:tensorflow:Initialized TPU in 13 seconds
INFO:tensorflow:Starting infeed thread controller.
INFO:tensorflow:Starting outfeed thread controller.
INFO:tensorflow:Initialized dataset iterators in 0 seconds
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Stop infeed thread controller
INFO:tensorflow:Shutting down InfeedController thread.
INFO:tensorflow:InfeedController received shutdown signal, stopping.
INFO:tensorflow:Infeed thread finished, shutting down.
INFO:tensorflow:infeed marked as finished
INFO:tensorflow:Stop output thread controller
INFO:tensorflow:Shutting down OutfeedController thread.
INFO:tensorflow:OutfeedController received shutdown signal, stopping.
INFO:tensorflow:Outfeed thread finished, shutting down.
INFO:tensorflow:outfeed marked as finished
INFO:tensorflow:Shutdown TPU system.
INFO:tensorflow:prediction_loop marked as finished
INFO:tensorflow:prediction_loop marked as finished

Prediction is "Virginica" (92.4%), expected "Virginica"

Prediction is "Setosa" (98.3%), expected "Setosa"

Prediction is "Versicolor" (95.8%), expected "Versicolor"