In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

WordCount in TFF

This notebook demonstrates a basic analytics query (count the top occurrences of each word in Shakespeare), implemented first as pure-Tensorflow and then as a TFF computation.

The goal is to demonstrate an analytics query in TFFs.


In [0]:
#@test {"skip": true}

# Note: If you are running a Jupyter notebook, and installing a locally built
# pip package, you may need to edit the following to point to the '.whl' file
# on your local filesystem.

!pip install --quiet --upgrade tensorflow_federated
!pip install --quiet --upgrade tf-nightly
!pip install --quiet --upgrade tensorflow-text


/usr/bin/sh: pip: command not found
/usr/bin/sh: pip: command not found
/usr/bin/sh: pip: command not found

In [0]:
import tensorflow.google as tf
import tensorflow_federated as tff
import tensorflow_text as text

In [0]:
# https://www.tensorflow.org/federated/api_docs/python/tff/simulation/datasets/shakespeare/load_data
shake_train, shake_test = tff.simulation.datasets.shakespeare.load_data()

training_ds = shake_train.create_tf_dataset_from_all_clients()

Dataset Pre-processing


In [0]:
## Preprocess the dataset to split each line into individual words
TRIM_TO_LINES=1000

tokenizer = text.UnicodeScriptTokenizer()
def preprocess(ds):
  # Trim dataset to improve example performance
  ds = ds.take(TRIM_TO_LINES)
  ds = ds.map(lambda l: tf.expand_dims(l['snippets'], 0))
  ds = ds.flat_map(lambda l: tf.data.Dataset.from_tensor_slices(
      tokenizer.tokenize(l)[0]))
  ds = ds.map(text.case_fold_utf8)
  ds = ds.filter(lambda w: tf.math.logical_not(text.wordshape(w, text.WordShape.IS_PUNCT_OR_SYMBOL)))
  ds = ds.shuffle(buffer_size=50000)
  return ds
 

dataset = preprocess(training_ds)
for i in dataset.take(3):
  print(i.numpy())


b'frown'
b'pages'
b'marcus'

In [0]:
## Build a vocab dictionary by getting the set of unique words

vocab = dataset.apply(tf.data.experimental.unique())
vocab_list = [t.numpy() for t in vocab]
print('Final vocab length %d' % len(vocab_list))
print(vocab_list[:3])


Final vocab length 3889
[b'road', b'and', b'call']

Pure Tensorflow implementation


In [0]:
# BATCH_SIZE = 5
# TAKE = 2
# K = 5
BATCH_SIZE = 10000
TAKE = -1
K = 30

vocab_lookup = tf.lookup.index_table_from_tensor(vocab_list)
vocab_size = tf.cast(vocab_lookup.size(), tf.int32)
print('Vocab size: %d' % vocab_size.numpy())

counts = tf.zeros([vocab_size])
for batch in dataset.batch(BATCH_SIZE).take(TAKE):
  indices = vocab_lookup.lookup(batch)
  onehot = tf.one_hot(indices, depth=vocab_size)
  counts += tf.reduce_sum(onehot, axis=0)
  top_vals, top_indices = tf.math.top_k(counts, k=K)
  top_words = tf.gather(vocab_list, top_indices)
  print('.', end='')

print()
for word,count in zip(top_words, top_vals):
  print('%s: %d' % (word.numpy().decode('utf-8'), count))


Vocab size: 3889
...
the: 736
and: 677
i: 592
to: 575
of: 422
you: 364
d: 338
a: 319
my: 310
that: 303
in: 293
not: 267
is: 251
he: 250
s: 238
me: 232
it: 231
him: 219
with: 211
have: 199
his: 194
be: 191
thou: 185
for: 185
we: 183
this: 174
as: 170
your: 166
but: 166
so: 151

Tensorflow Federated Approach


In [0]:
## Dataset prep
TRIM_TO_CLIENTS = 25

client_ids = shake_train.client_ids[:TRIM_TO_CLIENTS]
client_datasets = [preprocess(shake_train.create_tf_dataset_for_client(id)) for id in client_ids]

print('Num clients: %d' % len(client_datasets))
for ds in client_datasets[:3]:
  for words in ds.batch(3).take(1):
    print(words.numpy())


Num clients: 25
[b'within' b'is' b'i']
[b'a' b'reports' b'is']
[b'octavia' b'how' b'widower']

