This tutorial will describe how to set up high-performance simulation using a TFF runtime running on Kubernetes. The model is the same as in the previous tutorial, High-performance simulations with TFF. The only difference is that here we use a worker pool instead of a local executor.
This tutorial refers to Google Cloud's GKE to create the Kubernetes cluster, but all the steps after the cluster is created can be used with any Kubernetes installation.
Note: This tutorial assumes the user has an existing GCP project.
The following step only needs to be done once. The cluster can be re-used for future workloads.
Follow the GKE instructions to create a container cluster. The rest of this tutorial assumes that the cluster is named tff-cluster
, but the actual name isn't important.
Stop following the instructions when you get to "Step 5: Deploy your application".
The commands to interact with GCP can be run locally or in the Google Cloud Shell. We recommend the Google Cloud Shell since it doesn't require additional setup.
$ kubectl create deployment tff-workers --image=gcr.io/tensorflow-federated/remote-executor-service:{{version}}
Replace {{version}}
with a release of TFF, e.g. 0.11.0
or latest
.
$ kubectl expose deployment tff-workers --type=LoadBalancer --port 80 --target-port 8000
Note: This exposes your deployment to the internet and is for demo purposes only. For production use, a firewall and authentication are strongly recommended.
Look up the IP address of the loadbalancer on the Google Cloud Console. You'll need it later to connect the training loop to the worker app.
In [0]:
#@test {"skip": true}
!pip install --upgrade tensorflow_federated
In [0]:
import collections
import time
import tensorflow as tf
import tensorflow_federated as tff
source, _ = tff.simulation.datasets.emnist.load_data()
def map_fn(example):
return collections.OrderedDict(
x=tf.reshape(example['pixels'], [-1, 784]), y=example['label'])
def client_data(n):
ds = source.create_tf_dataset_for_client(source.client_ids[n])
return ds.repeat(10).batch(20).map(map_fn)
train_data = [client_data(n) for n in range(10)]
input_spec = train_data[0].element_spec
def model_fn():
model = tf.keras.models.Sequential([
tf.keras.layers.Input(shape=(784,)),
tf.keras.layers.Dense(units=10, kernel_initializer='zeros'),
tf.keras.layers.Softmax(),
])
return tff.learning.from_keras_model(
model,
input_spec=input_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
trainer = tff.learning.build_federated_averaging_process(
model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02))
def evaluate(num_rounds=10):
state = trainer.initialize()
for round in range(num_rounds):
t1 = time.time()
state, metrics = trainer.next(state, train_data)
t2 = time.time()
print('Round {}: loss {}, round time {}'.format(round, metrics.loss, t2 - t1))
In [0]:
import grpc
ip_address = '0.0.0.0' #@param {type:"string"}
port = 80 #@param {type:"integer"}
client_ex = []
for i in range(10):
channel = grpc.insecure_channel('{}:{}'.format(ip_address, port))
client_ex.append(tff.framework.RemoteExecutor(channel, rpc_mode='STREAMING'))
factory = tff.framework.worker_pool_executor_factory(client_ex)
context = tff.framework.ExecutionContext(factory)
tff.framework.set_default_context(context)
In [0]:
evaluate()