In [ ]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
The tf.distribute APIs provide an easy way for users to scale their training from a single machine to multiple machines. When scaling their model, users also have to distribute their input across multiple devices. tf.distribute
provides APIs using which you can automatically distribute your input across devices.
This guide will show you the different ways in which you can create distributed dataset and iterators using tf.distribute
APIs. Additionally, the following topics will be covered:
tf.distribute.Strategy.experimental_distribute_dataset
and tf.distribute.Strategy.experimental_distribute_datasets_from_function
.tf.distribute.Strategy.experimental_distribute_dataset
/tf.distribute.Strategy.experimental_distribute_datasets_from_function
APIs and tf.data
APIs as well any limitations that users may come across in their usage.This guide does not cover usage of distributed input with Keras APIs.
To use tf.distribute
APIs to scale, it is recommended that users use tf.data.Dataset
to represent their input. tf.distribute
has been made to work efficiently with tf.data.Dataset
(for example, automatic prefetch of data onto each accelerator device) with performance optimizations being regularly incorporated into the implementation. If you have a use case for using something other than tf.data.Dataset
, please refer a later section in this guide.
In a non distributed training loop, users first create a tf.data.Dataset
instance and then iterate over the elements. For example:
In [ ]:
# Import TensorFlow
!pip install tf-nightly
import tensorflow as tf
# Helper libraries
import numpy as np
import os
print(tf.__version__)
In [ ]:
global_batch_size = 16
# Create a tf.data.Dataset object.
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(global_batch_size)
@tf.function
def train_step(inputs):
features, labels = inputs
return labels - 0.3 * features
# Iterate over the dataset using the for..in construct.
for inputs in dataset:
print(train_step(inputs))
To allow users to use tf.distribute
strategy with minimal changes to a user’s existing code, two APIs were introduced which would distribute a tf.data.Dataset
instance and return a distributed dataset object. A user could then iterate over this distributed dataset instance and train their model as before. Let us now look at the two APIs - tf.distribute.Strategy.experimental_distribute_dataset
and tf.distribute.Strategy.experimental_distribute_datasets_from_function
in more detail:
This API takes a tf.data.Dataset
instance as input and returns a tf.distribute.DistributedDataset
instance. You should batch the input dataset with a value that is equal to the global batch size. This global batch size is the number of samples that you want to process across all devices in 1 step. You can iterate over this distributed dataset in a Pythonic fashion or create an iterator using iter
. The returned object is not a tf.data.Dataset
instance and does not support any other APIs that transform or inspect the dataset in any way.
This is the recommended API if you don’t have specific ways in which you want to shard your input over different replicas.
In [ ]:
global_batch_size = 16
mirrored_strategy = tf.distribute.MirroredStrategy()
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(global_batch_size)
# Distribute input using the `experimental_distribute_dataset`.
dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)
# 1 global batch of data fed to the model in 1 step.
print(next(iter(dist_dataset)))
tf.distribute
rebatches the input tf.data.Dataset
instance with a new batch size that is equal to the global batch size divided by the number of replicas in sync. The number of replicas in sync is equal to the number of devices that are taking part in the gradient allreduce during training. When a user calls next
on the distributed iterator, a per replica batch size of data is returned on each replica. The rebatched dataset cardinality will always be a multiple of the number of replicas. Here are a couple of
examples:
tf.data.Dataset.range(6).batch(4, drop_remainder=False)
Without distribution:
Batch 1: [0, 1, 2, 3]
Batch 2: [4, 5]
With distribution over 2 replicas:
Batch 1: Replica 1:[0, 1] Replica 2:[2, 3]
Batch 2: Replica 2: [4] Replica 2: [5]
The last batch ([4, 5]) is split between 2 replicas.
tf.data.Dataset.range(4).batch(4)
Without distribution:
Batch 1: [[0], [1], [2], [3]]
With distribution over 5 replicas:
Batch 1: Replica 1: [0] Replica 2: [1] Replica 3: [2] Replica 4: [3] Replica 5: []
tf.data.Dataset.range(8).batch(4)
Without distribution:
Batch 1: [0, 1, 2, 3]
Batch 2: [4, 5, 6, 7]
With distribution over 3 replicas:
Batch 1: Replica 1: [0, 1] Replica 2: [2, 3] Replica 3: []
Batch 2: Replica 1: [4, 5] Replica 2: [6, 7] Replica 3: []
Note: The above examples only illustrate how a global batch is split on different replicas. It is not advisable to depend on the actual values that might end up on each replica as it can change depending on the implementation.
Rebatching the dataset has a space complexity that increases linearly with the number of replicas. This means that for the multi worker training use case the input pipeline can run into OOM errors.
tf.distribute
also autoshards the input dataset in multi worker training. Each dataset is created on the CPU device of the worker. Autosharding a dataset over a set of workers means that each worker is assigned a subset of the entire dataset (if the right tf.data.experimental.AutoShardPolicy
is set). This is to ensure that at each step, a global batch size of non overlapping dataset elements will be processed by each worker. Autosharding has a couple of different options that can be specified using tf.data.experimental.DistributeOptions
.
In [ ]:
dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(64).batch(16)
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
dataset = dataset.with_options(options)
There are three different options that you can set for the tf.data.experimental.AutoShardPolicy
:
tf.distribute
will then fall back to sharding by DATA. Note that if the input dataset is file-based but the number of files is less than the number of workers, an error will be raised.FILE: This is the option if you want to shard the input files over all the workers. If the number of files is less than the number of workers there will be an error raised. You should use this option if the number of input files is much larger than the number of workers and the data in the files is evenly distributed. The downside of this option is having idle workers if the data in the files is not evenly distributed. For example, let us distribute 2 files over 2 workers with 1 replica each. File 1 contains [0, 1, 2, 3, 4, 5] and File 2 contains [6, 7, 8, 9, 10, 11]. Let the total number of replicas in sync be 2 and global batch size be 4.
Batch 1 = Replica 1: [0, 1]
Batch 2 = Replica 1: [2, 3]
Batch 3 = Replica 1: [4]
Batch 4 = Replica 1: [5]
Batch 1 = Replica 2: [6, 7]
Batch 2 = Replica 2: [8, 9]
Batch 3 = Replica 2: [10]
Batch 4 = Replica 2: [11]
DATA: This will autoshard the elements across all the workers. Each of the workers will read the entire dataset and only process the shard assigned to it. All other shards will be discarded. This is generally used if the number of input files is less than the number of workers and you want better sharding of data across all workers. The downside is that the entire dataset will be read on each worker. For example, let us distribute 1 files over 2 workers. File 1 contains [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]. Let the total number of replicas in sync be 2.
Batch 1 = Replica 1: [0, 1]
Batch 2 = Replica 1: [4, 5]
Batch 3 = Replica 1: [8, 9]
Batch 1 = Replica 2: [2, 3]
Batch 2 = Replica 2: [6, 7]
Batch 3 = Replica 2: [10, 11]
OFF: If you turn off autosharding, each worker will process all the data. For example, let us distribute 1 files over 2 workers. File 1 contains [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]. Let the total number of replicas in sync be 2. Then each worker will see the following distribution:
Batch 1 = Replica 1: [0, 1]
Batch 2 = Replica 1: [2, 3]
Batch 3 = Replica 1: [4, 5]
Batch 4 = Replica 1: [6, 7]
Batch 5 = Replica 1: [8, 9]
Batch 6 = Replica 1: [10, 11]
Batch 1 = Replica 2: [0, 1]
Batch 2 = Replica 2: [2, 3]
Batch 3 = Replica 2: [4, 5]
Batch 4 = Replica 2: [6, 7]
Batch 5 = Replica 2: [8, 9]
Batch 6 = Replica 2: [10, 11]
This API takes an input function and returns a tf.distribute.DistributedDataset
instance. The input function that users pass in has a tf.distribute.InputContext
argument and should return a tf.data.Dataset
instance. With this API, tf.distribute
does not make any further changes to the user’s tf.data.Dataset
instance returned from the input function. It is the responsibility of the user to batch and shard the dataset. tf.distribute
calls the input function on the CPU device of each of the workers. Apart from allowing users to specify their own batching and sharding logic, this API also demonstrates better scalability and performance compared to tf.distribute.Strategy.experimental_distribute_dataset
when used for multi worker training.
In [ ]:
mirrored_strategy = tf.distribute.MirroredStrategy()
def dataset_fn(input_context):
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(64).batch(16)
dataset = dataset.shard(
input_context.num_input_pipelines, input_context.input_pipeline_id)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(2) # This prefetches 2 batches per device.
return dataset
dist_dataset = mirrored_strategy.experimental_distribute_datasets_from_function(dataset_fn)
The tf.data.Dataset
instance that is the return value of the input function should be batched using the per replica batch size. The per replica batch size is the global batch size divided by the number of replicas that are taking part in sync training. This is because tf.distribute
calls the input function on the CPU device of each of the workers. The dataset that is created on a given worker should be ready to use by all the replicas on that worker.
The tf.distribute.InputContext
object that is implicitly passed as an argument to the user’s input function is created by tf.distribute
under the hood. It has information about the number of workers, current worker id etc. This input function can handle sharding as per policies set by the user using these properties that are part of the tf.distribute.InputContext
object.
Note:
Both tf.distribute.Strategy.experimental_distribute_dataset
and tf.distribute.Strategy.experimental_distribute_datasets_from_function
return tf.distribute.DistributedDataset
instances that are not of type tf.data.Dataset
. You can iterate over these instances (as shown in the Distributed Iterators section) and use the element_spec
property.
Similar to non-distributed tf.data.Dataset
instances, you will need to create an iterator on the tf.distribute.DistributedDataset
instances to iterate over it and access the elements in the tf.distribute.DistributedDataset
.
The following are the ways in which you can create an tf.distribute.DistributedIterator
and use it to train your model:
You can use a user friendly Pythonic loop to iterate over the tf.distribute.DistributedDataset
. The elements returned from the tf.distribute.DistributedIterator
can be a single tf.Tensor
or a tf.distribute.DistributedValues
which contains a value per replica. Placing the loop inside a tf.function
will give a performance boost. However break
and return
are currently not supported if the loop is placed inside a tf.function
. We also don't support placing the loop inside a tf.function
when using multi-worker strategies such as tf.distribute.experimental.MultiWorkerMirroredStrategy
and tf.distribute.TPUStrategy
. Placing the loop inside tf.function
works for single worker tf.distribute.TPUStrategy
but not when using TPU pods.
In [ ]:
global_batch_size = 16
mirrored_strategy = tf.distribute.MirroredStrategy()
dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(100).batch(global_batch_size)
dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)
@tf.function
def train_step(inputs):
features, labels = inputs
return labels - 0.3 * features
for x in dist_dataset:
# train_step trains the model using the dataset elements
loss = mirrored_strategy.run(train_step, args=(x,))
print("Loss is ", loss)
iter
to create an explicit iteratorTo iterate over the elements in a tf.distribute.DistributedDataset
instance, you can create a tf.distribute.DistributedIterator
using the iter
API on it. With an explicit iterator, you can iterate for a fixed number of steps. In order to get the next element from an tf.distribute.DistributedIterator
instance dist_iterator
, you can call next(dist_iterator)
, dist_iterator.get_next()
, or dist_iterator.get_next_as_optional()
. The former two are essentially the same:
In [ ]:
num_epochs = 10
steps_per_epoch = 5
for epoch in range(num_epochs):
dist_iterator = iter(dist_dataset)
for step in range(steps_per_epoch):
# train_step trains the model using the dataset elements
loss = mirrored_strategy.run(train_step, args=(next(dist_iterator),))
# which is the same as
# loss = mirrored_strategy.run(train_step, args=(dist_iterator.get_next(),))
print("Loss is ", loss)
With next()
or tf.distribute.DistributedIterator.get_next()
, if the tf.distribute.DistributedIterator
has reached its end, an OutOfRange error will be thrown. The client can catch the error on python side and continue doing other work such as checkpointing and evaluation. However, this will not work if you are using a host training loop (i.e., run multiple steps per tf.function
), which looks like:
@tf.function
def train_fn(iterator):
for _ in tf.range(steps_per_loop):
strategy.run(step_fn, args=(next(iterator),))
train_fn
contains multiple steps by wrapping the step body inside a tf.range
. In this case, different iterations in the loop with no dependency could start in parallel, so an OutOfRange error can be triggered in later iterations before the computation of previous iterations finishes. Once an OutOfRange error is thrown, all the ops in the function will be terminated right away. If this is some case that you would like to avoid, an alternative that does not throw an OutOfRange error is tf.distribute.DistributedIterator.get_next_as_optional()
. get_next_as_optional
returns a tf.experimental.Optional
which contains the next element or no value if the tf.distribute.DistributedIterator
has reached to an end.
In [ ]:
# You can break the loop with get_next_as_optional by checking if the Optional contains value
global_batch_size = 4
steps_per_loop = 5
strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "CPU:0"])
dataset = tf.data.Dataset.range(9).batch(global_batch_size)
distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))
@tf.function
def train_fn(distributed_iterator):
for _ in tf.range(steps_per_loop):
optional_data = distributed_iterator.get_next_as_optional()
if not optional_data.has_value():
break
per_replica_results = strategy.run(lambda x:x, args=(optional_data.get_value(),))
tf.print(strategy.experimental_local_results(per_replica_results))
train_fn(distributed_iterator)
If you pass the elements of a distributed dataset to a tf.function
and want a tf.TypeSpec
guarantee, you can specify the input_signature
argument of the tf.function
. The output of a distributed dataset is tf.distribute.DistributedValues
which can represent the input to a single device or multiple devices. To get the tf.TypeSpec
corresponding to this distributed value you can use the element_spec
property of the distributed dataset or distributed iterator object.
In [ ]:
global_batch_size = 16
epochs = 5
steps_per_epoch = 5
mirrored_strategy = tf.distribute.MirroredStrategy()
dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(100).batch(global_batch_size)
dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)
@tf.function(input_signature=[dist_dataset.element_spec])
def train_step(per_replica_inputs):
def step_fn(inputs):
return 2 * inputs
return mirrored_strategy.run(step_fn, args=(per_replica_inputs,))
for _ in range(epochs):
iterator = iter(dist_dataset)
for _ in range(steps_per_epoch):
output = train_step(next(iterator))
tf.print(output)
Partial batches are encountered when tf.data.Dataset
instances that users create may contain batch sizes that are not evenly divisible by the number of replicas or when the cardinality of the dataset instance is not divisible by the batch size. This means that when the dataset is distributed over multiple replicas, the next
call on some iterators will result in an OutOfRangeError. To handle this use case, tf.distribute
returns dummy batches of batch size 0 on replicas that do not have any more data to process.
For the single worker case, if data is not returned by the next
call on the iterator, dummy batches of 0 batch size are created and used along with the real data in the dataset. In the case of partial batches, the last global batch of data will contain real data alongside dummy batches of data. The stopping condition for processing data now checks if any of the replicas have data. If there is no data on any of the replicas, an OutOfRange error is thrown.
For the multi worker case, the boolean value representing presence of data on each of the workers is aggregated using cross replica communication and this is used to identify if all the workers have finished processing the distributed dataset. Since this involves cross worker communication there is some performance penalty involved. Currently this is supported for all strategies except tf.distribute.experimental.MultiWorkerMirroredStrategy
.
Partial batches are supported for all strategies except tf.distribute.experimental.MultiWorkerMirroredStrategy
. Make sure to iterate over a specific number of steps for which enough data exists on all replicas when using tf.distribute.experimental.MultiWorkerMirroredStrategy
Stateful dataset transformations are currently not supported with tf.distribute
and any stateful ops that the dataset may have are currently ignored. For example, if your dataset has a map_fn
that uses tf.random.uniform
to rotate an image, then you have a dataset graph that depends on state (i.e the random seed) on the local machine where the python process is being executed.
tf.distribute.experimental_distribute_dataset
or tf.distribute.experimental_distribute_datasets_from_function
is not guaranteed. This is typically required if you are using tf.distribute
to scale prediction. You can however insert an index for each element in the batch and order outputs accordingly. The following snippet is an example of how to order outputs.Note: tf.distribute.MirroredStrategy()
is used here for the sake of convenience. We only need to reorder inputs when we are using multiple workers and tf.distribute.MirroredStrategy
is used to distribute training on a single worker.
In [ ]:
mirrored_strategy = tf.distribute.MirroredStrategy()
dataset_size = 24
batch_size = 6
dataset = tf.data.Dataset.range(dataset_size).enumerate().batch(batch_size)
dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)
def predict(index, inputs):
outputs = 2 * inputs
return index, outputs
result = {}
for index, inputs in dist_dataset:
output_index, outputs = mirrored_strategy.run(predict, args=(index, inputs))
indices = list(mirrored_strategy.experimental_local_results(output_index))
rindices = []
for a in indices:
rindices.extend(a.numpy())
outputs = list(mirrored_strategy.experimental_local_results(outputs))
routputs = []
for a in outputs:
routputs.extend(a.numpy())
for i, value in zip(rindices, routputs):
result[i] = value
print(result)
Sometimes users cannot use a tf.data.Dataset
to represent their input and subsequently
the above mentioned APIs to distribute the dataset to multiple devices.
In such cases you can use raw tensors or inputs from a generator.
strategy.run
accepts tf.distribute.DistributedValues
which is the output of
next(iterator)
. To pass the tensor values, use
experimental_distribute_values_from_function
to construct
tf.distribute.DistributedValues
from raw tensors.
In [ ]:
mirrored_strategy = tf.distribute.MirroredStrategy()
worker_devices = mirrored_strategy.extended.worker_devices
def value_fn(ctx):
return tf.constant(1.0)
distributed_values = mirrored_strategy.experimental_distribute_values_from_function(value_fn)
for _ in range(4):
result = mirrored_strategy.run(lambda x:x, args=(distributed_values,))
print(result)
If you have a generator function that you want to use, you can create a tf.data.Dataset
instance using the from_generator
API.
Note: This is currently not supported for tf.distribute.TPUStrategy
.
In [ ]:
mirrored_strategy = tf.distribute.MirroredStrategy()
def input_gen():
while True:
yield np.random.rand(4)
# use Dataset.from_generator
dataset = tf.data.Dataset.from_generator(
input_gen, output_types=(tf.float32), output_shapes=tf.TensorShape([4]))
dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)
iterator = iter(dist_dataset)
for _ in range(4):
mirrored_strategy.run(lambda x:x, args=(next(iterator),))