This sample shows how to use the distribution strategy APIs when writing a custom training loop on TPU:
TPUStrategy()
with strategy.scope(): ...
strategy.experimental_distribute_dataset(ds)
strategy.experimental_run_v2(step_fn)
strategy.reduce(...)
In [1]:
import re, sys
if 'google.colab' in sys.modules: # Colab-only Tensorflow version selector
%tensorflow_version 2.x
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
print("Tensorflow version " + tf.__version__)
In [2]:
# Detect hardware, return appropriate distribution strategy
try:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
tpu = None
if tpu:
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
strategy = tf.distribute.get_strategy()
print("REPLICAS: ", strategy.num_replicas_in_sync)
In [3]:
EPOCHS = 60
if strategy.num_replicas_in_sync == 1: # GPU
# This achieves 80% accuracy on a GPU (final loss 0.45)
BATCH_SIZE = 16
VALIDATION_BATCH_SIZE = 16
START_LR = 0.01
MAX_LR = 0.01
MIN_LR = 0.01
LR_RAMP = 0 # epochs
LR_SUSTAIN = 0 #epochs
LR_DECAY = 1
elif strategy.num_replicas_in_sync == 8: # single TPU
# This achieves 80% accuracy on a TPU v3-8 (final loss 0.44)
BATCH_SIZE = 16 * strategy.num_replicas_in_sync # use 32 on TPUv3
VALIDATION_BATCH_SIZE = 256
START_LR = 0.01
MAX_LR = 0.01 * strategy.num_replicas_in_sync
MIN_LR = 0.001
LR_RAMP = 0 # epochs
LR_SUSTAIN = 13 # epochs
LR_DECAY = .95
else: # TPU pod
# This achieves 80% accuracy on a TPU v2-32 pod (final loss 0.54)
BATCH_SIZE = 16 * strategy.num_replicas_in_sync # Gobal batch size.
VALIDATION_BATCH_SIZE = 256
START_LR = 0.06
MAX_LR = 0.012 * strategy.num_replicas_in_sync
MIN_LR = 0.01
LR_RAMP = 5 # epochs
LR_SUSTAIN = 8 # epochs
LR_DECAY = 0.95
CLASSES = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'] # do not change, maps to the labels in the data (folder names)
IMAGE_SIZE = [331, 331] # supported images sizes: 192x192, 331x331, 512,512
# make sure you load the appropriate dataset on the next line
#GCS_PATTERN = 'gs://flowers-public/tfrecords-jpeg-192x192-2/*.tfrec'
GCS_PATTERN = 'gs://flowers-public/tfrecords-jpeg-331x331/*.tfrec'
#GCS_PATTERN = 'gs://flowers-public/tfrecords-jpeg-512x512/*.tfrec'
VALIDATION_SPLIT = 0.19
# in custom training loop training you
# need an object to hold the epoch value
class LRSchedule():
def __init__(self):
self.epoch = 0
def set_epoch(self, epoch):
self.epoch = epoch
@staticmethod
def lrfn(epoch, start_lr, max_lr, min_lr, rampup_epochs, sustain_epochs, exp_decay):
if epoch < rampup_epochs: # linear ramp from start_lr to max_lr
lr = (max_lr - start_lr)/rampup_epochs * epoch + start_lr
elif epoch < rampup_epochs + sustain_epochs: # constant ar max_lr
lr = max_lr
else: # exponential decay from max_lr to min_lr
lr = (max_lr - min_lr) * exp_decay**(epoch-rampup_epochs-sustain_epochs) + min_lr
return lr
@staticmethod
def default_lrfn(epoch):
return LRSchedule.lrfn(epoch, START_LR, MAX_LR, MIN_LR, LR_RAMP, LR_SUSTAIN, LR_DECAY)
def lr(self):
return self.default_lrfn(self.epoch)
print("Learning rate schedule:")
rng = [i for i in range(EPOCHS)]
plt.plot(rng, [LRSchedule.default_lrfn(x) for x in rng])
plt.show()
In [4]:
#@title display utilities [RUN ME]
def dataset_to_numpy_util(dataset, N):
dataset = dataset.batch(N)
if tf.executing_eagerly():
# In eager mode, iterate in the Datset directly.
for images, labels in dataset:
numpy_images = images.numpy()
numpy_labels = labels.numpy()
break;
else: # In non-eager mode, must get the TF note that
# yields the nextitem and run it in a tf.Session.
get_next_item = dataset.make_one_shot_iterator().get_next()
with tf.Session() as ses:
numpy_images, numpy_labels = ses.run(get_next_item)
return numpy_images, numpy_labels
def title_from_label_and_target(label, correct_label):
label = np.argmax(label, axis=-1) # one-hot to class number
correct_label = np.argmax(correct_label, axis=-1) # one-hot to class number
correct = (label == correct_label)
return "{} [{}{}{}]".format(CLASSES[label], str(correct), ', shoud be ' if not correct else '',
CLASSES[correct_label] if not correct else ''), correct
def display_one_flower(image, title, subplot, red=False):
plt.subplot(subplot)
plt.axis('off')
plt.imshow(image)
plt.title(title, fontsize=16, color='red' if red else 'black')
return subplot+1
def display_9_images_from_dataset(dataset):
subplot=331
plt.figure(figsize=(13,13))
images, labels = dataset_to_numpy_util(dataset, 9)
for i, image in enumerate(images):
title = CLASSES[np.argmax(labels[i], axis=-1)]
subplot = display_one_flower(image, title, subplot)
if i >= 8:
break;
plt.tight_layout()
plt.subplots_adjust(wspace=0.1, hspace=0.1)
plt.show()
def display_9_images_with_predictions(images, predictions, labels):
subplot=331
plt.figure(figsize=(13,13))
for i, image in enumerate(images):
title, correct = title_from_label_and_target(predictions[i], labels[i])
subplot = display_one_flower(image, title, subplot, not correct)
if i >= 8:
break;
plt.tight_layout()
plt.subplots_adjust(wspace=0.1, hspace=0.1)
plt.show()
def display_training_curves(training, validation, title, subplot):
if subplot%10==1: # set up the subplots on the first call
plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
plt.tight_layout()
ax = plt.subplot(subplot)
ax.set_facecolor('#F8F8F8')
ax.plot(training)
ax.plot(validation)
ax.set_title('model '+ title)
ax.set_ylabel(title)
#ax.set_ylim(0.28,1.05)
ax.set_xlabel('epoch')
ax.legend(['train', 'valid.'])
In [5]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
def count_data_items(filenames):
# trick: the number of data items is written in the name of
# the .tfrec files a flowers00-230.tfrec = 230 data items
n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
return np.sum(n)
def data_augment(image, one_hot_class):
image = tf.image.random_flip_left_right(image)
image = tf.image.random_saturation(image, 0, 2)
return image, one_hot_class
def read_tfrecord(example):
features = {
"image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
"class": tf.io.FixedLenFeature([], tf.int64), # shape [] means scalar
"one_hot_class": tf.io.VarLenFeature(tf.float32),
}
example = tf.io.parse_single_example(example, features)
image = tf.image.decode_jpeg(example['image'], channels=3)
image = tf.cast(image, tf.float32) / 255.0 # convert image to floats in [0, 1] range
image = tf.reshape(image, [*IMAGE_SIZE, 3]) # force the image size so that the shape of the tensor is known to Tensorflow
class_label = tf.cast(example['class'], tf.int32)
one_hot_class = tf.sparse.to_dense(example['one_hot_class'])
one_hot_class = tf.reshape(one_hot_class, [5])
return image, one_hot_class
def load_dataset(filenames):
# read from TFRecords. For optimal performance, use TFRecordDataset with
# num_parallel_calls=AUTOTUNE to read from multiple TFRecord files at once
# band set the option experimental_deterministic = False
# to allow order-altering optimizations.
opt = tf.data.Options()
opt.experimental_deterministic = False
dataset = tf.data.Dataset.from_tensor_slices(filenames).with_options(opt)
dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=16) # can be AUTOTUNE in TF 2.1
dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
return dataset
def batch_dataset(filenames, batch_size, train):
dataset = load_dataset(filenames)
n = count_data_items(filenames)
if train:
dataset = dataset.repeat() # training dataset must repeat
dataset = dataset.map(data_augment, num_parallel_calls=AUTOTUNE)
dataset = dataset.shuffle(2048)
else:
# usually fewer validation files than workers so disable FILE auto-sharding on validation
if strategy.num_replicas_in_sync > 1: # option not useful if there is no sharding (not harmful either)
opt = tf.data.Options()
opt.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
dataset = dataset.with_options(opt)
# validation dataset does not need to repeat
# also no need to shuffle or apply data augmentation
if train:
dataset = dataset.batch(batch_size)
else:
# little wrinkle: drop_remainder is NOT necessary but validation on the last
# partial batch sometimes returns a "nan" loss (probably a bug). You can remove
# this if you do not care about the validatoin loss.
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(AUTOTUNE) # prefetch next batch while training (autotune prefetch buffer size)
return dataset, n//batch_size
def get_training_dataset(filenames):
dataset, steps = batch_dataset(filenames, BATCH_SIZE, train=True)
return dataset, steps
def get_validation_dataset(filenames):
dataset, steps = batch_dataset(filenames, VALIDATION_BATCH_SIZE, train=False)
return dataset, steps
In [6]:
# instantiate datasets
filenames = tf.io.gfile.glob(GCS_PATTERN)
split = len(filenames) - int(len(filenames) * VALIDATION_SPLIT)
train_filenames = filenames[:split]
valid_filenames = filenames[split:]
training_dataset, steps_per_epoch = get_training_dataset(train_filenames)
validation_dataset, validation_steps = get_validation_dataset(valid_filenames)
print("TRAINING IMAGES: ", count_data_items(train_filenames), ", STEPS PER EPOCH: ", steps_per_epoch)
print("VALIDATION IMAGES: ", count_data_items(valid_filenames), ", STEPS PER EPOCH: ", validation_steps)
# numpy data to test predictions
some_flowers, some_labels = dataset_to_numpy_util(load_dataset(valid_filenames), 160)
In [7]:
display_9_images_from_dataset(load_dataset(train_filenames))
In [8]:
def create_model():
bnmomemtum=0.9 # with only a handful of batches per epoch, the batch norm running average period must be lowered
def fire(x, squeeze, expand):
y = tf.keras.layers.Conv2D(filters=squeeze, kernel_size=1, activation=None, padding='same', use_bias=False)(x)
y = tf.keras.layers.BatchNormalization(momentum=bnmomemtum, scale=False, center=True)(y)
y = tf.keras.layers.Activation('relu')(y)
y1 = tf.keras.layers.Conv2D(filters=expand//2, kernel_size=1, activation=None, padding='same', use_bias=False)(y)
y1 = tf.keras.layers.BatchNormalization(momentum=bnmomemtum, scale=False, center=True)(y1)
y1 = tf.keras.layers.Activation('relu')(y1)
y3 = tf.keras.layers.Conv2D(filters=expand//2, kernel_size=3, activation=None, padding='same', use_bias=False)(y)
y3 = tf.keras.layers.BatchNormalization(momentum=bnmomemtum, scale=False, center=True)(y3)
y3 = tf.keras.layers.Activation('relu')(y3)
return tf.keras.layers.concatenate([y1, y3])
def fire_module(squeeze, expand):
return lambda x: fire(x, squeeze, expand)
x = tf.keras.layers.Input(shape=(*IMAGE_SIZE, 3)) # input is 331x331 pixels RGB
y = tf.keras.layers.Conv2D(kernel_size=3, filters=32, padding='same', use_bias=True, activation='relu')(x)
y = tf.keras.layers.BatchNormalization(momentum=bnmomemtum)(y)
y = fire_module(24, 48)(y)
y = tf.keras.layers.MaxPooling2D(pool_size=2)(y)
y = fire_module(48, 96)(y)
y = tf.keras.layers.MaxPooling2D(pool_size=2)(y)
y = fire_module(64, 128)(y)
y = tf.keras.layers.MaxPooling2D(pool_size=2)(y)
y = fire_module(48, 96)(y)
y = tf.keras.layers.MaxPooling2D(pool_size=2)(y)
y = fire_module(24, 48)(y)
y = tf.keras.layers.GlobalAveragePooling2D()(y)
y = tf.keras.layers.Dropout(0.4)(y)
y = tf.keras.layers.Dense(5, activation='softmax')(y)
return tf.keras.Model(x, y)
In [9]:
with strategy.scope():
model = create_model()
lr_schedule = LRSchedule()
optimizer = tf.keras.optimizers.SGD(nesterov=True, momentum=0.9, learning_rate = lr_schedule.lr)
train_accuracy = tf.keras.metrics.CategoricalAccuracy()
valid_accuracy = tf.keras.metrics.CategoricalAccuracy()
loss_fn = lambda labels, probabilities: tf.reduce_mean(tf.keras.losses.categorical_crossentropy(labels, probabilities))
model.summary()
In [10]:
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
probabilities = model(images, training=True)
loss = loss_fn(labels, probabilities)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
train_accuracy.update_state(labels, probabilities)
return loss
@tf.function
def valid_step(images, labels):
probabilities = model(images, training=False)
loss = loss_fn(labels, probabilities)
valid_accuracy.update_state(labels, probabilities)
return loss
In [11]:
# distribute the datset according to the strategy
train_dist_ds = strategy.experimental_distribute_dataset(training_dataset)
valid_dist_ds = strategy.experimental_distribute_dataset(validation_dataset)
print("Steps per epoch: ", steps_per_epoch)
epoch = -1
train_losses=[]
for step, (images, labels) in enumerate(train_dist_ds):
# batch losses from all replicas
loss = strategy.experimental_run_v2(train_step, args=(images, labels))
# reduced to a single number both across replicas and across the bacth size
loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, loss, axis=None)
# or use loss.values to access the raw set of losses returned from all replicas
# (the official API is strategy.experiemental_local_results(loss), does the same as loss.values)
# validation run at the end of each epoch
if (step // steps_per_epoch) > epoch:
valid_loss = []
for image, labels in valid_dist_ds:
batch_loss = strategy.experimental_run_v2(valid_step, args=(image, labels)) # just one batch
batch_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, batch_loss, axis=None)
valid_loss.append(batch_loss.numpy())
valid_loss = np.mean(valid_loss)
epoch = step // steps_per_epoch
lr_schedule.set_epoch(epoch)
print('\nEPOCH: ', epoch)
print('loss: ', loss.numpy(), ', accuracy_: ', train_accuracy.result().numpy(), ' , val_loss: ', valid_loss, ' , val_acc_: ', valid_accuracy.result().numpy(), ' , lr: ', lr_schedule.lr())
train_accuracy.reset_states()
valid_accuracy.reset_states()
if epoch >= EPOCHS:
break
train_losses.append(loss)
print('=', end='')
In [12]:
print("Detailed training loss:")
plt.plot(train_losses)
plt.show()
In [17]:
# randomize the input so that you can execute multiple times to change results
permutation = np.random.permutation(8*20)
some_flowers, some_labels = (some_flowers[permutation], some_labels[permutation])
predictions = model.predict(some_flowers, batch_size=16)
print(np.array(CLASSES)[np.argmax(predictions, axis=-1)].tolist())
In [18]:
display_9_images_with_predictions(some_flowers, predictions, some_labels)
author: Martin Gorner
twitter: @martin_gorner
Copyright 2019 Google LLC
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
This is not an official Google product but sample code provided for an educational purpose