In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
In [0]:
#@test {"skip": true}
# Note: If you are running a Jupyter notebook, and installing a locally built
# pip package, you may need to edit the following to point to the '.whl' file
# on your local filesystem.
!pip install --quiet --upgrade tensorflow_federated
!pip install --quiet --upgrade tf-nightly
!pip install --quiet --upgrade tensorflow-text
In [0]:
import tensorflow.google as tf
import tensorflow_federated as tff
import tensorflow_text as text
In [0]:
# https://www.tensorflow.org/federated/api_docs/python/tff/simulation/datasets/shakespeare/load_data
shake_train, shake_test = tff.simulation.datasets.shakespeare.load_data()
training_ds = shake_train.create_tf_dataset_from_all_clients()
In [0]:
## Preprocess the dataset to split each line into individual words
TRIM_TO_LINES=1000
tokenizer = text.UnicodeScriptTokenizer()
def preprocess(ds):
# Trim dataset to improve example performance
ds = ds.take(TRIM_TO_LINES)
ds = ds.map(lambda l: tf.expand_dims(l['snippets'], 0))
ds = ds.flat_map(lambda l: tf.data.Dataset.from_tensor_slices(
tokenizer.tokenize(l)[0]))
ds = ds.map(text.case_fold_utf8)
ds = ds.filter(lambda w: tf.math.logical_not(text.wordshape(w, text.WordShape.IS_PUNCT_OR_SYMBOL)))
ds = ds.shuffle(buffer_size=50000)
return ds
dataset = preprocess(training_ds)
for i in dataset.take(3):
print(i.numpy())
In [0]:
## Build a vocab dictionary by getting the set of unique words
vocab = dataset.apply(tf.data.experimental.unique())
vocab_list = [t.numpy() for t in vocab]
print('Final vocab length %d' % len(vocab_list))
print(vocab_list[:3])
In [0]:
# BATCH_SIZE = 5
# TAKE = 2
# K = 5
BATCH_SIZE = 10000
TAKE = -1
K = 30
vocab_lookup = tf.lookup.index_table_from_tensor(vocab_list)
vocab_size = tf.cast(vocab_lookup.size(), tf.int32)
print('Vocab size: %d' % vocab_size.numpy())
counts = tf.zeros([vocab_size])
for batch in dataset.batch(BATCH_SIZE).take(TAKE):
indices = vocab_lookup.lookup(batch)
onehot = tf.one_hot(indices, depth=vocab_size)
counts += tf.reduce_sum(onehot, axis=0)
top_vals, top_indices = tf.math.top_k(counts, k=K)
top_words = tf.gather(vocab_list, top_indices)
print('.', end='')
print()
for word,count in zip(top_words, top_vals):
print('%s: %d' % (word.numpy().decode('utf-8'), count))
In [0]:
## Dataset prep
TRIM_TO_CLIENTS = 25
client_ids = shake_train.client_ids[:TRIM_TO_CLIENTS]
client_datasets = [preprocess(shake_train.create_tf_dataset_for_client(id)) for id in client_ids]
print('Num clients: %d' % len(client_datasets))
for ds in client_datasets[:3]:
for words in ds.batch(3).take(1):
print(words.numpy())
In [0]:
## Initial decomposition to map-reduce style, not actually TFF just yet!
@tf.function
def client_map_step(ds):
# N.B. vocab_size and vocab_lookup must be created inside the @tf.function
vocab_size = len(vocab_list)
vocab_lookup = tf.lookup.index_table_from_tensor(vocab_list)
@tf.function
def _count_words_in_batch(acummulator, batch):
indices = vocab_lookup.lookup(batch)
onehot = tf.one_hot(indices, depth=tf.cast(vocab_size, tf.int32), dtype=tf.int32)
return acummulator + tf.reduce_sum(onehot, axis=0)
return ds.batch(BATCH_SIZE).take(TAKE).reduce(
initial_state=tf.zeros([vocab_size], tf.int32),
reduce_func=_count_words_in_batch)
@tf.function
def cross_client_reduce_step(client_aggregates):
reduced = tf.math.add_n(client_aggregates)
top_vals, top_indices = tf.math.top_k(reduced, k=K)
top_words = tf.gather(vocab_list, top_indices)
return top_words, top_vals
# Wire it all together
client_sums = list()
for client_ds in client_datasets:
print('.', end='')
client_sums.append(client_map_step(client_ds))
top_words, top_counts = cross_client_reduce_step(client_sums)
print()
for word,count in zip(top_words, top_vals):
print('%s: %d' % (word.numpy().decode('utf-8'), count))
In [0]:
@tff.federated_computation(
tff.FederatedType((tff.SequenceType(tf.string)), tff.CLIENTS))
def federated_top_k_words(client_datasets):
tff_map = tff.tf_computation(
client_map_step, client_datasets.type_signature.member)
print(tff_map.type_signature) # (string* -> int32[VOCAB_SIZE])
client_aggregates = tff.federated_map(tff_map, client_datasets)
print(client_aggregates.type_signature) # {int32[VOCAB_SIZE]@CLIENTS}
@tff.tf_computation()
def build_zeros():
return tf.zeros([len(vocab_list)], tf.int32)
print(build_zeros.type_signature) # ( -> int32[VOCAB_SIZE])
@tff.tf_computation(tff_map.type_signature.result,
tff_map.type_signature.result)
def accumulate(accum, delta):
return accum + delta
print(accumulate.type_signature) # (<int32[VOCAB_SIZE],int32[VOCAB_SIZE]> -> int32[VOCAB_SIZE])
@tff.tf_computation(accumulate.type_signature.result)
def report(accum):
top_vals, top_indices = tf.math.top_k(accum, k=K)
top_words = tf.gather(vocab_list, top_indices)
return top_words, top_vals
print(report.type_signature) # (int32[VOCAB_SIZE] -> <string[K],int32[]>)
aggregate = tff.federated_aggregate(
value=client_aggregates,
zero=build_zeros(),
accumulate=accumulate,
merge=accumulate,
report=report,
)
print(aggregate.type_signature) # <string[K],int32[K]>@SERVER
return aggregate
print(federated_top_k_words.type_signature) # ({string*}@CLIENTS -> <string[K],int32[K]>@SERVER)
In [0]:
top_words, top_vals = federated_top_k_words([ds for ds in client_datasets])
for word,count in zip(top_words, top_vals):
print('%s: %d' % (word.decode('utf-8'), count))