In [ ]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
NOTE: This colab has been verified to work with the latest released version of the tensorflow_federated
pip package, but the Tensorflow Federated project is still in pre-release development and may not work on master
.
In this tutorial, we use the EMNIST dataset to demonstrate how to enable lossy compression algorithms to reduce communication cost in the Federated Averaging algorithm using the tff.learning.build_federated_averaging_process
API and the tensor_encoding API. For more details on the Federated Averaging algorithm, see the paper Communication-Efficient Learning of Deep Networks from Decentralized Data.
Before we start, please run the following to make sure that your environment is correctly setup. If you don't see a greeting, please refer to the Installation guide for instructions.
In [ ]:
#@test {"skip": true}
!pip install --quiet --upgrade tensorflow_federated
!pip install --quiet --upgrade tensorflow-model-optimization
%load_ext tensorboard
In [ ]:
import functools
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
from tensorflow_model_optimization.python.core.internal import tensor_encoding as te
Verify if TFF is working.
In [ ]:
@tff.federated_computation
def hello_world():
return 'Hello, World!'
hello_world()
Out[ ]:
In this section we load and preprocess the EMNIST dataset included in TFF. Please check out Federated Learning for Image Classification tutorial for more details about EMNIST dataset.
In [ ]:
# This value only applies to EMNIST dataset, consider choosing appropriate
# values if switching to other datasets.
MAX_CLIENT_DATASET_SIZE = 418
CLIENT_EPOCHS_PER_ROUND = 1
CLIENT_BATCH_SIZE = 20
TEST_BATCH_SIZE = 500
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
only_digits=True)
def reshape_emnist_element(element):
return (tf.expand_dims(element['pixels'], axis=-1), element['label'])
def preprocess_train_dataset(dataset):
"""Preprocessing function for the EMNIST training dataset."""
return (dataset
# Shuffle according to the largest client dataset
.shuffle(buffer_size=MAX_CLIENT_DATASET_SIZE)
# Repeat to do multiple local epochs
.repeat(CLIENT_EPOCHS_PER_ROUND)
# Batch to a fixed client batch size
.batch(CLIENT_BATCH_SIZE, drop_remainder=False)
# Preprocessing step
.map(reshape_emnist_element))
emnist_train = emnist_train.preprocess(preprocess_train_dataset)
Here we define a keras model based on the orginial FedAvg CNN, and then wrap the keras model in an instance of tff.learning.Model so that it can be consumed by TFF.
Note that we'll need a function which produces a model instead of simply a model directly. In addition, the function cannot just capture a pre-constructed model, it must create the model in the context that it is called. The reason is that TFF is designed to go to devices, and needs control over when resources are constructed so that they can be captured and packaged up.
In [ ]:
def create_original_fedavg_cnn_model(only_digits=True):
"""The CNN model used in https://arxiv.org/abs/1602.05629."""
data_format = 'channels_last'
max_pool = functools.partial(
tf.keras.layers.MaxPooling2D,
pool_size=(2, 2),
padding='same',
data_format=data_format)
conv2d = functools.partial(
tf.keras.layers.Conv2D,
kernel_size=5,
padding='same',
data_format=data_format,
activation=tf.nn.relu)
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
conv2d(filters=32),
max_pool(),
conv2d(filters=64),
max_pool(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dense(10 if only_digits else 62),
tf.keras.layers.Softmax(),
])
return model
# Gets the type information of the input data. TFF is a strongly typed
# functional programming framework, and needs type information about inputs to
# the model.
input_spec = emnist_train.create_tf_dataset_for_client(
emnist_train.client_ids[0]).element_spec
def tff_model_fn():
keras_model = create_original_fedavg_cnn_model()
return tff.learning.from_keras_model(
keras_model=keras_model,
input_spec=input_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
Now we are ready to construct a Federated Averaging algorithm and train the defined model on EMNIST dataset.
First we need to build a Federated Averaging algorithm using the tff.learning.build_federated_averaging_process API.
In [ ]:
federated_averaging = tff.learning.build_federated_averaging_process(
model_fn=tff_model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
Now let's run the Federated Averaging algorithm. The execution of a Federated Learning algorithm from the perspective of TFF looks like this:
For more details, please see Custom Federated Algorithms, Part 2: Implementing Federated Averaging tutorial.
Training metrics are written to the Tensorboard directory for displaying after the training.
In [ ]:
#@title Load utility functions
def format_size(size):
"""A helper function for creating a human-readable size."""
size = float(size)
for unit in ['B','KiB','MiB','GiB']:
if size < 1024.0:
return "{size:3.2f}{unit}".format(size=size, unit=unit)
size /= 1024.0
return "{size:.2f}{unit}".format(size=size, unit='TiB')
def set_sizing_environment():
"""Creates an environment that contains sizing information."""
# Creates a sizing executor factory to output communication cost
# after the training finishes. Note that sizing executor only provides an
# estimate (not exact) of communication cost, and doesn't capture cases like
# compression of over-the-wire representations. However, it's perfect for
# demonstrating the effect of compression in this tutorial.
sizing_factory = tff.framework.sizing_executor_factory()
# TFF has a modular runtime you can configure yourself for various
# environments and purposes, and this example just shows how to configure one
# part of it to report the size of things.
context = tff.framework.ExecutionContext(executor_fn=sizing_factory)
tff.framework.set_default_context(context)
return sizing_factory
In [ ]:
def train(federated_averaging_process, num_rounds, num_clients_per_round, summary_writer):
"""Trains the federated averaging process and output metrics."""
# Create a environment to get communication cost.
environment = set_sizing_environment()
# Initialize the Federated Averaging algorithm to get the initial server state.
state = federated_averaging_process.initialize()
with summary_writer.as_default():
for round_num in range(num_rounds):
# Sample the clients parcitipated in this round.
sampled_clients = np.random.choice(
emnist_train.client_ids,
size=num_clients_per_round,
replace=False)
# Create a list of `tf.Dataset` instances from the data of sampled clients.
sampled_train_data = [
emnist_train.create_tf_dataset_for_client(client)
for client in sampled_clients
]
# Round one round of the algorithm based on the server state and client data
# and output the new state and metrics.
state, metrics = federated_averaging_process.next(state, sampled_train_data)
# For more about size_info, please see https://www.tensorflow.org/federated/api_docs/python/tff/framework/SizeInfo
size_info = environment.get_size_info()
broadcasted_bits = size_info.broadcast_bits[-1]
aggregated_bits = size_info.aggregate_bits[-1]
print('round {:2d}, metrics={}, broadcasted_bits={}, aggregated_bits={}'.format(round_num, metrics, format_size(broadcasted_bits), format_size(aggregated_bits)))
# Add metrics to Tensorboard.
for name, value in metrics['train']._asdict().items():
tf.summary.scalar(name, value, step=round_num)
# Add broadcasted and aggregated data size to Tensorboard.
tf.summary.scalar('cumulative_broadcasted_bits', broadcasted_bits, step=round_num)
tf.summary.scalar('cumulative_aggregated_bits', aggregated_bits, step=round_num)
summary_writer.flush()
In [ ]:
# Clean the log directory to avoid conflicts.
!rm -R /tmp/logs/scalars/*
# Set up the log directory and writer for Tensorboard.
logdir = "/tmp/logs/scalars/original/"
summary_writer = tf.summary.create_file_writer(logdir)
train(federated_averaging_process=federated_averaging, num_rounds=10,
num_clients_per_round=10, summary_writer=summary_writer)
Start TensorBoard with the root log directory specified above to display the training metrics. It can take a few seconds for the data to load. Except for Loss and Accuracy, we also output the amount of broadcasted and aggregated data. Broadcasted data refers to tensors the server pushes to each client while aggregated data refers to tensors each client returns to the server.
In [ ]:
%tensorboard --logdir /tmp/logs/scalars/ --port=0
Now let's implement function to use lossy compression algorithms on broadcasted data and aggregated data using the tensor_encoding API.
First, we define two functions:
broadcast_encoder_fn
which creates an instance of te.core.SimpleEncoder to encode tensors or variables in server to client communication (Broadcast data).mean_encoder_fn
which creates an instance of te.core.GatherEncoder to encode tensors or variables in client to server communicaiton (Aggregation data).It is important to note that we do not apply a compression method to the entire model at once. Instead, we decide how (and whether) to compress each variable of the model independently. The reason is that generally, small variables such as biases are more sensitive to inaccuracy, and being relatively small, the potential communication savings are also relatively small. Hence we do not compress small variables by default. In this example, we apply uniform quantization to 8 bits (256 buckets) to every variable with more than 10000 elements, and only apply identity to other variables.
In [ ]:
def broadcast_encoder_fn(value):
"""Function for building encoded broadcast."""
spec = tf.TensorSpec(value.shape, value.dtype)
if value.shape.num_elements() > 10000:
return te.encoders.as_simple_encoder(
te.encoders.uniform_quantization(bits=8), spec)
else:
return te.encoders.as_simple_encoder(te.encoders.identity(), spec)
def mean_encoder_fn(value):
"""Function for building encoded mean."""
spec = tf.TensorSpec(value.shape, value.dtype)
if value.shape.num_elements() > 10000:
return te.encoders.as_gather_encoder(
te.encoders.uniform_quantization(bits=8), spec)
else:
return te.encoders.as_gather_encoder(te.encoders.identity(), spec)
TFF provides APIs to convert the encoder function into a format that tff.learning.build_federated_averaging_process
API can consume. By using the tff.learning.framework.build_encoded_broadcast_from_model
and tff.learning.framework.build_encoded_mean_from_model
, we can create two functions that can be passed into broadcast_process
and aggregation_process
agruments of tff.learning.build_federated_averaging_process
to create a Federated Averaging algorithms with a lossy compression algorithm.
In [ ]:
encoded_broadcast_process = (
tff.learning.framework.build_encoded_broadcast_process_from_model(
tff_model_fn, broadcast_encoder_fn))
encoded_mean_process = (
tff.learning.framework.build_encoded_mean_process_from_model(
tff_model_fn, mean_encoder_fn))
federated_averaging_with_compression = tff.learning.build_federated_averaging_process(
tff_model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
broadcast_process=encoded_broadcast_process,
aggregation_process=encoded_mean_process)
In [ ]:
logdir_for_compression = "/tmp/logs/scalars/compression/"
summary_writer_for_compression = tf.summary.create_file_writer(
logdir_for_compression)
train(federated_averaging_process=federated_averaging_with_compression,
num_rounds=10,
num_clients_per_round=10,
summary_writer=summary_writer_for_compression)
Start TensorBoard again to compare the training metrics between two runs.
As you can see in Tensorboard, there is a significant reduction between the orginial
and compression
curves in the broadcasted_bits
and aggregated_bits
plots while in the loss
and sparse_categorical_accuracy
plot the two curves are pretty similiar.
In conclusion, we implemented a compression algorithm that can achieve similar performance as the orignial Federated Averaging algorithm while the comminucation cost is significently reduced.
In [ ]:
%tensorboard --logdir /tmp/logs/scalars/ --port=0
To implement a custom compression algorithm and apply it to the training loop, you can:
EncodingStageInterface
or its more general variant,
AdaptiveEncodingStageInterface
following
this example.Encoder
and specialize it for
model broadcast
or
model update averaging.Potentially valuable open research questions include: non-uniform quantization, lossless compression such as huffman coding, and mechanisms for adapting compression based on the information from previous training rounds.
Recommended reading materials: