This sample shows how to use the distribution strategy APIs when writing a custom training loop on TPU:

  • instantiate a TPUStrategy()
  • create the model and all other trainin objects in a strategy scope with strategy.scope(): ...
  • distribute the dataset with strategy.experimental_distribute_dataset(ds)
  • run the training step distributed with strategy.experimental_run_v2(step_fn)
  • aggregate results returned by distributed workers with strategy.reduce(...)

Imports


In [1]:
import re, sys
if 'google.colab' in sys.modules: # Colab-only Tensorflow version selector
  %tensorflow_version 2.x
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
print("Tensorflow version " + tf.__version__)


Tensorflow version 2.1.0-dev20191028

TPU or GPU detection

TPUClusterResolver() automatically detects a connected TPU on all Gooogle's platforms: Colaboratory, AI Platform (ML Engine), Kubernetes and Deep Learning VMs provided the TPU_NAME environment variable is set on the VM.


In [2]:
# Detect hardware, return appropriate distribution strategy
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()

print("REPLICAS: ", strategy.num_replicas_in_sync)


Running on TPU  ['192.168.78.2:8470']
INFO:tensorflow:Initializing the TPU system: martin-tpu-nightly
INFO:tensorflow:Initializing the TPU system: martin-tpu-nightly
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
REPLICAS:  8

Configuration and learning rate schedule


In [3]:
EPOCHS = 60

if strategy.num_replicas_in_sync == 1: # GPU
    # This achieves 80% accuracy on a GPU (final loss 0.45)
    BATCH_SIZE = 16
    VALIDATION_BATCH_SIZE = 16
    START_LR = 0.01
    MAX_LR = 0.01
    MIN_LR = 0.01
    LR_RAMP = 0 # epochs
    LR_SUSTAIN = 0 #epochs
    LR_DECAY = 1
    
elif strategy.num_replicas_in_sync == 8: # single TPU
    # This achieves 80% accuracy on a TPU v3-8 (final loss 0.44)
    BATCH_SIZE = 16 * strategy.num_replicas_in_sync # use 32 on TPUv3
    VALIDATION_BATCH_SIZE = 256
    START_LR = 0.01
    MAX_LR = 0.01 * strategy.num_replicas_in_sync
    MIN_LR = 0.001
    LR_RAMP = 0 # epochs
    LR_SUSTAIN = 13 # epochs
    LR_DECAY = .95

else: # TPU pod
    # This achieves 80% accuracy on a TPU v2-32 pod (final loss 0.54)
    BATCH_SIZE = 16 * strategy.num_replicas_in_sync  # Gobal batch size.
    VALIDATION_BATCH_SIZE = 256
    START_LR = 0.06
    MAX_LR = 0.012 * strategy.num_replicas_in_sync
    MIN_LR = 0.01
    LR_RAMP = 5 # epochs
    LR_SUSTAIN = 8 # epochs
    LR_DECAY = 0.95

CLASSES = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'] # do not change, maps to the labels in the data (folder names)

IMAGE_SIZE = [331, 331] # supported images sizes: 192x192, 331x331, 512,512
                        # make sure you load the appropriate dataset on the next line
#GCS_PATTERN = 'gs://flowers-public/tfrecords-jpeg-192x192-2/*.tfrec'
GCS_PATTERN = 'gs://flowers-public/tfrecords-jpeg-331x331/*.tfrec'
#GCS_PATTERN = 'gs://flowers-public/tfrecords-jpeg-512x512/*.tfrec'
VALIDATION_SPLIT = 0.19

# in custom training loop training you
# need an object to hold the epoch value
class LRSchedule():
    def __init__(self):
        self.epoch = 0
        
    def set_epoch(self, epoch):
        self.epoch = epoch

    @staticmethod
    def lrfn(epoch, start_lr, max_lr, min_lr, rampup_epochs, sustain_epochs, exp_decay):
        if epoch < rampup_epochs:  # linear ramp from start_lr to max_lr
            lr = (max_lr - start_lr)/rampup_epochs * epoch + start_lr
        elif epoch < rampup_epochs + sustain_epochs:  # constant ar max_lr
            lr = max_lr
        else:  # exponential decay from max_lr to min_lr
            lr = (max_lr - min_lr) * exp_decay**(epoch-rampup_epochs-sustain_epochs) + min_lr
        return lr
    
    @staticmethod
    def default_lrfn(epoch):
        return LRSchedule.lrfn(epoch, START_LR, MAX_LR, MIN_LR, LR_RAMP, LR_SUSTAIN, LR_DECAY)
    
    def lr(self):
        return self.default_lrfn(self.epoch)

