In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

Custom training with TPUs


This colab will take you through using tf.distribute.experimental.TPUStrategy. This is a new technique, a part of tf.distribute.Strategy, that allows users to easily switch their model to using TPUs. As part of this tutorial, you will create a Keras model and take it through a custom training loop (instead of calling fit method).

For full information on DistributionStrategy, please see the linked documentation above. For the purpose of this Colab, a DistributionStrategy is an object that indicates how TF should distribute your model across one or more workers or accelerators. The TPUStrategy provides information about your TPU configuration, and we can use it in combination with a Keras model to quickly train that model on the TPU!

If you're familiar with Keras models, you may be used to training them with the API. This API is supported using DistributionStrategy, but to show the flexibility of using DistributionStrategy, we will demonstrate a more customized approach in this notebook, using an explicit training loop.

Learning objectives

In this Colab, you will learn how to

  • Build a Keras model used to classify image
  • Set up custom loss, gradient, and training loop functions
  • Learn to use TPUStrategy so that you can easily port your models from CPU and GPU to TPU


  Train on TPU  

  1. On the main menu, click Runtime and select Change runtime type. Set "TPU" as the hardware accelerator.
  2. Click Runtime again and select Runtime > Run All. You can also run the cells manually with Shift-ENTER.

Imports and hardware and software verification

In [0]:
from __future__ import absolute_import, division, print_function, unicode_literals

# Import TensorFlow
import tensorflow as tf

# Helper libraries
import numpy as np
import os

assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

assert float('.'.join(tf.__version__.split('.')[:2])) >= 1.14, 'Make sure that Tensorflow version is at least 1.14'

In [0]:
TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']

Create model

Since you will be working with the MNIST data, which is a collection of 70,000 greyscale images representing digits, you want to be using a convolutional neural network to help us with the labeled image data.

In [0]:
def create_model(input_shape):
  """Creates a simple convolutional neural network model using the Keras API"""
  return tf.keras.Sequential([
      tf.keras.layers.Conv2D(28, kernel_size=(3, 3), activation='relu', input_shape=input_shape),
      tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
      tf.keras.layers.Dense(128, activation=tf.nn.relu),
      tf.keras.layers.Dense(10, activation=tf.nn.softmax),

Loss and gradient

When using, the gradient function is implemented for you, but for the custom training loop, an explicit loss and gradient function must be defined.

In [0]:
def loss(model, x, y):
  """Calculates the loss given an example (x, y)"""
  logits = model(x)
  return logits, tf.losses.sparse_softmax_cross_entropy(labels=y, logits=logits)

def grad(model, x, y):
  """Calculates the loss and the gradients given an example (x, y)"""
  logits, loss_value = loss(model, x, y)
  return logits, loss_value, tf.gradients(loss_value, model.trainable_variables)

Main function

Now that we have the model defined, we can load our data, setup the distribution strategy and run our training loop.

In [0]:

resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=TPU_WORKER)
strategy = tf.contrib.distribute.TPUStrategy(resolver)

# Load MNIST training and test data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# All MNIST examples are 28x28 pixel greyscale images (hence the 1
# for the number of channels).
input_shape = (28, 28, 1)

# Only specific data types are supported on the TPU, so it is important to
# pay attention to these.
# More information:
x_train = x_train.reshape(x_train.shape[0], *input_shape).astype(np.float32)
x_test = x_test.reshape(x_test.shape[0], *input_shape).astype(np.float32)
y_train, y_test = y_train.astype(np.int64), y_test.astype(np.int64)

# The batch size must be divisible by the number of workers (8 workers),
# so batch sizes of 8, 16, 24, 32, ... are supported.


train_steps_per_epoch = len(x_train) // BATCH_SIZE
test_steps_per_epoch = len(x_test) // BATCH_SIZE

Start by creating objects within the strategy's scope

A key concept to remember: model creation, optimizer creation, etc. must be written in the context of strategy.scope() in order to use TPUStrategy.

Also initialize metrics for the train and test sets.

In [0]:
with strategy.scope():
  model = create_model(input_shape)

  optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)

  training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
  training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      'training_accuracy', dtype=tf.float32)
  test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      'test_accuracy', dtype=tf.float32)

Define custom train and test steps

In [0]:
with strategy.scope():
  def train_step(inputs):
    """Each training step runs this custom function which calculates
    gradients and updates weights.
    x, y = inputs

    logits, loss_value, grads = grad(model, x, y)

    update_loss = training_loss.update_state(loss_value)
    update_accuracy = training_accuracy.update_state(y, logits)

    # Show that this is truly a custom training loop
    # Multiply all gradients by 2.
    grads = grads * 2

    update_vars = optimizer.apply_gradients(
        zip(grads, model.trainable_variables))

    with tf.control_dependencies([update_vars, update_loss, update_accuracy]):
      return tf.identity(loss_value)

  def test_step(inputs):
    """Each training step runs this custom function"""
    x, y = inputs

    logits, loss_value = loss(model, x, y)

    update_loss = test_loss.update_state(loss_value)
    update_accuracy = test_accuracy.update_state(y, logits)

    with tf.control_dependencies([update_loss, update_accuracy]):
      return tf.identity(loss_value)

Do the training

In order to make the reading a little bit easier, the full training loop calls two helper functions, run_train() and run_test(). The training runs for 5 epochs and trains to 96% accuracy.

In [0]:
def run_train():
  # Train
  while True:
    except tf.errors.OutOfRangeError:
  print('Train loss: {:0.4f}\t Train accuracy: {:0.4f}%'.format(, * 100))

def run_test():
  # Test
  while True:
    except tf.errors.OutOfRangeError:
  print('Test loss: {:0.4f}\t Test accuracy: {:0.4f}%'.format(, * 100))

In [0]:
with strategy.scope():
  training_loss_result = training_loss.result()
  training_accuracy_result = training_accuracy.result()
  test_loss_result = test_loss.result()
  test_accuracy_result = test_accuracy.result()
  config = tf.ConfigProto()
  config.allow_soft_placement = True
  cluster_spec = resolver.cluster_spec()
  if cluster_spec:

  print('Starting training...')

  # Do all the computations inside a Session (as opposed to doing eager mode)
  with tf.Session(target=resolver.master(), config=config) as session:
    all_variables = (
        tf.global_variables() + training_loss.variables +
        training_accuracy.variables + test_loss.variables +
    train_dataset =, y_train)).batch(BATCH_SIZE, drop_remainder=True)
    train_iterator = strategy.make_dataset_iterator(train_dataset)

    test_dataset =, y_test)).batch(BATCH_SIZE, drop_remainder=True)
    test_iterator = strategy.make_dataset_iterator(train_dataset)
    train_iterator_init = train_iterator.initialize()
    test_iterator_init = test_iterator.initialize()[v.initializer for v in all_variables])
    dist_train = strategy.experimental_run(train_step, train_iterator).values
    dist_test = strategy.experimental_run(test_step, test_iterator).values

    # Custom training loop
    for epoch in range(0, NUM_EPOCHS):
      print('Starting epoch {}'.format(epoch))



What's next

  • Learn about Cloud TPUs that Google designed and optimized specifically to speed up and scale up ML workloads for training and inference and to enable ML engineers and researchers to iterate more quickly.
  • Explore the range of Cloud TPU tutorials and Colabs to find other examples that can be used when implementing your ML project.

On Google Cloud Platform, in addition to GPUs and TPUs available on pre-configured deep learning VMs, you will find AutoML(beta) for training custom models without writing code and Cloud ML Engine which will allows you to run parallel trainings and hyperparameter tuning of your custom models on powerful distributed hardware.