In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
Note: While you can use Estimators with tf.distribute
API, it's recommended to use Keras with tf.distribute
, see multi-worker training with Keras. Estimator training with tf.distribute.Strategy
has limited support.
This tutorial demonstrates how tf.distribute.Strategy
can be used for distributed multi-worker training with tf.estimator
. If you write your code using tf.estimator
, and you're interested in scaling beyond a single machine with high performance, this tutorial is for you.
Before getting started, please read the distribution strategy guide. The multi-GPU training tutorial is also relevant, because this tutorial uses the same model.
In [0]:
import tensorflow_datasets as tfds
import tensorflow as tf
import os, json
This tutorial uses the MNIST dataset from TensorFlow Datasets. The code here is similar to the multi-GPU training tutorial with one key difference: when using Estimator for multi-worker training, it is necessary to shard the dataset by the number of workers to ensure model convergence. The input data is sharded by worker index, so that each worker processes 1/num_workers
distinct portions of the dataset.
In [0]:
def input_fn(mode, input_context=None):
datasets, info = tfds.load(name='mnist',
mnist_dataset = (datasets['train'] if mode == tf.estimator.ModeKeys.TRAIN else
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
if input_context:
mnist_dataset = mnist_dataset.shard(input_context.num_input_pipelines,
Another reasonable approach to achieve convergence would be to shuffle the dataset with distinct seeds at each worker.
One of the key differences in this tutorial (compared to the multi-GPU training tutorial) is the multi-worker setup. The TF_CONFIG
environment variable is the standard way to specify the cluster configuration to each worker that is part of the cluster.
There are two components of TF_CONFIG
: cluster
and task
. cluster
provides information about the entire cluster, namely the workers and parameter servers in the cluster. task
provides information about the current task. The first component cluster
is the same for all workers and parameter servers in the cluster, and the second component task
is different on each worker and parameter server and specifies its own type
and index
. In this example, the task type
is worker
and the task index
is 0
For illustration purposes, this tutorial shows how to set a TF_CONFIG
with 2 workers on localhost
. In practice, you would create multiple workers on an external IP address and port, and set TF_CONFIG
on each worker appropriately, i.e. modify the task index
Warning: Do not execute the following code in Colab. TensorFlow's runtime will attempt to create a gRPC server at the specified IP address and port, which will likely fail.
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': ["localhost:12345", "localhost:23456"]
'task': {'type': 'worker', 'index': 0}
Write the layers, the optimizer, and the loss function for training. This tutorial defines the model with Keras layers, similar to the multi-GPU training tutorial.
In [0]:
def model_fn(features, labels, mode):
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.Dense(64, activation='relu'),
logits = model(features, training=False)
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = {'logits': logits}
return tf.estimator.EstimatorSpec(labels=labels, predictions=predictions)
optimizer = tf.compat.v1.train.GradientDescentOptimizer(
loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(labels, logits)
loss = tf.reduce_sum(loss) * (1. / BATCH_SIZE)
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(mode, loss=loss)
return tf.estimator.EstimatorSpec(
loss, tf.compat.v1.train.get_or_create_global_step()))
Note: Although the learning rate is fixed in this example, in general it may be necessary to adjust the learning rate based on the global batch size.
To train the model, use an instance of tf.distribute.experimental.MultiWorkerMirroredStrategy
. MultiWorkerMirroredStrategy
creates copies of all variables in the model's layers on each device across all workers. It uses CollectiveOps
, a TensorFlow op for collective communication, to aggregate gradients and keep the variables in sync. The tf.distribute.Strategy
guide has more details about this strategy.
In [0]:
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
Next, specify the distribution strategy in the RunConfig
for the estimator, and train and evaluate by invoking tf.estimator.train_and_evaluate
. This tutorial distributes only the training by specifying the strategy via train_distribute
. It is also possible to distribute the evaluation via eval_distribute
In [0]:
config = tf.estimator.RunConfig(train_distribute=strategy)
classifier = tf.estimator.Estimator(
model_fn=model_fn, model_dir='/tmp/multiworker', config=config)
You now have a model and a multi-worker capable Estimator powered by tf.distribute.Strategy
. You can try the following techniques to optimize performance of multi-worker training:
if possible. The official ResNet model includes an example of how this can be done.Use collective communication: MultiWorkerMirroredStrategy
provides multiple collective communication implementations.
implements ring-based collectives using gRPC as the cross-host communication layer. NCCL
uses Nvidia's NCCL to implement collectives. AUTO
defers the choice to the runtime.The best choice of collective implementation depends upon the number and kind of GPUs, and the network interconnect in the cluster. To override the automatic choice, specify a valid value to the communication
parameter of MultiWorkerMirroredStrategy
's constructor, e.g. communication=tf.distribute.experimental.CollectiveCommunication.NCCL
Visit the Performance section in the guide to learn more about other strategies and tools you can use to optimize the performance of your TensorFlow models.