print("Learning rate schedule:")
rng = [i for i in range(EPOCHS)]
plt.plot(rng, [LRSchedule.default_lrfn(x) for x in rng])
plt.show()


Learning rate schedule:

In [4]:
#@title display utilities [RUN ME]

def dataset_to_numpy_util(dataset, N):
  dataset = dataset.batch(N)
  
  if tf.executing_eagerly():
    # In eager mode, iterate in the Datset directly.
    for images, labels in dataset:
      numpy_images = images.numpy()
      numpy_labels = labels.numpy()
      break;
      
  else: # In non-eager mode, must get the TF note that 
        # yields the nextitem and run it in a tf.Session.
    get_next_item = dataset.make_one_shot_iterator().get_next()
    with tf.Session() as ses:
      numpy_images, numpy_labels = ses.run(get_next_item)

  return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
  label = np.argmax(label, axis=-1)  # one-hot to class number
  correct_label = np.argmax(correct_label, axis=-1) # one-hot to class number
  correct = (label == correct_label)
  return "{} [{}{}{}]".format(CLASSES[label], str(correct), ', shoud be ' if not correct else '',
                              CLASSES[correct_label] if not correct else ''), correct

def display_one_flower(image, title, subplot, red=False):
    plt.subplot(subplot)
    plt.axis('off')
    plt.imshow(image)
    plt.title(title, fontsize=16, color='red' if red else 'black')
    return subplot+1
  
def display_9_images_from_dataset(dataset):
  subplot=331
  plt.figure(figsize=(13,13))
  images, labels = dataset_to_numpy_util(dataset, 9)
  for i, image in enumerate(images):
    title = CLASSES[np.argmax(labels[i], axis=-1)]
    subplot = display_one_flower(image, title, subplot)
    if i >= 8:
      break;
              
  plt.tight_layout()
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
  plt.show()
  
def display_9_images_with_predictions(images, predictions, labels):
  subplot=331
  plt.figure(figsize=(13,13))
  for i, image in enumerate(images):
    title, correct = title_from_label_and_target(predictions[i], labels[i])
    subplot = display_one_flower(image, title, subplot, not correct)
    if i >= 8:
      break;
              
  plt.tight_layout()
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
  plt.show()
  
def display_training_curves(training, validation, title, subplot):
  if subplot%10==1: # set up the subplots on the first call
    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
    plt.tight_layout()
  ax = plt.subplot(subplot)
  ax.set_facecolor('#F8F8F8')
  ax.plot(training)
  ax.plot(validation)
  ax.set_title('model '+ title)
  ax.set_ylabel(title)
  #ax.set_ylim(0.28,1.05)
  ax.set_xlabel('epoch')
  ax.legend(['train', 'valid.'])

Read images and labels from TFRecords


In [5]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

def count_data_items(filenames):
    # trick: the number of data items is written in the name of
    # the .tfrec files a flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

def data_augment(image, one_hot_class):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_saturation(image, 0, 2)
    return image, one_hot_class

def read_tfrecord(example):
    features = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means scalar
        "one_hot_class": tf.io.VarLenFeature(tf.float32),
    }
    example = tf.io.parse_single_example(example, features)
    image = tf.image.decode_jpeg(example['image'], channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # force the image size so that the shape of the tensor is known to Tensorflow
    class_label = tf.cast(example['class'], tf.int32)
    one_hot_class = tf.sparse.to_dense(example['one_hot_class'])
    one_hot_class = tf.reshape(one_hot_class, [5])
    return image, one_hot_class

def load_dataset(filenames):
    # read from TFRecords. For optimal performance, use TFRecordDataset with
    # num_parallel_calls=AUTOTUNE to read from multiple TFRecord files at once
    # band set the option experimental_deterministic = False
    # to allow order-altering optimizations.

    opt = tf.data.Options()
    opt.experimental_deterministic = False

    dataset = tf.data.Dataset.from_tensor_slices(filenames).with_options(opt)
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=16) # can be AUTOTUNE in TF 2.1
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    return dataset