In [0]:
## Initial decomposition to map-reduce style, not actually TFF just yet!
@tf.function
def client_map_step(ds):
  # N.B. vocab_size and vocab_lookup must be created inside the @tf.function
  vocab_size = len(vocab_list)
  vocab_lookup = tf.lookup.index_table_from_tensor(vocab_list)
  
  @tf.function
  def _count_words_in_batch(acummulator, batch):    
    indices = vocab_lookup.lookup(batch)
    onehot = tf.one_hot(indices, depth=tf.cast(vocab_size, tf.int32), dtype=tf.int32)
    return acummulator + tf.reduce_sum(onehot, axis=0)

  return ds.batch(BATCH_SIZE).take(TAKE).reduce(
      initial_state=tf.zeros([vocab_size], tf.int32),
      reduce_func=_count_words_in_batch)

@tf.function
def cross_client_reduce_step(client_aggregates):
  reduced = tf.math.add_n(client_aggregates)
  top_vals, top_indices = tf.math.top_k(reduced, k=K)
  top_words = tf.gather(vocab_list, top_indices)
  return top_words, top_vals

# Wire it all together
client_sums = list()
for client_ds in client_datasets:
  print('.', end='')
  client_sums.append(client_map_step(client_ds))
top_words, top_counts = cross_client_reduce_step(client_sums)


print()
for word,count in zip(top_words, top_vals):
  print('%s: %d' % (word.numpy().decode('utf-8'), count))


.........................
the: 614
and: 549
to: 480
i: 435
of: 357
you: 316
d: 278
a: 264
in: 243
that: 237
my: 232
he: 211
not: 203
is: 187
it: 186
with: 179
him: 175
s: 171
we: 165
me: 163
his: 163
for: 161
have: 160
your: 149
be: 145
this: 142
as: 140
thou: 129
but: 127
so: 119

In [0]:
@tff.federated_computation(
  tff.FederatedType((tff.SequenceType(tf.string)), tff.CLIENTS))
def federated_top_k_words(client_datasets):
  tff_map = tff.tf_computation(
      client_map_step, client_datasets.type_signature.member)
  print(tff_map.type_signature)  # (string* -> int32[VOCAB_SIZE])

  client_aggregates = tff.federated_map(tff_map, client_datasets)
  print(client_aggregates.type_signature)  # {int32[VOCAB_SIZE]@CLIENTS}

  @tff.tf_computation()
  def build_zeros():
    return tf.zeros([len(vocab_list)], tf.int32)
  print(build_zeros.type_signature)  # ( -> int32[VOCAB_SIZE])

  @tff.tf_computation(tff_map.type_signature.result,
                      tff_map.type_signature.result)
  def accumulate(accum, delta):
    return accum + delta
  print(accumulate.type_signature)  # (<int32[VOCAB_SIZE],int32[VOCAB_SIZE]> -> int32[VOCAB_SIZE])

  @tff.tf_computation(accumulate.type_signature.result)
  def report(accum):
    top_vals, top_indices = tf.math.top_k(accum, k=K)
    top_words = tf.gather(vocab_list, top_indices)
    return top_words, top_vals
  print(report.type_signature)  # (int32[VOCAB_SIZE] -> <string[K],int32[]>)

  aggregate = tff.federated_aggregate(
      value=client_aggregates,
      zero=build_zeros(),
      accumulate=accumulate,
      merge=accumulate,
      report=report,
  )
  print(aggregate.type_signature)  # <string[K],int32[K]>@SERVER

  return aggregate 

print(federated_top_k_words.type_signature)  # ({string*}@CLIENTS -> <string[K],int32[K]>@SERVER)


(string* -> int32[3889])
{int32[3889]}@CLIENTS
( -> int32[3889])
(<int32[3889],int32[3889]> -> int32[3889])
(int32[3889] -> <string[30],int32[30]>)
<string[30],int32[30]>@SERVER
({string*}@CLIENTS -> <string[30],int32[30]>@SERVER)

In [0]:
top_words, top_vals = federated_top_k_words([ds for ds in client_datasets])
for word,count in zip(top_words, top_vals):
  print('%s: %d' % (word.decode('utf-8'), count))


the: 614
and: 549
to: 480
i: 435
of: 357
you: 316
d: 278
a: 264
in: 243
that: 237
my: 232
he: 211
not: 203
is: 187
it: 186
with: 179
him: 175
s: 171
we: 165
me: 163
his: 163
for: 161
have: 160
your: 149
be: 145
this: 142
as: 140
thou: 129
but: 127
so: 119