Licensed under the Apache License, Version 2.0 (the "License");


In [0]:
# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

A Simple Classification Model using TPUEstimator on Colab TPU

Overview

This notebook shows how to use TPUEstimator to build a simple classification model. The model can train, evaluate, and generate predictions using Cloud TPUs. It uses the iris dataset to predict the species of the flower and also shows how to use your own data instead of using pre-loaded data. This model uses 4 input features (SepalLength, SepalWidth, PetalLength, PetalWidth) to determine one of these flower species (Setosa, Versicolor, Virginica).

The model trains for 20 epochs and completes in approximately 2 minutes.

This notebook is hosted on GitHub. To view it in its original repository, after opening the notebook, select File > View on GitHub.

Note: You will need a GCP account and a GCS bucket for this notebook to run!

Instructions

  Train on TPU

  1. Create a Cloud Storage bucket for your TensorBoard logs at http://console.cloud.google.com/storage and fill in the bucket name in the "Resolve TPU Address and authenticate GCS bucket" cell below.

  2. On the main menu, click Runtime and select Change runtime type. Set "TPU" as the hardware accelerator.

  3. Click Runtime again and select Runtime > Run All (Watch out: the "Resolve TPU address and authenticate GCS bucket" cell requires a bucket name). You can also run the cells manually with Shift-ENTER.

TPUs are located in Google Cloud, for optimal performance, they read data directly from Google Cloud Storage (GCS)

Learning objectives

In this Colab, you will learn how to:

  • Define and build a TPUEstimator model with 2 hidden layers and 10 nodes in each layer.
  • Run the TPUEstimator model to train, evaluate, and and generate predictions on Cloud TPU.

Data, model, and training

Imports


In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import os
import pandas as pd
import pprint
import tensorflow as tf
import time

Resolve TPU Address and authenticate GCS bucket


In [0]:
use_tpu = True #@param {type:"boolean"}
bucket = '' #@param {type:"string"}

assert bucket, 'Must specify an existing GCS bucket name'
print('Using bucket: {}'.format(bucket))

if use_tpu:
    assert 'COLAB_TPU_ADDR' in os.environ, 'Missing TPU; did you request a TPU in Notebook Settings?'

MODEL_DIR = 'gs://{}/{}'.format(bucket, time.strftime('tpuestimator-dnn/%Y-%m-%d-%H-%M-%S'))
print('Using model dir: {}'.format(MODEL_DIR))

from google.colab import auth
auth.authenticate_user()

if 'COLAB_TPU_ADDR' in os.environ:
  TF_MASTER = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])

  # Upload credentials to TPU.
  with tf.Session(TF_MASTER) as sess:
    with open('/content/adc.json', 'r') as f:
      auth_info = json.load(f)
    tf.contrib.cloud.configure_gcs(sess, credentials=auth_info)
  # Now credentials are set for all future sessions on this TPU.
else:
  TF_MASTER=''

with tf.Session(TF_MASTER) as session:
  print ('List of devices:')
  pprint.pprint(session.list_devices())

FLAGS used as model params


In [0]:
# Model specific parameters

# TPU address
tpu_address = TF_MASTER

# Estimators model_dir
model_dir = MODEL_DIR

# This is the global batch size, not the per-shard batch.
batch_size = 128

# Total number of training steps.
train_steps = 1000

# Total number of evaluation steps. If '0', evaluation after training is skipped
eval_steps = 4

# Number of iterations per TPU training loop
iterations = 500

Get input data and define input functions


In [0]:
TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv"
TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth',
                    'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']

PREDICTION_INPUT_DATA = {
    'SepalLength': [6.9, 5.1, 5.9],
    'SepalWidth': [3.1, 3.3, 3.0],
    'PetalLength': [5.4, 1.7, 4.2],
    'PetalWidth': [2.1, 0.5, 1.5],
}

PREDICTION_OUTPUT_DATA = ['Virginica', 'Setosa', 'Versicolor']

def maybe_download():
    train_path = tf.keras.utils.get_file(TRAIN_URL.split('/')[-1], TRAIN_URL)
    test_path = tf.keras.utils.get_file(TEST_URL.split('/')[-1], TEST_URL)

    return train_path, test_path

def load_data(y_name='Species'):
    """Returns the iris dataset as (train_x, train_y), (test_x, test_y)."""
    train_path, test_path = maybe_download()

    train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0, dtype={'SepalLength': pd.np.float32,
        'SepalWidth': pd.np.float32, 'PetalLength': pd.np.float32, 'PetalWidth': pd.np.float32, 'Species': pd.np.int32})
    train_x, train_y = train, train.pop(y_name)

    test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0, dtype={'SepalLength': pd.np.float32,
        'SepalWidth': pd.np.float32, 'PetalLength': pd.np.float32, 'PetalWidth': pd.np.float32, 'Species': pd.np.int32})
    test_x, test_y = test, test.pop(y_name)

    return (train_x, train_y), (test_x, test_y)