def batch_dataset(filenames, batch_size, train):
    dataset = load_dataset(filenames)
    n = count_data_items(filenames)
    
    if train:
        dataset = dataset.repeat() # training dataset must repeat
        dataset = dataset.map(data_augment, num_parallel_calls=AUTOTUNE)
        dataset = dataset.shuffle(2048)
    else:
        # usually fewer validation files than workers so disable FILE auto-sharding on validation
        if strategy.num_replicas_in_sync > 1: # option not useful if there is no sharding (not harmful either)
            opt = tf.data.Options()
            opt.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
            dataset = dataset.with_options(opt)
        # validation dataset does not need to repeat
        # also no need to shuffle or apply data augmentation
    if train:
        dataset = dataset.batch(batch_size)
    else:
        # little wrinkle: drop_remainder is NOT necessary but validation on the last
        # partial batch sometimes returns a "nan" loss (probably a bug). You can remove
        # this if you do not care about the validatoin loss.
        dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(AUTOTUNE) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset, n//batch_size

def get_training_dataset(filenames):
    dataset, steps = batch_dataset(filenames, BATCH_SIZE, train=True)
    return dataset, steps

def get_validation_dataset(filenames):
    dataset, steps = batch_dataset(filenames, VALIDATION_BATCH_SIZE, train=False)
    return dataset, steps

In [6]:
# instantiate datasets
filenames = tf.io.gfile.glob(GCS_PATTERN)
split = len(filenames) - int(len(filenames) * VALIDATION_SPLIT)
train_filenames = filenames[:split]
valid_filenames = filenames[split:]

training_dataset, steps_per_epoch = get_training_dataset(train_filenames)
validation_dataset, validation_steps = get_validation_dataset(valid_filenames)

print("TRAINING   IMAGES: ", count_data_items(train_filenames), ", STEPS PER EPOCH: ", steps_per_epoch)
print("VALIDATION IMAGES: ", count_data_items(valid_filenames), ", STEPS PER EPOCH: ", validation_steps)

# numpy data to test predictions
some_flowers, some_labels = dataset_to_numpy_util(load_dataset(valid_filenames), 160)


TRAINING   IMAGES:  2990 , STEPS PER EPOCH:  11
VALIDATION IMAGES:  680 , STEPS PER EPOCH:  2

In [7]:
display_9_images_from_dataset(load_dataset(train_filenames))


The model: squeezenet with 12 layers


