This sample shows how to use the distribution strategy APIs when writing a custom training loop on TPU:

  • instantiate a TPUStrategy()
  • create the model and all other trainin objects in a strategy scope with strategy.scope(): ...
  • distribute the dataset with strategy.experimental_distribute_dataset(ds)
  • run the training step distributed with strategy.experimental_run_v2(step_fn)
  • aggregate results returned by distributed workers with strategy.reduce(...)


In [1]:
import re, sys
if 'google.colab' in sys.modules: # Colab-only Tensorflow version selector
  %tensorflow_version 2.x
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
print("Tensorflow version " + tf.__version__)

Tensorflow version 2.1.0-dev20191028

TPU or GPU detection

TPUClusterResolver() automatically detects a connected TPU on all Gooogle's platforms: Colaboratory, AI Platform (ML Engine), Kubernetes and Deep Learning VMs provided the TPU_NAME environment variable is set on the VM.

In [2]:
# Detect hardware, return appropriate distribution strategy
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
    tpu = None

if tpu:
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
    strategy = tf.distribute.get_strategy()

print("REPLICAS: ", strategy.num_replicas_in_sync)

Running on TPU  ['']
INFO:tensorflow:Initializing the TPU system: martin-tpu-nightly
INFO:tensorflow:Initializing the TPU system: martin-tpu-nightly
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)

Configuration and learning rate schedule

In [3]:

if strategy.num_replicas_in_sync == 1: # GPU
    # This achieves 80% accuracy on a GPU (final loss 0.45)
    BATCH_SIZE = 16
    START_LR = 0.01
    MAX_LR = 0.01
    MIN_LR = 0.01
    LR_RAMP = 0 # epochs
    LR_SUSTAIN = 0 #epochs
    LR_DECAY = 1
elif strategy.num_replicas_in_sync == 8: # single TPU
    # This achieves 80% accuracy on a TPU v3-8 (final loss 0.44)
    BATCH_SIZE = 16 * strategy.num_replicas_in_sync # use 32 on TPUv3
    START_LR = 0.01
    MAX_LR = 0.01 * strategy.num_replicas_in_sync
    MIN_LR = 0.001
    LR_RAMP = 0 # epochs
    LR_SUSTAIN = 13 # epochs
    LR_DECAY = .95

else: # TPU pod
    # This achieves 80% accuracy on a TPU v2-32 pod (final loss 0.54)
    BATCH_SIZE = 16 * strategy.num_replicas_in_sync  # Gobal batch size.
    START_LR = 0.06
    MAX_LR = 0.012 * strategy.num_replicas_in_sync
    MIN_LR = 0.01
    LR_RAMP = 5 # epochs
    LR_SUSTAIN = 8 # epochs
    LR_DECAY = 0.95

CLASSES = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'] # do not change, maps to the labels in the data (folder names)

IMAGE_SIZE = [331, 331] # supported images sizes: 192x192, 331x331, 512,512
                        # make sure you load the appropriate dataset on the next line
#GCS_PATTERN = 'gs://flowers-public/tfrecords-jpeg-192x192-2/*.tfrec'
GCS_PATTERN = 'gs://flowers-public/tfrecords-jpeg-331x331/*.tfrec'
#GCS_PATTERN = 'gs://flowers-public/tfrecords-jpeg-512x512/*.tfrec'

# in custom training loop training you
# need an object to hold the epoch value
class LRSchedule():
    def __init__(self):
        self.epoch = 0
    def set_epoch(self, epoch):
        self.epoch = epoch

    def lrfn(epoch, start_lr, max_lr, min_lr, rampup_epochs, sustain_epochs, exp_decay):
        if epoch < rampup_epochs:  # linear ramp from start_lr to max_lr
            lr = (max_lr - start_lr)/rampup_epochs * epoch + start_lr
        elif epoch < rampup_epochs + sustain_epochs:  # constant ar max_lr
            lr = max_lr
        else:  # exponential decay from max_lr to min_lr
            lr = (max_lr - min_lr) * exp_decay**(epoch-rampup_epochs-sustain_epochs) + min_lr
        return lr
    def default_lrfn(epoch):
        return LRSchedule.lrfn(epoch, START_LR, MAX_LR, MIN_LR, LR_RAMP, LR_SUSTAIN, LR_DECAY)
    def lr(self):
        return self.default_lrfn(self.epoch)

