Copyright 2019 The Sonnet Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


This tutorial assumes you have already completed (and understood!) the Sonnet 2 "Hello, world!" example (MLP on MNIST).

In this tutorial, we're going to scale things up with a bigger model and bigger dataset, and we're going to distribute the computation across multiple devices.


In [0]:
import sys
assert sys.version_info >= (3, 6), "Sonnet 2 requires Python >=3.6"

In [0]:
!pip install dm-sonnet tqdm

In [0]:
import sonnet as snt
import tensorflow as tf
import tensorflow_datasets as tfds

In [0]:
print("TensorFlow version: {}".format(tf.__version__))
print("    Sonnet version: {}".format(snt.__version__))

Finally lets take a quick look at the GPUs we have available:

In [0]:
!grep Model: /proc/driver/nvidia/gpus/*/information | awk '{$1="";print$0}'

Distribution strategy

We need a strategy to distribute our computation across several devices. Since Google Colab only provides a single GPU we'll split it into four virtual GPUs:

In [8]:
physical_gpus = tf.config.experimental.list_physical_devices("GPU")

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [0]:
    [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2000)] * 4

In [12]:
gpus = tf.config.experimental.list_logical_devices("GPU")

[LogicalDevice(name='/job:localhost/replica:0/task:0/device:GPU:0', device_type='GPU'),
 LogicalDevice(name='/job:localhost/replica:0/task:0/device:GPU:1', device_type='GPU'),
 LogicalDevice(name='/job:localhost/replica:0/task:0/device:GPU:2', device_type='GPU'),
 LogicalDevice(name='/job:localhost/replica:0/task:0/device:GPU:3', device_type='GPU')]

When using Sonnet optimizers, we must use either Replicator or TpuReplicator from snt.distribute, or we can use tf.distribute.OneDeviceStrategy. Replicator is equivalent to MirroredStrategy and TpuReplicator is equivalent to TPUStrategy.

In [0]:
strategy = snt.distribute.Replicator(
    ["/device:GPU:{}".format(i) for i in range(4)],


Basically the same as the MNIST example, but this time we're using CIFAR-10. CIFAR-10 contains 32x32 pixel color images in 10 different classes (airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks).

In [0]:
# NOTE: This is the batch size across all GPUs.
batch_size = 100 * 4

def process_batch(images, labels):
  images = tf.cast(images, dtype=tf.float32)
  images = ((images / 255.) - .5) * 2.
  return images, labels

def cifar10(split):
  dataset = tfds.load("cifar10", split=split, as_supervised=True)
  dataset =
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(
  dataset = dataset.cache()
  return dataset

cifar10_train = cifar10("train").shuffle(10)
cifar10_test = cifar10("test")

Model & Optimizer

Conveniently, there is a pre-built model in snt.nets designed specifically for this dataset.

We must build our model and optimizer within the strategy scope, to ensure that any variables created are distributed correctly. Alternatively, we could enter the scope for the entire program using tf.distribute.experimental_set_strategy.

In [0]:
learning_rate = 0.1

with strategy.scope():
  model = snt.nets.Cifar10ConvNet()
  optimizer = snt.optimizers.Momentum(learning_rate, 0.9)

Training the model

The Sonnet optimizers are designed to be as clean and simple as possible. They do not contain any code to deal with distributed execution. It therefore requires a few additional lines of code.

We must aggregate the gradients calculated on the different devices. This can be done using ReplicaContext.all_reduce.

Note that when using Replicator / TpuReplicator it is the user's responsibility to ensure that the values remain identical in all replicas.

In [0]:
def step(images, labels):
  """Performs a single training step, returning the cross-entropy loss."""
  with tf.GradientTape() as tape:
    logits = model(images, is_training=True)["logits"]
    loss = tf.reduce_mean(

  grads = tape.gradient(loss, model.trainable_variables)

  # Aggregate the gradients from the full batch.
  replica_ctx = tf.distribute.get_replica_context()
  grads = replica_ctx.all_reduce("mean", grads)

  optimizer.apply(grads, model.trainable_variables)
  return loss

def train_step(images, labels):
  per_replica_loss =, args=(images, labels))
  return strategy.reduce("sum", per_replica_loss, axis=None)

def train_epoch(dataset):
  """Performs one epoch of training, returning the mean cross-entropy loss."""
  total_loss = 0.0
  num_batches = 0

  # Loop over the entire training set.
  for images, labels in dataset:
    total_loss += train_step(images, labels).numpy()
    num_batches += 1

  return total_loss / num_batches

cifar10_train_dist = strategy.experimental_distribute_dataset(cifar10_train)

for epoch in range(20):
  print("Training epoch", epoch, "...", end=" ")
  print("loss :=", train_epoch(cifar10_train_dist))

Evaluating the model

Note the use of the axis parameter with strategy.reduce to reduce across the batch dimension.

In [0]:
num_cifar10_test_examples = 10000

def is_predicted(images, labels):
  logits = model(images, is_training=False)["logits"]
  # The reduction over the batch happens in `strategy.reduce`, below.
  return tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.int32)

cifar10_test_dist = strategy.experimental_distribute_dataset(cifar10_test)

def evaluate():
  """Returns the top-1 accuracy over the entire test set."""
  total_correct = 0

  for images, labels in cifar10_test_dist:
    per_replica_correct =, args=(images, labels))
    total_correct += strategy.reduce("sum", per_replica_correct, axis=0)

  return tf.cast(total_correct, tf.float32) / num_cifar10_test_examples

print("Testing...", end=" ")
print("top-1 accuracy =", evaluate().numpy())