def train_input_fn(features, labels, batch_size):
    """An input function for training"""

    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    # Shuffle, repeat, and batch the examples.
    dataset = dataset.shuffle(1000).repeat()

    dataset = dataset.apply(
            tf.contrib.data.batch_and_drop_remainder(batch_size))

    # Return the dataset.
    return dataset


def eval_input_fn(features, labels, batch_size):
    """An input function for evaluation"""
    features=dict(features)
    inputs = (features, labels)

    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices(inputs)
    dataset = dataset.shuffle(1000).repeat()

    dataset = dataset.apply(
            tf.contrib.data.batch_and_drop_remainder(batch_size))

    # Return the dataset.
    return dataset


def predict_input_fn(features, batch_size):
    """An input function for prediction"""

    dataset = tf.data.Dataset.from_tensor_slices(features)
    dataset = dataset.batch(batch_size)
    return dataset

Model and metric function


In [0]:
def metric_fn(labels, logits):
    """Function to return metrics for evaluation"""

    predicted_classes = tf.argmax(logits, 1)
    accuracy = tf.metrics.accuracy(labels=labels,
                                   predictions=predicted_classes,
                                   name='acc_op')
    return {'accuracy': accuracy}


def my_model(features, labels, mode, params):
    """DNN with three hidden layers, and dropout of 0.1 probability."""

    # Create three fully connected layers each layer having a dropout
    # probability of 0.1.
    net = tf.feature_column.input_layer(features, params['feature_columns'])
    for units in params['hidden_units']:
        net = tf.layers.dense(net, units=units, activation=tf.nn.relu)

    # Compute logits (1 per class).
    logits = tf.layers.dense(net, params['n_classes'], activation=None)

    # Compute predictions.
    predicted_classes = tf.argmax(logits, 1)
    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'class_ids': predicted_classes[:, tf.newaxis],
            'probabilities': tf.nn.softmax(logits),
            'logits': logits,
        }
        return tf.contrib.tpu.TPUEstimatorSpec(mode, predictions=predictions)

    # Compute loss.
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels,
                                                  logits=logits)

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.contrib.tpu.TPUEstimatorSpec(
            mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits]))

    # Create training op.
    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
        if use_tpu:
            optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
        train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
        return tf.contrib.tpu.TPUEstimatorSpec(mode, loss=loss, train_op=train_op)

Main function


In [0]:
def main():
    # Fetch the data
    (train_x, train_y), (test_x, test_y) = load_data()

    # Feature columns describe how to use the input.
    my_feature_columns = []
    for key in train_x.keys():
        my_feature_columns.append(tf.feature_column.numeric_column(key=key))

    # Resolve TPU cluster and runconfig for this.
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            tpu_address)

    run_config = tf.contrib.tpu.RunConfig(
            model_dir=model_dir,
            cluster=tpu_cluster_resolver,
            session_config=tf.ConfigProto(
                allow_soft_placement=True, log_device_placement=True),
            tpu_config=tf.contrib.tpu.TPUConfig(iterations),
            )

    # Build 2 hidden layer DNN with 10, 10 units respectively.
    classifier = tf.contrib.tpu.TPUEstimator(
        model_fn=my_model,
        use_tpu=use_tpu,
        train_batch_size=batch_size,
        eval_batch_size=batch_size,
        predict_batch_size=batch_size,
        config=run_config,
        params={
            'feature_columns': my_feature_columns,
            # Two hidden layers of 10 nodes each.
            'hidden_units': [10, 10],
            # The model must choose between 3 classes.
            'n_classes': 3,
            'use_tpu': use_tpu,
        })

    # Train the Model.
    classifier.train(
            input_fn = lambda params: train_input_fn(
                train_x, train_y, params["batch_size"]),
            max_steps=train_steps)

    # Evaluate the model.
    eval_result = classifier.evaluate(
        input_fn = lambda params: eval_input_fn(
            test_x, test_y, params["batch_size"]),
        steps=eval_steps)

    print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))

    # Generate predictions from the model
    predictions = classifier.predict(
        input_fn = lambda params: predict_input_fn(
            PREDICTION_INPUT_DATA, params["batch_size"]))

    for pred_dict, expec in zip(predictions, PREDICTION_OUTPUT_DATA):
        template = ('\nPrediction is "{}" ({:.1f}%), expected "{}"')

        class_id = pred_dict['class_ids'][0]
        probability = pred_dict['probabilities'][class_id]

        print(template.format(SPECIES[class_id],
                              100 * probability, expec))

Run the model and view the predictions


In [0]:
main()

What's next

  • Learn how to run the same model using rather than TPUEstimator.
  • Learn about Cloud TPUs that Google designed and optimized specifically to speed up and scale up ML workloads for training and inference and to enable ML engineers and researchers to iterate more quickly.
  • Explore the range of Cloud TPU tutorials and Colabs to find other examples that can be used when implementing your ML project.

On Google Cloud Platform, in addition to GPUs and TPUs available on pre-configured deep learning VMs, you will find AutoML(beta) for training custom models without writing code and Cloud ML Engine which will allows you to run parallel trainings and hyperparameter tuning of your custom models on powerful distributed hardware.