In [8]:
def create_model():
    bnmomemtum=0.9 # with only a handful of batches per epoch, the batch norm running average period must be lowered
    def fire(x, squeeze, expand):
        y  = tf.keras.layers.Conv2D(filters=squeeze, kernel_size=1, activation=None, padding='same', use_bias=False)(x)
        y = tf.keras.layers.BatchNormalization(momentum=bnmomemtum, scale=False, center=True)(y)
        y = tf.keras.layers.Activation('relu')(y)
        y1 = tf.keras.layers.Conv2D(filters=expand//2, kernel_size=1, activation=None, padding='same', use_bias=False)(y)
        y1 = tf.keras.layers.BatchNormalization(momentum=bnmomemtum, scale=False, center=True)(y1)
        y1 = tf.keras.layers.Activation('relu')(y1)
        y3 = tf.keras.layers.Conv2D(filters=expand//2, kernel_size=3, activation=None, padding='same', use_bias=False)(y)
        y3 = tf.keras.layers.BatchNormalization(momentum=bnmomemtum, scale=False, center=True)(y3)
        y3 = tf.keras.layers.Activation('relu')(y3)
        return tf.keras.layers.concatenate([y1, y3])

    def fire_module(squeeze, expand):
        return lambda x: fire(x, squeeze, expand)

    x = tf.keras.layers.Input(shape=(*IMAGE_SIZE, 3)) # input is 331x331 pixels RGB
    y = tf.keras.layers.Conv2D(kernel_size=3, filters=32, padding='same', use_bias=True, activation='relu')(x)
    y = tf.keras.layers.BatchNormalization(momentum=bnmomemtum)(y)
    y = fire_module(24, 48)(y)
    y = tf.keras.layers.MaxPooling2D(pool_size=2)(y)
    y = fire_module(48, 96)(y)
    y = tf.keras.layers.MaxPooling2D(pool_size=2)(y)
    y = fire_module(64, 128)(y)
    y = tf.keras.layers.MaxPooling2D(pool_size=2)(y)
    y = fire_module(48, 96)(y)
    y = tf.keras.layers.MaxPooling2D(pool_size=2)(y)
    y = fire_module(24, 48)(y)
    y = tf.keras.layers.GlobalAveragePooling2D()(y)
    y = tf.keras.layers.Dropout(0.4)(y)
    y = tf.keras.layers.Dense(5, activation='softmax')(y)
    return tf.keras.Model(x, y)

Instantiate all objects in the strategy scope


In [9]:
with strategy.scope():
    model = create_model()
    
    lr_schedule = LRSchedule()
    optimizer = tf.keras.optimizers.SGD(nesterov=True, momentum=0.9, learning_rate = lr_schedule.lr)
    train_accuracy = tf.keras.metrics.CategoricalAccuracy()
    valid_accuracy = tf.keras.metrics.CategoricalAccuracy()
    loss_fn = lambda labels, probabilities: tf.reduce_mean(tf.keras.losses.categorical_crossentropy(labels, probabilities))
        
    model.summary()


Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 331, 331, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 331, 331, 32) 896         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 331, 331, 32) 128         conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 331, 331, 24) 768         batch_normalization[0][0]        
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 331, 331, 24) 72          conv2d_1[0][0]                   
__________________________________________________________________________________________________
activation (Activation)         (None, 331, 331, 24) 0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 331, 331, 24) 576         activation[0][0]                 
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 331, 331, 24) 5184        activation[0][0]                 
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 331, 331, 24) 72          conv2d_2[0][0]                   
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 331, 331, 24) 72          conv2d_3[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 331, 331, 24) 0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 331, 331, 24) 0           batch_normalization_3[0][0]      
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 331, 331, 48) 0           activation_1[0][0]               
                                                                 activation_2[0][0]               
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 165, 165, 48) 0           concatenate[0][0]                
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 165, 165, 48) 2304        max_pooling2d[0][0]              
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 165, 165, 48) 144         conv2d_4[0][0]                   
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 165, 165, 48) 0           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 165, 165, 48) 2304        activation_3[0][0]               
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 165, 165, 48) 20736       activation_3[0][0]               
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 165, 165, 48) 144         conv2d_5[0][0]                   
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 165, 165, 48) 144         conv2d_6[0][0]                   
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 165, 165, 48) 0           batch_normalization_5[0][0]      
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 165, 165, 48) 0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 165, 165, 96) 0           activation_4[0][0]               
                                                                 activation_5[0][0]               
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 82, 82, 96)   0           concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 82, 82, 64)   6144        max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 82, 82, 64)   192         conv2d_7[0][0]                   
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 82, 82, 64)   0           batch_normalization_7[0][0]      
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 82, 82, 64)   4096        activation_6[0][0]               
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 82, 82, 64)   36864       activation_6[0][0]               
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 82, 82, 64)   192         conv2d_8[0][0]                   
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 82, 82, 64)   192         conv2d_9[0][0]                   
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 82, 82, 64)   0           batch_normalization_8[0][0]      
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 82, 82, 64)   0           batch_normalization_9[0][0]      
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 82, 82, 128)  0           activation_7[0][0]               
                                                                 activation_8[0][0]               
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 41, 41, 128)  0           concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 41, 41, 48)   6144        max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 41, 41, 48)   144         conv2d_10[0][0]                  
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 41, 41, 48)   0           batch_normalization_10[0][0]     
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 41, 41, 48)   2304        activation_9[0][0]               
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 41, 41, 48)   20736       activation_9[0][0]               
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 41, 41, 48)   144         conv2d_11[0][0]                  
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 41, 41, 48)   144         conv2d_12[0][0]                  
__________________________________________________________________________________________________
activation_10 (Activation)      (None, 41, 41, 48)   0           batch_normalization_11[0][0]     
__________________________________________________________________________________________________
activation_11 (Activation)      (None, 41, 41, 48)   0           batch_normalization_12[0][0]     
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 41, 41, 96)   0           activation_10[0][0]              
                                                                 activation_11[0][0]              
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 20, 20, 96)   0           concatenate_3[0][0]              
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 20, 20, 24)   2304        max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 20, 20, 24)   72          conv2d_13[0][0]                  
__________________________________________________________________________________________________
activation_12 (Activation)      (None, 20, 20, 24)   0           batch_normalization_13[0][0]     
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 20, 20, 24)   576         activation_12[0][0]              
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 20, 20, 24)   5184        activation_12[0][0]              
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 20, 20, 24)   72          conv2d_14[0][0]                  
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 20, 20, 24)   72          conv2d_15[0][0]                  
__________________________________________________________________________________________________
activation_13 (Activation)      (None, 20, 20, 24)   0           batch_normalization_14[0][0]     
__________________________________________________________________________________________________
activation_14 (Activation)      (None, 20, 20, 24)   0           batch_normalization_15[0][0]     
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 20, 20, 48)   0           activation_13[0][0]              
                                                                 activation_14[0][0]              
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 48)           0           concatenate_4[0][0]              
__________________________________________________________________________________________________
dropout (Dropout)               (None, 48)           0           global_average_pooling2d[0][0]   
__________________________________________________________________________________________________
dense (Dense)                   (None, 5)            245         dropout[0][0]                    
==================================================================================================
Total params: 119,365
Trainable params: 118,053
Non-trainable params: 1,312
__________________________________________________________________________________________________