print("Learning rate schedule:")
rng = [i for i in range(EPOCHS)]
plt.plot(rng, [LRSchedule.default_lrfn(x) for x in rng])

Learning rate schedule:

In [4]:
#@title display utilities [RUN ME]

def dataset_to_numpy_util(dataset, N):
  dataset = dataset.batch(N)
  if tf.executing_eagerly():
    # In eager mode, iterate in the Datset directly.
    for images, labels in dataset:
      numpy_images = images.numpy()
      numpy_labels = labels.numpy()
  else: # In non-eager mode, must get the TF note that 
        # yields the nextitem and run it in a tf.Session.
    get_next_item = dataset.make_one_shot_iterator().get_next()
    with tf.Session() as ses:
      numpy_images, numpy_labels =

  return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
  label = np.argmax(label, axis=-1)  # one-hot to class number
  correct_label = np.argmax(correct_label, axis=-1) # one-hot to class number
  correct = (label == correct_label)
  return "{} [{}{}{}]".format(CLASSES[label], str(correct), ', shoud be ' if not correct else '',
                              CLASSES[correct_label] if not correct else ''), correct

def display_one_flower(image, title, subplot, red=False):
    plt.title(title, fontsize=16, color='red' if red else 'black')
    return subplot+1
def display_9_images_from_dataset(dataset):
  images, labels = dataset_to_numpy_util(dataset, 9)
  for i, image in enumerate(images):
    title = CLASSES[np.argmax(labels[i], axis=-1)]
    subplot = display_one_flower(image, title, subplot)
    if i >= 8:
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
def display_9_images_with_predictions(images, predictions, labels):
  for i, image in enumerate(images):
    title, correct = title_from_label_and_target(predictions[i], labels[i])
    subplot = display_one_flower(image, title, subplot, not correct)
    if i >= 8:
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
def display_training_curves(training, validation, title, subplot):
  if subplot%10==1: # set up the subplots on the first call
    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
  ax = plt.subplot(subplot)
  ax.set_title('model '+ title)
  ax.legend(['train', 'valid.'])

Read images and labels from TFRecords

In [5]:

def count_data_items(filenames):
    # trick: the number of data items is written in the name of
    # the .tfrec files a flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

def data_augment(image, one_hot_class):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_saturation(image, 0, 2)
    return image, one_hot_class

def read_tfrecord(example):
    features = {
        "image":[], tf.string), # tf.string means bytestring
        "class":[], tf.int64),  # shape [] means scalar
    example =, features)
    image = tf.image.decode_jpeg(example['image'], channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # force the image size so that the shape of the tensor is known to Tensorflow
    class_label = tf.cast(example['class'], tf.int32)
    one_hot_class = tf.sparse.to_dense(example['one_hot_class'])
    one_hot_class = tf.reshape(one_hot_class, [5])
    return image, one_hot_class

def load_dataset(filenames):
    # read from TFRecords. For optimal performance, use TFRecordDataset with
    # num_parallel_calls=AUTOTUNE to read from multiple TFRecord files at once
    # band set the option experimental_deterministic = False
    # to allow order-altering optimizations.

    opt =
    opt.experimental_deterministic = False

    dataset =
    dataset =, num_parallel_reads=16) # can be AUTOTUNE in TF 2.1
    dataset =, num_parallel_calls=AUTOTUNE)
    return dataset

def batch_dataset(filenames, batch_size, train):
    dataset = load_dataset(filenames)
    n = count_data_items(filenames)
    if train:
        dataset = dataset.repeat() # training dataset must repeat
        dataset =, num_parallel_calls=AUTOTUNE)
        dataset = dataset.shuffle(2048)
        # usually fewer validation files than workers so disable FILE auto-sharding on validation
        if strategy.num_replicas_in_sync > 1: # option not useful if there is no sharding (not harmful either)
            opt =
            opt.experimental_distribute.auto_shard_policy =
            dataset = dataset.with_options(opt)
        # validation dataset does not need to repeat
        # also no need to shuffle or apply data augmentation
    if train:
        dataset = dataset.batch(batch_size)
        # little wrinkle: drop_remainder is NOT necessary but validation on the last
        # partial batch sometimes returns a "nan" loss (probably a bug). You can remove
        # this if you do not care about the validatoin loss.
        dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(AUTOTUNE) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset, n//batch_size

def get_training_dataset(filenames):
    dataset, steps = batch_dataset(filenames, BATCH_SIZE, train=True)
    return dataset, steps

