Copyright 2019 The Sonnet Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
In [0]:
import sys
assert sys.version_info >= (3, 6), "Sonnet 2 requires Python >=3.6"
In [0]:
!pip install dm-sonnet tqdm
In [0]:
import sonnet as snt
import tensorflow as tf
import tensorflow_datasets as tfds
In [0]:
print("TensorFlow version: {}".format(tf.__version__))
print(" Sonnet version: {}".format(snt.__version__))
Finally lets take a quick look at the GPUs we have available:
In [0]:
!grep Model: /proc/driver/nvidia/gpus/*/information | awk '{$1="";print$0}'
In [8]:
physical_gpus = tf.config.experimental.list_physical_devices("GPU")
physical_gpus
Out[8]:
In [0]:
tf.config.experimental.set_virtual_device_configuration(
physical_gpus[0],
[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2000)] * 4
)
In [12]:
gpus = tf.config.experimental.list_logical_devices("GPU")
gpus
Out[12]:
When using Sonnet optimizers, we must use either Replicator
or TpuReplicator
from snt.distribute
, or we can use tf.distribute.OneDeviceStrategy
. Replicator
is equivalent to MirroredStrategy
and TpuReplicator
is equivalent to TPUStrategy
.
In [0]:
strategy = snt.distribute.Replicator(
["/device:GPU:{}".format(i) for i in range(4)],
tf.distribute.ReductionToOneDevice("GPU:0"))
In [0]:
# NOTE: This is the batch size across all GPUs.
batch_size = 100 * 4
def process_batch(images, labels):
images = tf.cast(images, dtype=tf.float32)
images = ((images / 255.) - .5) * 2.
return images, labels
def cifar10(split):
dataset = tfds.load("cifar10", split=split, as_supervised=True)
dataset = dataset.map(process_batch)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
dataset = dataset.cache()
return dataset
cifar10_train = cifar10("train").shuffle(10)
cifar10_test = cifar10("test")
Conveniently, there is a pre-built model in snt.nets
designed specifically for this dataset.
We must build our model and optimizer within the strategy scope, to ensure that any variables created are distributed correctly. Alternatively, we could enter the scope for the entire program using tf.distribute.experimental_set_strategy
.
In [0]:
learning_rate = 0.1
with strategy.scope():
model = snt.nets.Cifar10ConvNet()
optimizer = snt.optimizers.Momentum(learning_rate, 0.9)
The Sonnet optimizers are designed to be as clean and simple as possible. They do not contain any code to deal with distributed execution. It therefore requires a few additional lines of code.
We must aggregate the gradients calculated on the different devices. This can be done using ReplicaContext.all_reduce
.
Note that when using Replicator
/ TpuReplicator
it is the user's responsibility to ensure that the values remain identical in all replicas.
In [0]:
def step(images, labels):
"""Performs a single training step, returning the cross-entropy loss."""
with tf.GradientTape() as tape:
logits = model(images, is_training=True)["logits"]
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
logits=logits))
grads = tape.gradient(loss, model.trainable_variables)
# Aggregate the gradients from the full batch.
replica_ctx = tf.distribute.get_replica_context()
grads = replica_ctx.all_reduce("mean", grads)
optimizer.apply(grads, model.trainable_variables)
return loss
@tf.function
def train_step(images, labels):
per_replica_loss = strategy.run(step, args=(images, labels))
return strategy.reduce("sum", per_replica_loss, axis=None)
def train_epoch(dataset):
"""Performs one epoch of training, returning the mean cross-entropy loss."""
total_loss = 0.0
num_batches = 0
# Loop over the entire training set.
for images, labels in dataset:
total_loss += train_step(images, labels).numpy()
num_batches += 1
return total_loss / num_batches
cifar10_train_dist = strategy.experimental_distribute_dataset(cifar10_train)
for epoch in range(20):
print("Training epoch", epoch, "...", end=" ")
print("loss :=", train_epoch(cifar10_train_dist))
In [0]:
num_cifar10_test_examples = 10000
def is_predicted(images, labels):
logits = model(images, is_training=False)["logits"]
# The reduction over the batch happens in `strategy.reduce`, below.
return tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.int32)
cifar10_test_dist = strategy.experimental_distribute_dataset(cifar10_test)
@tf.function
def evaluate():
"""Returns the top-1 accuracy over the entire test set."""
total_correct = 0
for images, labels in cifar10_test_dist:
per_replica_correct = strategy.run(is_predicted, args=(images, labels))
total_correct += strategy.reduce("sum", per_replica_correct, axis=0)
return tf.cast(total_correct, tf.float32) / num_cifar10_test_examples
print("Testing...", end=" ")
print("top-1 accuracy =", evaluate().numpy())