Step functions


In [10]:
@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        probabilities = model(images, training=True)
        loss = loss_fn(labels, probabilities)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    train_accuracy.update_state(labels, probabilities)
    return loss

@tf.function
def valid_step(images, labels):
    probabilities = model(images, training=False)
    loss = loss_fn(labels, probabilities)
    valid_accuracy.update_state(labels, probabilities)
    return loss

Custom training loop


In [11]:
# distribute the datset according to the strategy
train_dist_ds = strategy.experimental_distribute_dataset(training_dataset)
valid_dist_ds = strategy.experimental_distribute_dataset(validation_dataset)

print("Steps per epoch: ", steps_per_epoch)

epoch = -1
train_losses=[]
for step, (images, labels) in enumerate(train_dist_ds):

    # batch losses from all replicas
    loss = strategy.experimental_run_v2(train_step, args=(images, labels))
    # reduced to a single number both across replicas and across the bacth size
    loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, loss, axis=None)
    # or use loss.values to access the raw set of losses returned from all replicas
    # (the official API is strategy.experiemental_local_results(loss), does the same as loss.values)

    # validation run at the end of each epoch
    if (step // steps_per_epoch) > epoch:
        valid_loss = []
        for image, labels in valid_dist_ds:
            batch_loss = strategy.experimental_run_v2(valid_step, args=(image, labels)) # just one batch
            batch_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, batch_loss, axis=None)
            valid_loss.append(batch_loss.numpy())
        valid_loss = np.mean(valid_loss)

        epoch = step // steps_per_epoch
        lr_schedule.set_epoch(epoch)
        print('\nEPOCH: ', epoch)
        print('loss: ', loss.numpy(), ', accuracy_: ', train_accuracy.result().numpy(), ' , val_loss: ', valid_loss, ' , val_acc_: ', valid_accuracy.result().numpy(), ' , lr: ', lr_schedule.lr())
        train_accuracy.reset_states()
        valid_accuracy.reset_states()
        if epoch >= EPOCHS:
            break
            
    train_losses.append(loss)
    print('=', end='')


Steps per epoch:  11