def get_validation_dataset(filenames):
    dataset, steps = batch_dataset(filenames, VALIDATION_BATCH_SIZE, train=False)
    return dataset, steps

In [6]:
# instantiate datasets
filenames =
split = len(filenames) - int(len(filenames) * VALIDATION_SPLIT)
train_filenames = filenames[:split]
valid_filenames = filenames[split:]

training_dataset, steps_per_epoch = get_training_dataset(train_filenames)
validation_dataset, validation_steps = get_validation_dataset(valid_filenames)

print("TRAINING   IMAGES: ", count_data_items(train_filenames), ", STEPS PER EPOCH: ", steps_per_epoch)
print("VALIDATION IMAGES: ", count_data_items(valid_filenames), ", STEPS PER EPOCH: ", validation_steps)

# numpy data to test predictions
some_flowers, some_labels = dataset_to_numpy_util(load_dataset(valid_filenames), 160)


In [7]:

The model: squeezenet with 12 layers

In [8]:
def create_model():
    bnmomemtum=0.9 # with only a handful of batches per epoch, the batch norm running average period must be lowered
    def fire(x, squeeze, expand):
        y  = tf.keras.layers.Conv2D(filters=squeeze, kernel_size=1, activation=None, padding='same', use_bias=False)(x)
        y = tf.keras.layers.BatchNormalization(momentum=bnmomemtum, scale=False, center=True)(y)
        y = tf.keras.layers.Activation('relu')(y)
        y1 = tf.keras.layers.Conv2D(filters=expand//2, kernel_size=1, activation=None, padding='same', use_bias=False)(y)
        y1 = tf.keras.layers.BatchNormalization(momentum=bnmomemtum, scale=False, center=True)(y1)
        y1 = tf.keras.layers.Activation('relu')(y1)
        y3 = tf.keras.layers.Conv2D(filters=expand//2, kernel_size=3, activation=None, padding='same', use_bias=False)(y)
        y3 = tf.keras.layers.BatchNormalization(momentum=bnmomemtum, scale=False, center=True)(y3)
        y3 = tf.keras.layers.Activation('relu')(y3)
        return tf.keras.layers.concatenate([y1, y3])

    def fire_module(squeeze, expand):
        return lambda x: fire(x, squeeze, expand)

    x = tf.keras.layers.Input(shape=(*IMAGE_SIZE, 3)) # input is 331x331 pixels RGB
    y = tf.keras.layers.Conv2D(kernel_size=3, filters=32, padding='same', use_bias=True, activation='relu')(x)
    y = tf.keras.layers.BatchNormalization(momentum=bnmomemtum)(y)
    y = fire_module(24, 48)(y)
    y = tf.keras.layers.MaxPooling2D(pool_size=2)(y)
    y = fire_module(48, 96)(y)
    y = tf.keras.layers.MaxPooling2D(pool_size=2)(y)
    y = fire_module(64, 128)(y)
    y = tf.keras.layers.MaxPooling2D(pool_size=2)(y)
    y = fire_module(48, 96)(y)
    y = tf.keras.layers.MaxPooling2D(pool_size=2)(y)
    y = fire_module(24, 48)(y)
    y = tf.keras.layers.GlobalAveragePooling2D()(y)
    y = tf.keras.layers.Dropout(0.4)(y)
    y = tf.keras.layers.Dense(5, activation='softmax')(y)
    return tf.keras.Model(x, y)

Instantiate all objects in the strategy scope

In [9]:
with strategy.scope():
    model = create_model()
    lr_schedule = LRSchedule()
    optimizer = tf.keras.optimizers.SGD(nesterov=True, momentum=0.9, learning_rate =
    train_accuracy = tf.keras.metrics.CategoricalAccuracy()
    valid_accuracy = tf.keras.metrics.CategoricalAccuracy()
    loss_fn = lambda labels, probabilities: tf.reduce_mean(tf.keras.losses.categorical_crossentropy(labels, probabilities))

Model: "model"
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 331, 331, 3) 0                                            
conv2d (Conv2D)                 (None, 331, 331, 32) 896         input_1[0][0]                    
batch_normalization (BatchNorma (None, 331, 331, 32) 128         conv2d[0][0]                     
conv2d_1 (Conv2D)               (None, 331, 331, 24) 768         batch_normalization[0][0]        
batch_normalization_1 (BatchNor (None, 331, 331, 24) 72          conv2d_1[0][0]                   
activation (Activation)         (None, 331, 331, 24) 0           batch_normalization_1[0][0]      
conv2d_2 (Conv2D)               (None, 331, 331, 24) 576         activation[0][0]                 
conv2d_3 (Conv2D)               (None, 331, 331, 24) 5184        activation[0][0]                 
batch_normalization_2 (BatchNor (None, 331, 331, 24) 72          conv2d_2[0][0]                   
batch_normalization_3 (BatchNor (None, 331, 331, 24) 72          conv2d_3[0][0]                   
activation_1 (Activation)       (None, 331, 331, 24) 0           batch_normalization_2[0][0]      
activation_2 (Activation)       (None, 331, 331, 24) 0           batch_normalization_3[0][0]      
concatenate (Concatenate)       (None, 331, 331, 48) 0           activation_1[0][0]               
max_pooling2d (MaxPooling2D)    (None, 165, 165, 48) 0           concatenate[0][0]                
conv2d_4 (Conv2D)               (None, 165, 165, 48) 2304        max_pooling2d[0][0]              
batch_normalization_4 (BatchNor (None, 165, 165, 48) 144         conv2d_4[0][0]                   
activation_3 (Activation)       (None, 165, 165, 48) 0           batch_normalization_4[0][0]      
conv2d_5 (Conv2D)               (None, 165, 165, 48) 2304        activation_3[0][0]               
conv2d_6 (Conv2D)               (None, 165, 165, 48) 20736       activation_3[0][0]               
batch_normalization_5 (BatchNor (None, 165, 165, 48) 144         conv2d_5[0][0]                   
batch_normalization_6 (BatchNor (None, 165, 165, 48) 144         conv2d_6[0][0]                   
activation_4 (Activation)       (None, 165, 165, 48) 0           batch_normalization_5[0][0]      
activation_5 (Activation)       (None, 165, 165, 48) 0           batch_normalization_6[0][0]      
concatenate_1 (Concatenate)     (None, 165, 165, 96) 0           activation_4[0][0]               
max_pooling2d_1 (MaxPooling2D)  (None, 82, 82, 96)   0           concatenate_1[0][0]              
conv2d_7 (Conv2D)               (None, 82, 82, 64)   6144        max_pooling2d_1[0][0]            
batch_normalization_7 (BatchNor (None, 82, 82, 64)   192         conv2d_7[0][0]                   
activation_6 (Activation)       (None, 82, 82, 64)   0           batch_normalization_7[0][0]      
conv2d_8 (Conv2D)               (None, 82, 82, 64)   4096        activation_6[0][0]               
conv2d_9 (Conv2D)               (None, 82, 82, 64)   36864       activation_6[0][0]               
batch_normalization_8 (BatchNor (None, 82, 82, 64)   192         conv2d_8[0][0]                   
batch_normalization_9 (BatchNor (None, 82, 82, 64)   192         conv2d_9[0][0]                   
activation_7 (Activation)       (None, 82, 82, 64)   0           batch_normalization_8[0][0]      
activation_8 (Activation)       (None, 82, 82, 64)   0           batch_normalization_9[0][0]      
concatenate_2 (Concatenate)     (None, 82, 82, 128)  0           activation_7[0][0]               
max_pooling2d_2 (MaxPooling2D)  (None, 41, 41, 128)  0           concatenate_2[0][0]              
conv2d_10 (Conv2D)              (None, 41, 41, 48)   6144        max_pooling2d_2[0][0]            
batch_normalization_10 (BatchNo (None, 41, 41, 48)   144         conv2d_10[0][0]                  
activation_9 (Activation)       (None, 41, 41, 48)   0           batch_normalization_10[0][0]     
conv2d_11 (Conv2D)              (None, 41, 41, 48)   2304        activation_9[0][0]               
conv2d_12 (Conv2D)              (None, 41, 41, 48)   20736       activation_9[0][0]               
batch_normalization_11 (BatchNo (None, 41, 41, 48)   144         conv2d_11[0][0]                  
batch_normalization_12 (BatchNo (None, 41, 41, 48)   144         conv2d_12[0][0]                  
activation_10 (Activation)      (None, 41, 41, 48)   0           batch_normalization_11[0][0]     
activation_11 (Activation)      (None, 41, 41, 48)   0           batch_normalization_12[0][0]     
concatenate_3 (Concatenate)     (None, 41, 41, 96)   0           activation_10[0][0]              
max_pooling2d_3 (MaxPooling2D)  (None, 20, 20, 96)   0           concatenate_3[0][0]              
conv2d_13 (Conv2D)              (None, 20, 20, 24)   2304        max_pooling2d_3[0][0]            
batch_normalization_13 (BatchNo (None, 20, 20, 24)   72          conv2d_13[0][0]                  
activation_12 (Activation)      (None, 20, 20, 24)   0           batch_normalization_13[0][0]     
conv2d_14 (Conv2D)              (None, 20, 20, 24)   576         activation_12[0][0]              
conv2d_15 (Conv2D)              (None, 20, 20, 24)   5184        activation_12[0][0]              
batch_normalization_14 (BatchNo (None, 20, 20, 24)   72          conv2d_14[0][0]                  
batch_normalization_15 (BatchNo (None, 20, 20, 24)   72          conv2d_15[0][0]                  
activation_13 (Activation)      (None, 20, 20, 24)   0           batch_normalization_14[0][0]     
activation_14 (Activation)      (None, 20, 20, 24)   0           batch_normalization_15[0][0]     
concatenate_4 (Concatenate)     (None, 20, 20, 48)   0           activation_13[0][0]              
global_average_pooling2d (Globa (None, 48)           0           concatenate_4[0][0]              
dropout (Dropout)               (None, 48)           0           global_average_pooling2d[0][0]   
dense (Dense)                   (None, 5)            245         dropout[0][0]                    
Total params: 119,365
Trainable params: 118,053
Non-trainable params: 1,312

Step functions

In [10]:
def train_step(images, labels):
    with tf.GradientTape() as tape:
        probabilities = model(images, training=True)
        loss = loss_fn(labels, probabilities)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    train_accuracy.update_state(labels, probabilities)
    return loss

def valid_step(images, labels):
    probabilities = model(images, training=False)
    loss = loss_fn(labels, probabilities)
    valid_accuracy.update_state(labels, probabilities)
    return loss

Custom training loop

In [11]:
# distribute the datset according to the strategy
train_dist_ds = strategy.experimental_distribute_dataset(training_dataset)
valid_dist_ds = strategy.experimental_distribute_dataset(validation_dataset)

print("Steps per epoch: ", steps_per_epoch)

epoch = -1
for step, (images, labels) in enumerate(train_dist_ds):

    # batch losses from all replicas
    loss = strategy.experimental_run_v2(train_step, args=(images, labels))
    # reduced to a single number both across replicas and across the bacth size
    loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, loss, axis=None)
    # or use loss.values to access the raw set of losses returned from all replicas
    # (the official API is strategy.experiemental_local_results(loss), does the same as loss.values)

    # validation run at the end of each epoch
    if (step // steps_per_epoch) > epoch:
        valid_loss = []
        for image, labels in valid_dist_ds:
            batch_loss = strategy.experimental_run_v2(valid_step, args=(image, labels)) # just one batch
            batch_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, batch_loss, axis=None)
        valid_loss = np.mean(valid_loss)

        epoch = step // steps_per_epoch
        print('\nEPOCH: ', epoch)
        print('loss: ', loss.numpy(), ', accuracy_: ', train_accuracy.result().numpy(), ' , val_loss: ', valid_loss, ' , val_acc_: ', valid_accuracy.result().numpy(), ' , lr: ',
        if epoch >= EPOCHS:
    print('=', end='')

Steps per epoch:  11

loss:  1.8574139 , accuracy_:  0.19140625  , val_loss:  1.6670775  , val_acc_:  0.21875  , lr:  0.08
loss:  1.370311 , accuracy_:  0.33132103  , val_loss:  62.40265  , val_acc_:  0.203125  , lr:  0.08
loss:  1.3508414 , accuracy_:  0.42294034  , val_loss:  10.712177  , val_acc_:  0.3515625  , lr:  0.08
loss:  1.2815936 , accuracy_:  0.45419034  , val_loss:  2.663825  , val_acc_:  0.3984375  , lr:  0.08
loss:  1.2429017 , accuracy_:  0.50177556  , val_loss:  1.9874109  , val_acc_:  0.45898438  , lr:  0.08
loss:  1.1355323 , accuracy_:  0.51740056  , val_loss:  1.2704945  , val_acc_:  0.5332031  , lr:  0.08
loss:  1.1954806 , accuracy_:  0.53728694  , val_loss:  1.8864973  , val_acc_:  0.453125  , lr:  0.08
loss:  1.150807 , accuracy_:  0.5465199  , val_loss:  1.120469  , val_acc_:  0.5957031  , lr:  0.08
loss:  1.0735602 , accuracy_:  0.559304  , val_loss:  1.8047378  , val_acc_:  0.53125  , lr:  0.08
loss:  1.0801067 , accuracy_:  0.5628551  , val_loss:  1.0923085  , val_acc_:  0.5957031  , lr:  0.08
EPOCH:  10
loss:  0.9859637 , accuracy_:  0.57990056  , val_loss:  1.5855552  , val_acc_:  0.5390625  , lr:  0.08
EPOCH:  11
loss:  1.025581 , accuracy_:  0.5827415  , val_loss:  1.0143851  , val_acc_:  0.6191406  , lr:  0.08
EPOCH:  12
loss:  0.99369746 , accuracy_:  0.60546875  , val_loss:  1.1153538  , val_acc_:  0.6230469  , lr:  0.08
EPOCH:  13
loss:  0.94948137 , accuracy_:  0.6090199  , val_loss:  0.92257625  , val_acc_:  0.6171875  , lr:  0.08
EPOCH:  14
loss:  0.9516282 , accuracy_:  0.6147017  , val_loss:  0.9634201  , val_acc_:  0.61328125  , lr:  0.07604999999999999
EPOCH:  15
loss:  0.9929214 , accuracy_:  0.6147017  , val_loss:  0.91528195  , val_acc_:  0.6542969  , lr:  0.0722975
EPOCH:  16
loss:  0.85034007 , accuracy_:  0.6292614  , val_loss:  1.1177208  , val_acc_:  0.5625  , lr:  0.06873262499999999
EPOCH:  17
loss:  0.92602944 , accuracy_:  0.6161222  , val_loss:  0.88137645  , val_acc_:  0.6503906  , lr:  0.06534599375
EPOCH:  18
loss:  0.88779795 , accuracy_:  0.64666194  , val_loss:  0.8969081  , val_acc_:  0.6582031  , lr:  0.06212869406249998
EPOCH:  19
loss:  0.9948027 , accuracy_:  0.6558949  , val_loss:  1.1083667  , val_acc_:  0.6152344  , lr:  0.05907225935937498
EPOCH:  20
loss:  1.0186219 , accuracy_:  0.6352983  , val_loss:  0.9336281  , val_acc_:  0.6328125  , lr:  0.05616864639140623
EPOCH:  21
loss:  1.0545349 , accuracy_:  0.6296165  , val_loss:  1.073566  , val_acc_:  0.5839844  , lr:  0.05341021407183592
EPOCH:  22
loss:  0.9266959 , accuracy_:  0.6395597  , val_loss:  0.8530734  , val_acc_:  0.6640625  , lr:  0.05078970336824412
EPOCH:  23
loss:  0.8740611 , accuracy_:  0.65234375  , val_loss:  1.0662813  , val_acc_:  0.6269531  , lr:  0.048300218199831914
EPOCH:  24
loss:  0.8150981 , accuracy_:  0.66690344  , val_loss:  0.87279373  , val_acc_:  0.6699219  , lr:  0.04593520728984032
EPOCH:  25
loss:  0.87887394 , accuracy_:  0.63778406  , val_loss:  0.86162126  , val_acc_:  0.6621094  , lr:  0.0436884469253483
EPOCH:  26
loss:  0.79184717 , accuracy_:  0.6828835  , val_loss:  0.853451  , val_acc_:  0.6777344  , lr:  0.04155402457908088
EPOCH:  27
loss:  0.85961646 , accuracy_:  0.6651278  , val_loss:  0.8895781  , val_acc_:  0.66796875  , lr:  0.03952632335012683
EPOCH:  28
loss:  0.86511034 , accuracy_:  0.66015625  , val_loss:  1.0709796  , val_acc_:  0.6484375  , lr:  0.03760000718262049
EPOCH:  29
loss:  0.83482003 , accuracy_:  0.6772017  , val_loss:  0.8657106  , val_acc_:  0.6816406  , lr:  0.035770006823489464
EPOCH:  30
loss:  0.8544643 , accuracy_:  0.6715199  , val_loss:  0.95447123  , val_acc_:  0.6738281  , lr:  0.03403150648231499
EPOCH:  31
loss:  0.8925582 , accuracy_:  0.671875  , val_loss:  0.89564705  , val_acc_:  0.6738281  , lr:  0.03237993115819924
EPOCH:  32
loss:  0.8363841 , accuracy_:  0.6875  , val_loss:  0.8471222  , val_acc_:  0.6875  , lr:  0.030810934600289275
EPOCH:  33
loss:  0.8462085 , accuracy_:  0.69602275  , val_loss:  0.83931756  , val_acc_:  0.6777344  , lr:  0.02932038787027481
EPOCH:  34
loss:  0.78276515 , accuracy_:  0.6839489  , val_loss:  0.82418346  , val_acc_:  0.6933594  , lr:  0.02790436847676107
EPOCH:  35
loss:  0.8355361 , accuracy_:  0.6899858  , val_loss:  0.9210977  , val_acc_:  0.6738281  , lr:  0.026559150052923013
EPOCH:  36
loss:  0.80718756 , accuracy_:  0.70454544  , val_loss:  0.91811264  , val_acc_:  0.6699219  , lr:  0.025281192550276863
EPOCH:  37
loss:  0.70264596 , accuracy_:  0.6974432  , val_loss:  0.80253756  , val_acc_:  0.6933594  , lr:  0.02406713292276302
EPOCH:  38
loss:  0.7115408 , accuracy_:  0.7088068  , val_loss:  0.8474855  , val_acc_:  0.6894531  , lr:  0.02291377627662487
EPOCH:  39
loss:  0.6314294 , accuracy_:  0.7198153  , val_loss:  0.8814893  , val_acc_:  0.70703125  , lr:  0.021818087462793623
EPOCH:  40
loss:  0.7165874 , accuracy_:  0.73259944  , val_loss:  0.7393073  , val_acc_:  0.72265625  , lr:  0.02077718308965394
EPOCH:  41
loss:  0.7348884 , accuracy_:  0.7286932  , val_loss:  0.96582544  , val_acc_:  0.6953125  , lr:  0.019788323935171243
EPOCH:  42
loss:  0.785071 , accuracy_:  0.7141335  , val_loss:  0.89993995  , val_acc_:  0.6230469  , lr:  0.01884890773841268
EPOCH:  43
loss:  0.7332948 , accuracy_:  0.7389915  , val_loss:  0.8410876  , val_acc_:  0.6972656  , lr:  0.017956462351492047
EPOCH:  44
loss:  0.67525405 , accuracy_:  0.73828125  , val_loss:  0.8138516  , val_acc_:  0.6933594  , lr:  0.017108639233917443
EPOCH:  45
loss:  0.83102876 , accuracy_:  0.7389915  , val_loss:  0.8350415  , val_acc_:  0.71875  , lr:  0.01630320727222157
EPOCH:  46
loss:  0.63529277 , accuracy_:  0.73934656  , val_loss:  0.8200476  , val_acc_:  0.7167969  , lr:  0.015538046908610489
EPOCH:  47
loss:  0.74751514 , accuracy_:  0.71022725  , val_loss:  0.94797283  , val_acc_:  0.68359375  , lr:  0.014811144563179963
EPOCH:  48
loss:  0.7188057 , accuracy_:  0.7340199  , val_loss:  0.73836386  , val_acc_:  0.73046875  , lr:  0.014120587335020966
EPOCH:  49
loss:  0.6821742 , accuracy_:  0.75390625  , val_loss:  0.8478819  , val_acc_:  0.72265625  , lr:  0.013464557968269918
EPOCH:  50
loss:  0.6849743 , accuracy_:  0.74360794  , val_loss:  1.0432442  , val_acc_:  0.66796875  , lr:  0.01284133006985642
EPOCH:  51
loss:  0.7552935 , accuracy_:  0.75142044  , val_loss:  0.80256677  , val_acc_:  0.7207031  , lr:  0.012249263566363598
EPOCH:  52
loss:  0.68016076 , accuracy_:  0.7524858  , val_loss:  0.7880584  , val_acc_:  0.72265625  , lr:  0.01168680038804542
EPOCH:  53
loss:  0.89600945 , accuracy_:  0.73615056  , val_loss:  0.8785901  , val_acc_:  0.6777344  , lr:  0.011152460368643147
EPOCH:  54
loss:  0.8385702 , accuracy_:  0.7411222  , val_loss:  0.8847884  , val_acc_:  0.6640625  , lr:  0.01064483735021099
EPOCH:  55
loss:  0.58544314 , accuracy_:  0.75958806  , val_loss:  0.78721035  , val_acc_:  0.73046875  , lr:  0.010162595482700439
EPOCH:  56
loss:  0.6134448 , accuracy_:  0.74076706  , val_loss:  0.88842964  , val_acc_:  0.73046875  , lr:  0.009704465708565417
EPOCH:  57
loss:  0.6251881 , accuracy_:  0.7659801  , val_loss:  0.7727301  , val_acc_:  0.7011719  , lr:  0.009269242423137147
EPOCH:  58
loss:  0.6687313 , accuracy_:  0.75071025  , val_loss:  0.73691225  , val_acc_:  0.74609375  , lr:  0.008855780301980289
EPOCH:  59
loss:  0.64336467 , accuracy_:  0.77166194  , val_loss:  0.7323053  , val_acc_:  0.76171875  , lr:  0.008462991286881274
EPOCH:  60
loss:  0.6133585 , accuracy_:  0.7723722  , val_loss:  0.7338977  , val_acc_:  0.73828125  , lr:  0.008089841722537211

In [12]:
print("Detailed training loss:")

Detailed training loss:


(not distributed)

In [17]:
# randomize the input so that you can execute multiple times to change results
permutation = np.random.permutation(8*20)
some_flowers, some_labels = (some_flowers[permutation], some_labels[permutation])

predictions = model.predict(some_flowers, batch_size=16)
print(np.array(CLASSES)[np.argmax(predictions, axis=-1)].tolist())

['sunflowers', 'roses', 'dandelion', 'dandelion', 'sunflowers', 'daisy', 'dandelion', 'sunflowers', 'tulips', 'dandelion', 'sunflowers', 'dandelion', 'roses', 'sunflowers', 'roses', 'dandelion', 'daisy', 'tulips', 'daisy', 'dandelion', 'daisy', 'tulips', 'dandelion', 'daisy', 'dandelion', 'dandelion', 'sunflowers', 'daisy', 'tulips', 'dandelion', 'tulips', 'dandelion', 'dandelion', 'dandelion', 'dandelion', 'dandelion', 'sunflowers', 'tulips', 'roses', 'roses', 'roses', 'daisy', 'dandelion', 'dandelion', 'daisy', 'roses', 'roses', 'dandelion', 'tulips', 'dandelion', 'tulips', 'roses', 'roses', 'dandelion', 'tulips', 'daisy', 'sunflowers', 'sunflowers', 'tulips', 'tulips', 'tulips', 'dandelion', 'dandelion', 'daisy', 'dandelion', 'dandelion', 'sunflowers', 'sunflowers', 'dandelion', 'dandelion', 'dandelion', 'dandelion', 'dandelion', 'dandelion', 'dandelion', 'tulips', 'sunflowers', 'sunflowers', 'roses', 'dandelion', 'roses', 'tulips', 'sunflowers', 'tulips', 'sunflowers', 'dandelion', 'dandelion', 'dandelion', 'tulips', 'dandelion', 'dandelion', 'dandelion', 'daisy', 'sunflowers', 'dandelion', 'dandelion', 'dandelion', 'dandelion', 'dandelion', 'roses', 'dandelion', 'dandelion', 'dandelion', 'daisy', 'sunflowers', 'roses', 'dandelion', 'sunflowers', 'roses', 'daisy', 'tulips', 'sunflowers', 'roses', 'sunflowers', 'dandelion', 'daisy', 'sunflowers', 'dandelion', 'dandelion', 'tulips', 'roses', 'tulips', 'tulips', 'tulips', 'dandelion', 'daisy', 'sunflowers', 'daisy', 'dandelion', 'dandelion', 'tulips', 'roses', 'dandelion', 'daisy', 'tulips', 'roses', 'sunflowers', 'daisy', 'roses', 'dandelion', 'sunflowers', 'tulips', 'tulips', 'dandelion', 'dandelion', 'dandelion', 'roses', 'tulips', 'roses', 'daisy', 'sunflowers', 'tulips', 'dandelion', 'dandelion', 'dandelion', 'tulips', 'dandelion', 'roses', 'sunflowers', 'dandelion']

In [18]:
display_9_images_with_predictions(some_flowers, predictions, some_labels)


author: Martin Gorner
twitter: @martin_gorner

Copyright 2019 Google LLC

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

This is not an official Google product but sample code provided for an educational purpose