EPOCH:  0
loss:  1.8574139 , accuracy_:  0.19140625  , val_loss:  1.6670775  , val_acc_:  0.21875  , lr:  0.08
===========
EPOCH:  1
loss:  1.370311 , accuracy_:  0.33132103  , val_loss:  62.40265  , val_acc_:  0.203125  , lr:  0.08
===========
EPOCH:  2
loss:  1.3508414 , accuracy_:  0.42294034  , val_loss:  10.712177  , val_acc_:  0.3515625  , lr:  0.08
===========
EPOCH:  3
loss:  1.2815936 , accuracy_:  0.45419034  , val_loss:  2.663825  , val_acc_:  0.3984375  , lr:  0.08
===========
EPOCH:  4
loss:  1.2429017 , accuracy_:  0.50177556  , val_loss:  1.9874109  , val_acc_:  0.45898438  , lr:  0.08
===========
EPOCH:  5
loss:  1.1355323 , accuracy_:  0.51740056  , val_loss:  1.2704945  , val_acc_:  0.5332031  , lr:  0.08
===========
EPOCH:  6
loss:  1.1954806 , accuracy_:  0.53728694  , val_loss:  1.8864973  , val_acc_:  0.453125  , lr:  0.08
===========
EPOCH:  7
loss:  1.150807 , accuracy_:  0.5465199  , val_loss:  1.120469  , val_acc_:  0.5957031  , lr:  0.08
===========
EPOCH:  8
loss:  1.0735602 , accuracy_:  0.559304  , val_loss:  1.8047378  , val_acc_:  0.53125  , lr:  0.08
===========
EPOCH:  9
loss:  1.0801067 , accuracy_:  0.5628551  , val_loss:  1.0923085  , val_acc_:  0.5957031  , lr:  0.08
===========
EPOCH:  10
loss:  0.9859637 , accuracy_:  0.57990056  , val_loss:  1.5855552  , val_acc_:  0.5390625  , lr:  0.08
===========
EPOCH:  11
loss:  1.025581 , accuracy_:  0.5827415  , val_loss:  1.0143851  , val_acc_:  0.6191406  , lr:  0.08
===========
EPOCH:  12
loss:  0.99369746 , accuracy_:  0.60546875  , val_loss:  1.1153538  , val_acc_:  0.6230469  , lr:  0.08
===========
EPOCH:  13
loss:  0.94948137 , accuracy_:  0.6090199  , val_loss:  0.92257625  , val_acc_:  0.6171875  , lr:  0.08
===========
EPOCH:  14
loss:  0.9516282 , accuracy_:  0.6147017  , val_loss:  0.9634201  , val_acc_:  0.61328125  , lr:  0.07604999999999999
===========
EPOCH:  15
loss:  0.9929214 , accuracy_:  0.6147017  , val_loss:  0.91528195  , val_acc_:  0.6542969  , lr:  0.0722975
===========
EPOCH:  16
loss:  0.85034007 , accuracy_:  0.6292614  , val_loss:  1.1177208  , val_acc_:  0.5625  , lr:  0.06873262499999999
===========
EPOCH:  17
loss:  0.92602944 , accuracy_:  0.6161222  , val_loss:  0.88137645  , val_acc_:  0.6503906  , lr:  0.06534599375
===========
EPOCH:  18
loss:  0.88779795 , accuracy_:  0.64666194  , val_loss:  0.8969081  , val_acc_:  0.6582031  , lr:  0.06212869406249998
===========
EPOCH:  19
loss:  0.9948027 , accuracy_:  0.6558949  , val_loss:  1.1083667  , val_acc_:  0.6152344  , lr:  0.05907225935937498
===========
EPOCH:  20
loss:  1.0186219 , accuracy_:  0.6352983  , val_loss:  0.9336281  , val_acc_:  0.6328125  , lr:  0.05616864639140623
===========
EPOCH:  21
loss:  1.0545349 , accuracy_:  0.6296165  , val_loss:  1.073566  , val_acc_:  0.5839844  , lr:  0.05341021407183592
===========
EPOCH:  22
loss:  0.9266959 , accuracy_:  0.6395597  , val_loss:  0.8530734  , val_acc_:  0.6640625  , lr:  0.05078970336824412
===========
EPOCH:  23
loss:  0.8740611 , accuracy_:  0.65234375  , val_loss:  1.0662813  , val_acc_:  0.6269531  , lr:  0.048300218199831914
===========
EPOCH:  24
loss:  0.8150981 , accuracy_:  0.66690344  , val_loss:  0.87279373  , val_acc_:  0.6699219  , lr:  0.04593520728984032
===========
EPOCH:  25
loss:  0.87887394 , accuracy_:  0.63778406  , val_loss:  0.86162126  , val_acc_:  0.6621094  , lr:  0.0436884469253483
===========
EPOCH:  26
loss:  0.79184717 , accuracy_:  0.6828835  , val_loss:  0.853451  , val_acc_:  0.6777344  , lr:  0.04155402457908088
===========
EPOCH:  27
loss:  0.85961646 , accuracy_:  0.6651278  , val_loss:  0.8895781  , val_acc_:  0.66796875  , lr:  0.03952632335012683
===========
EPOCH:  28
loss:  0.86511034 , accuracy_:  0.66015625  , val_loss:  1.0709796  , val_acc_:  0.6484375  , lr:  0.03760000718262049
===========
EPOCH:  29
loss:  0.83482003 , accuracy_:  0.6772017  , val_loss:  0.8657106  , val_acc_:  0.6816406  , lr:  0.035770006823489464
===========
EPOCH:  30
loss:  0.8544643 , accuracy_:  0.6715199  , val_loss:  0.95447123  , val_acc_:  0.6738281  , lr:  0.03403150648231499
===========
EPOCH:  31
loss:  0.8925582 , accuracy_:  0.671875  , val_loss:  0.89564705  , val_acc_:  0.6738281  , lr:  0.03237993115819924
===========
EPOCH:  32
loss:  0.8363841 , accuracy_:  0.6875  , val_loss:  0.8471222  , val_acc_:  0.6875  , lr:  0.030810934600289275
===========
EPOCH:  33
loss:  0.8462085 , accuracy_:  0.69602275  , val_loss:  0.83931756  , val_acc_:  0.6777344  , lr:  0.02932038787027481
===========
EPOCH:  34
loss:  0.78276515 , accuracy_:  0.6839489  , val_loss:  0.82418346  , val_acc_:  0.6933594  , lr:  0.02790436847676107
===========
EPOCH:  35
loss:  0.8355361 , accuracy_:  0.6899858  , val_loss:  0.9210977  , val_acc_:  0.6738281  , lr:  0.026559150052923013
===========
EPOCH:  36
loss:  0.80718756 , accuracy_:  0.70454544  , val_loss:  0.91811264  , val_acc_:  0.6699219  , lr:  0.025281192550276863
===========
EPOCH:  37
loss:  0.70264596 , accuracy_:  0.6974432  , val_loss:  0.80253756  , val_acc_:  0.6933594  , lr:  0.02406713292276302
===========
EPOCH:  38
loss:  0.7115408 , accuracy_:  0.7088068  , val_loss:  0.8474855  , val_acc_:  0.6894531  , lr:  0.02291377627662487
===========
EPOCH:  39
loss:  0.6314294 , accuracy_:  0.7198153  , val_loss:  0.8814893  , val_acc_:  0.70703125  , lr:  0.021818087462793623
===========
EPOCH:  40
loss:  0.7165874 , accuracy_:  0.73259944  , val_loss:  0.7393073  , val_acc_:  0.72265625  , lr:  0.02077718308965394
===========
EPOCH:  41
loss:  0.7348884 , accuracy_:  0.7286932  , val_loss:  0.96582544  , val_acc_:  0.6953125  , lr:  0.019788323935171243
===========
EPOCH:  42
loss:  0.785071 , accuracy_:  0.7141335  , val_loss:  0.89993995  , val_acc_:  0.6230469  , lr:  0.01884890773841268
===========
EPOCH:  43
loss:  0.7332948 , accuracy_:  0.7389915  , val_loss:  0.8410876  , val_acc_:  0.6972656  , lr:  0.017956462351492047
===========
EPOCH:  44
loss:  0.67525405 , accuracy_:  0.73828125  , val_loss:  0.8138516  , val_acc_:  0.6933594  , lr:  0.017108639233917443
===========
EPOCH:  45
loss:  0.83102876 , accuracy_:  0.7389915  , val_loss:  0.8350415  , val_acc_:  0.71875  , lr:  0.01630320727222157
===========
EPOCH:  46
loss:  0.63529277 , accuracy_:  0.73934656  , val_loss:  0.8200476  , val_acc_:  0.7167969  , lr:  0.015538046908610489
===========
EPOCH:  47
loss:  0.74751514 , accuracy_:  0.71022725  , val_loss:  0.94797283  , val_acc_:  0.68359375  , lr:  0.014811144563179963
===========
EPOCH:  48
loss:  0.7188057 , accuracy_:  0.7340199  , val_loss:  0.73836386  , val_acc_:  0.73046875  , lr:  0.014120587335020966
===========
EPOCH:  49
loss:  0.6821742 , accuracy_:  0.75390625  , val_loss:  0.8478819  , val_acc_:  0.72265625  , lr:  0.013464557968269918
===========
EPOCH:  50
loss:  0.6849743 , accuracy_:  0.74360794  , val_loss:  1.0432442  , val_acc_:  0.66796875  , lr:  0.01284133006985642
===========
EPOCH:  51
loss:  0.7552935 , accuracy_:  0.75142044  , val_loss:  0.80256677  , val_acc_:  0.7207031  , lr:  0.012249263566363598
===========
EPOCH:  52
loss:  0.68016076 , accuracy_:  0.7524858  , val_loss:  0.7880584  , val_acc_:  0.72265625  , lr:  0.01168680038804542
===========
EPOCH:  53
loss:  0.89600945 , accuracy_:  0.73615056  , val_loss:  0.8785901  , val_acc_:  0.6777344  , lr:  0.011152460368643147
===========
EPOCH:  54
loss:  0.8385702 , accuracy_:  0.7411222  , val_loss:  0.8847884  , val_acc_:  0.6640625  , lr:  0.01064483735021099
===========
EPOCH:  55
loss:  0.58544314 , accuracy_:  0.75958806  , val_loss:  0.78721035  , val_acc_:  0.73046875  , lr:  0.010162595482700439
===========
EPOCH:  56
loss:  0.6134448 , accuracy_:  0.74076706  , val_loss:  0.88842964  , val_acc_:  0.73046875  , lr:  0.009704465708565417
===========
EPOCH:  57
loss:  0.6251881 , accuracy_:  0.7659801  , val_loss:  0.7727301  , val_acc_:  0.7011719  , lr:  0.009269242423137147
===========
EPOCH:  58
loss:  0.6687313 , accuracy_:  0.75071025  , val_loss:  0.73691225  , val_acc_:  0.74609375  , lr:  0.008855780301980289
===========
EPOCH:  59
loss:  0.64336467 , accuracy_:  0.77166194  , val_loss:  0.7323053  , val_acc_:  0.76171875  , lr:  0.008462991286881274
===========
EPOCH:  60
loss:  0.6133585 , accuracy_:  0.7723722  , val_loss:  0.7338977  , val_acc_:  0.73828125  , lr:  0.008089841722537211

In [12]:
print("Detailed training loss:")
plt.plot(train_losses)
plt.show()


Detailed training loss:

Predictions

(not distributed)


In [17]:
# randomize the input so that you can execute multiple times to change results
permutation = np.random.permutation(8*20)
some_flowers, some_labels = (some_flowers[permutation], some_labels[permutation])

predictions = model.predict(some_flowers, batch_size=16)
  
print(np.array(CLASSES)[np.argmax(predictions, axis=-1)].tolist())


['sunflowers', 'roses', 'dandelion', 'dandelion', 'sunflowers', 'daisy', 'dandelion', 'sunflowers', 'tulips', 'dandelion', 'sunflowers', 'dandelion', 'roses', 'sunflowers', 'roses', 'dandelion', 'daisy', 'tulips', 'daisy', 'dandelion', 'daisy', 'tulips', 'dandelion', 'daisy', 'dandelion', 'dandelion', 'sunflowers', 'daisy', 'tulips', 'dandelion', 'tulips', 'dandelion', 'dandelion', 'dandelion', 'dandelion', 'dandelion', 'sunflowers', 'tulips', 'roses', 'roses', 'roses', 'daisy', 'dandelion', 'dandelion', 'daisy', 'roses', 'roses', 'dandelion', 'tulips', 'dandelion', 'tulips', 'roses', 'roses', 'dandelion', 'tulips', 'daisy', 'sunflowers', 'sunflowers', 'tulips', 'tulips', 'tulips', 'dandelion', 'dandelion', 'daisy', 'dandelion', 'dandelion', 'sunflowers', 'sunflowers', 'dandelion', 'dandelion', 'dandelion', 'dandelion', 'dandelion', 'dandelion', 'dandelion', 'tulips', 'sunflowers', 'sunflowers', 'roses', 'dandelion', 'roses', 'tulips', 'sunflowers', 'tulips', 'sunflowers', 'dandelion', 'dandelion', 'dandelion', 'tulips', 'dandelion', 'dandelion', 'dandelion', 'daisy', 'sunflowers', 'dandelion', 'dandelion', 'dandelion', 'dandelion', 'dandelion', 'roses', 'dandelion', 'dandelion', 'dandelion', 'daisy', 'sunflowers', 'roses', 'dandelion', 'sunflowers', 'roses', 'daisy', 'tulips', 'sunflowers', 'roses', 'sunflowers', 'dandelion', 'daisy', 'sunflowers', 'dandelion', 'dandelion', 'tulips', 'roses', 'tulips', 'tulips', 'tulips', 'dandelion', 'daisy', 'sunflowers', 'daisy', 'dandelion', 'dandelion', 'tulips', 'roses', 'dandelion', 'daisy', 'tulips', 'roses', 'sunflowers', 'daisy', 'roses', 'dandelion', 'sunflowers', 'tulips', 'tulips', 'dandelion', 'dandelion', 'dandelion', 'roses', 'tulips', 'roses', 'daisy', 'sunflowers', 'tulips', 'dandelion', 'dandelion', 'dandelion', 'tulips', 'dandelion', 'roses', 'sunflowers', 'dandelion']

In [18]:
display_9_images_with_predictions(some_flowers, predictions, some_labels)


License


author: Martin Gorner
twitter: @martin_gorner


Copyright 2019 Google LLC

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


This is not an official Google product but sample code provided for an educational purpose