In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
By Jessica Yung and Joan Puigcerver
In this colab, we will show you how to load one of our BiT models (a ResNet50 trained on ImageNet-21k), use it out-of-the-box and fine-tune it on a dataset of flowers.
This colab accompanies our the TensorFlow blog post (TODO: link) and is based on the BigTransfer paper.
Our models can be found on TensorFlow Hub here. We also share code to fine-tune our models in TensorFlow2, JAX and PyTorch in our GitHub repository.
In [0]:
#@title Imports
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds
import time
from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
import os
import pathlib
In [3]:
#@title Construct imagenet logit-to-class-name dictionary (imagenet_int_to_str)
!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt
imagenet_int_to_str = {}
with open('ilsvrc2012_wordnet_lemmas.txt', 'r') as f:
for i in range(1000):
row = f.readline()
row = row.rstrip()
imagenet_int_to_str.update({i: row})
In [0]:
#@title tf_flowers label names (hidden)
tf_flowers_labels = ['dandelion', 'daisy', 'tulips', 'sunflowers', 'roses']
First, we will load a pre-trained BiT model. There are ten models you can choose from, spanning two upstream training datasets and five ResNet architectures.
In this tutorial, we will use a ResNet50x1 model trained on ImageNet-21k.
Models that output image features (pre-logit layer) can be found at
https://tfhub.dev/google/bit/m-{archi, e.g. r50x1}/1
whereas models that return outputs in the Imagenet (ILSVRC-2012) label space can be found at
https://tfhub.dev/google/bit/m-{archi, e.g. r50x1}/ilsvrc2012_classification/1
The architectures we have include R50x1, R50x3, R101x1, R101x3 and R152x4. The architectures are all in lowercase in the links.
So for example, if you want image features from a ResNet-50, you could use the model at https://tfhub.dev/google/bit/m-r50x1/1
. This is also the model we'll use in this tutorial.
In [0]:
# Load model into KerasLayer
model_url = "https://tfhub.dev/google/bit/m-r50x1/1"
module = hub.KerasLayer(model_url)
If you don’t yet have labels for your images (or just want to have some fun), you may be interested in using the model out-of-the-box, i.e. without fine-tuning it. For this, we will use a model fine-tuned on ImageNet so it has the interpretable ImageNet label space of 1k classes. Many common objects are not covered, but it gives a reasonable idea of what is in the image.
In [0]:
# Load model fine-tuned on ImageNet
model_url = "https://tfhub.dev/google/bit/m-r50x1/ilsvrc2012_classification/1"
imagenet_module = hub.KerasLayer(model_url)
Using the model is very simple:
logits = module(image)
Note that the BiT models take inputs of shape [?, ?, 3] (i.e. 3 colour channels) with values between 0 and 1.
In [0]:
#@title Helper functions for loading image (hidden)
def preprocess_image(image):
image = np.array(image)
# reshape into shape [batch_size, height, width, num_channels]
img_reshaped = tf.reshape(image, [1, image.shape[0], image.shape[1], image.shape[2]])
# Use `convert_image_dtype` to convert to floats in the [0,1] range.
image = tf.image.convert_image_dtype(img_reshaped, tf.float32)
return image
def load_image_from_url(url):
"""Returns an image with shape [1, height, width, num_channels]."""
response = requests.get(url)
image = Image.open(BytesIO(response.content))
image = preprocess_image(image)
return image
In [0]:
#@title Plotting helper functions (hidden)
#@markdown Credits to Xiaohua Zhai, Lucas Beyer and Alex Kolesnikov from Brain Zurich, Google Research
# Show the MAX_PREDS highest scoring labels:
MAX_PREDS = 5
# Do not show labels with lower score than this:
MIN_SCORE = 0.8
def show_preds(logits, image, correct_flowers_label=None, tf_flowers_logits=False):
if len(logits.shape) > 1:
logits = tf.reshape(logits, [-1])
fig, axes = plt.subplots(1, 2, figsize=(7, 4), squeeze=False)
ax1, ax2 = axes[0]
ax1.axis('off')
ax1.imshow(image)
if correct_flowers_label is not None:
ax1.set_title(tf_flowers_labels[correct_flowers_label])
classes = []
scores = []
logits_max = np.max(logits)
softmax_denominator = np.sum(np.exp(logits - logits_max))
for index, j in enumerate(np.argsort(logits)[-MAX_PREDS::][::-1]):
score = 1.0/(1.0 + np.exp(-logits[j]))
if score < MIN_SCORE: break
if not tf_flowers_logits:
# predicting in imagenet label space
classes.append(imagenet_int_to_str[j])
else:
# predicting in tf_flowers label space
classes.append(tf_flowers_labels[j])
scores.append(np.exp(logits[j] - logits_max)/softmax_denominator*100)
ax2.barh(np.arange(len(scores)) + 0.1, scores)
ax2.set_xlim(0, 100)
ax2.set_yticks(np.arange(len(scores)))
ax2.yaxis.set_ticks_position('right')
ax2.set_yticklabels(classes, rotation=0, fontsize=14)
ax2.invert_xaxis()
ax2.invert_yaxis()
ax2.set_xlabel('Prediction probabilities', fontsize=11)
TODO: try replacing the URL below with a link to an image of your choice!
In [0]:
# Load image (image provided is CC0 licensed)
img_url = "https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg"
image = load_image_from_url(img_url)
# Run model on image
logits = imagenet_module(image)
# Show image and predictions
show_preds(logits, image[0])
Here the model correctly classifies the photo as an elephant. It is likely an Asian elephant because of the size of its ears.
We will also try predicting on an image from the dataset we're going to fine-tune on - TF flowers, which has also been used in other tutorials. This dataset contains 3670 images of 5 classes of flowers.
Note that the correct label of the image we're going to predict on (‘tulip’) is not a class in ImageNet and so the model cannot predict that at the moment - let’s see what it tries to do instead.
In [0]:
# Import tf_flowers data from tfds
dataset_name = 'tf_flowers'
ds, info = tfds.load(name=dataset_name, split=['train'], in_memory=False, with_info=True)
ds = ds[0]
num_examples = info.splits['train'].num_examples
NUM_CLASSES = 5
In [0]:
#@title Alternative code for loading a dataset
#@markdown We provide alternative code for loading `tf_flowers` from an URL in this cell to make it easy for you to try loading your own datasets.
#@markdown This code is commented out by default and replaces the cell immediately above. Note that using this option may result in a different example image below.
"""
data_dir = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
fname='flower_photos', untar=True)
data_dir = pathlib.Path(data_dir)
IMG_HEIGHT = 224
IMG_WIDTH = 224
CLASS_NAMES = tf_flowers_labels # from plotting helper functions above
NUM_CLASSES = len(CLASS_NAMES)
num_examples = len(list(data_dir.glob('*/*.jpg')))
def get_label(file_path):
# convert the path to a list of path components
parts = tf.strings.split(file_path, os.path.sep)
# The second to last is the class-directory
return tf.where(parts[-2] == CLASS_NAMES)[0][0]
def decode_img(img):
# convert the compressed string to a 3D uint8 tensor
img = tf.image.decode_jpeg(img, channels=3)
return img
def process_path(file_path):
label = get_label(file_path)
# load the raw data from the file as a string
img = tf.io.read_file(file_path)
img = decode_img(img)
features = {'image': img, 'label': label}
return features
list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'))
ds = list_ds.map(process_path, num_parallel_calls=tf.data.experimental.AUTOTUNE)
"""
Out[0]:
In [0]:
# Split into train and test sets
# We have checked that the classes are reasonably balanced.
train_split = 0.9
num_train = int(train_split * num_examples)
ds_train = ds.take(num_train)
ds_test = ds.skip(num_train)
DATASET_NUM_TRAIN_EXAMPLES = num_examples
In [0]:
for features in ds_train.take(1):
image = features['image']
image = preprocess_image(image)
# Run model on image
logits = imagenet_module(image)
# Show image and predictions
show_preds(logits, image[0], correct_flowers_label=features['label'].numpy())
In this case, In this case, 'tulip' is not a class in ImageNet, and the model predicts a reasonably similar-looking classe, 'bell pepper'.
Now we are going to fine-tune the BiT model so it performs better on a specific dataset. Here we are going to use Keras for simplicity and we are going to fine-tune the model on a dataset of flowers (tf_flowers
).
We will use the model we loaded at the start (i.e. the one not fine-tuned on ImageNet) so the model is less biased towards ImageNet-like images.
There are two steps:
To create the new model, we:
Cut off the BiT model’s original head. This leaves us with the “pre-logits” output.
feature_vectors
), since for those models the head has already been cut off.Add a new head with the number of outputs equal to the number of classes of our new task. Note that it is important that we initialise the head to all zeroes.
In [0]:
# Add new head to the BiT model
class MyBiTModel(tf.keras.Model):
"""BiT with a new head."""
def __init__(self, num_classes, module):
super().__init__()
self.num_classes = num_classes
self.head = tf.keras.layers.Dense(num_classes, kernel_initializer='zeros')
self.bit_model = module
def call(self, images):
# No need to cut head off since we are using feature extractor model
bit_embedding = self.bit_model(images)
return self.head(bit_embedding)
model = MyBiTModel(num_classes=NUM_CLASSES, module=module)
When we fine-tune the model, we use BiT-HyperRule, our heuristic for choosing hyperparameters for downstream fine-tuning. This is not a hyperparameter sweep - given a dataset, it specifies one set of hyperparameters that we’ve seen produce good results. You can often obtain better results by running a more expensive hyperparameter sweep, but BiT-HyperRule is an effective way of getting good initial results on your dataset.
Hyperparameter heuristic details
In BiT-HyperRule, we use a vanilla SGD optimiser with an initial learning rate of 0.003, momentum 0.9 and batch size 512. We decay the learning rate by a factor of 10 at 30%, 60% and 90% of the training steps.
As data preprocessing, we resize the image, take a random crop, and then do a random horizontal flip (details in table below). We do random crops and horizontal flips for all tasks except those where such actions destroy label semantics. E.g. we don’t apply random crops to counting tasks, or random horizontal flip to tasks where we’re meant to predict the orientation of an object.
Image area | Resize to | Take random crop of size |
---|---|---|
Smaller than 96 x 96 px | 160 x 160 px | 128 x 128 px |
At least 96 x 96 px | 512 x 512 px | 480 x 480 px |
Table 1: Downstream resizing and random cropping details. If images are larger, we resize them to a larger fixed size to take advantage of benefits from fine-tuning on higher resolution.
We also use MixUp for datasets with more than 20k examples. Since the dataset used in this tutorial does not use MixUp, for simplicity and speed, we do not include it in this colab, but include it in our github repo implementation.
In [0]:
#@title Set dataset-dependent hyperparameters
#@markdown Here we set dataset-dependent hyperparameters. For example, our dataset of flowers has 3670 images of varying size (a few hundred x a few hundred pixels), so the image size is larger than 96x96 and the dataset size is <20k examples. However, for speed reasons (since this is a tutorial and we are training on a single GPU), we will select the `<96x96 px` option and train on lower resolution images. As we will see, we can still attain strong results.
#@markdown **Algorithm details: how are the hyperparameters dataset-dependent?**
#@markdown It's quite intuitive - we resize images to a smaller fixed size if they are smaller than 96 x 96px and to a larger fixed size otherwise. The number of steps we fine-tune for is larger for larger datasets.
IMAGE_SIZE = "=\u003C96x96 px" #@param ["=<96x96 px","> 96 x 96 px"]
DATASET_SIZE = "\u003C20k examples" #@param ["<20k examples", "20k-500k examples", ">500k examples"]
if IMAGE_SIZE == "=<96x96 px":
RESIZE_TO = 160
CROP_TO = 128
else:
RESIZE_TO = 512
CROP_TO = 480
if DATASET_SIZE == "<20k examples":
SCHEDULE_LENGTH = 500
SCHEDULE_BOUNDARIES = [200, 300, 400]
elif DATASET_SIZE == "20k-500k examples":
SCHEDULE_LENGTH = 10000
SCHEDULE_BOUNDARIES = [3000, 6000, 9000]
else:
SCHEDULE_LENGTH = 20000
SCHEDULE_BOUNDARIES = [6000, 12000, 18000]
Tip: if you are running out of memory, decrease the batch size. A way to adjust relevant parameters is to linearly scale the schedule length and learning rate.
SCHEDULE_LENGTH = SCHEDULE_LENGTH * 512 / BATCH_SIZE
lr = 0.003 * BATCH_SIZE / 512
These adjustments have already been coded in the cells below - you only have to change the BATCH_SIZE
. If you change the batch size, please re-run the cell above as well to make sure the SCHEDULE_LENGTH
you are starting from is correct as opposed to already altered from a previous run.
In [0]:
# Preprocessing helper functions
# Create data pipelines for training and testing:
BATCH_SIZE = 512
SCHEDULE_LENGTH = SCHEDULE_LENGTH * 512 / BATCH_SIZE
STEPS_PER_EPOCH = 10
def cast_to_tuple(features):
return (features['image'], features['label'])
def preprocess_train(features):
# Apply random crops and horizontal flips for all tasks
# except those for which cropping or flipping destroys the label semantics
# (e.g. predict orientation of an object)
features['image'] = tf.image.random_flip_left_right(features['image'])
features['image'] = tf.image.resize(features['image'], [RESIZE_TO, RESIZE_TO])
features['image'] = tf.image.random_crop(features['image'], [CROP_TO, CROP_TO, 3])
features['image'] = tf.cast(features['image'], tf.float32) / 255.0
return features
def preprocess_test(features):
features['image'] = tf.image.resize(features['image'], [RESIZE_TO, RESIZE_TO])
features['image'] = tf.cast(features['image'], tf.float32) / 255.0
return features
pipeline_train = (ds_train
.shuffle(10000)
.repeat(int(SCHEDULE_LENGTH * BATCH_SIZE / DATASET_NUM_TRAIN_EXAMPLES * STEPS_PER_EPOCH) + 1 + 50) # repeat dataset_size / num_steps
.map(preprocess_train, num_parallel_calls=8)
.batch(BATCH_SIZE)
.map(cast_to_tuple) # for keras model.fit
.prefetch(2))
pipeline_test = (ds_test.map(preprocess_test, num_parallel_calls=1)
.map(cast_to_tuple) # for keras model.fit
.batch(BATCH_SIZE)
.prefetch(2))
The fine-tuning will take about 15 minutes. If you wish, you can manually set the number of epochs to be 10 instead of 50 for the tutorial, and you will likely still obtain a model with validation accuracy > 99%.
In [0]:
# Define optimiser and loss
lr = 0.003 * BATCH_SIZE / 512
# Decay learning rate by a factor of 10 at SCHEDULE_BOUNDARIES.
lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=SCHEDULE_BOUNDARIES,
values=[lr, lr*0.1, lr*0.001, lr*0.0001])
optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
In [0]:
model.compile(optimizer=optimizer,
loss=loss_fn,
metrics=['accuracy'])
# Fine-tune model
history = model.fit(
pipeline_train,
batch_size=BATCH_SIZE,
steps_per_epoch=STEPS_PER_EPOCH,
epochs= int(SCHEDULE_LENGTH / STEPS_PER_EPOCH), # TODO: replace with `epochs=10` here to shorten fine-tuning for tutorial if you wish
validation_data=pipeline_test # here we are only using
# this data to evaluate our performance
)
We see that our model attains over 98-99% training and validation accuracy.
In [0]:
# Save fine-tuned model as SavedModel
export_module_dir = '/tmp/my_saved_bit_model/'
tf.saved_model.save(model, export_module_dir)
In [0]:
# Load saved model
saved_module = hub.KerasLayer(export_module_dir, trainable=True)
In [0]:
# Visualise predictions from new model
for features in ds_train.take(1):
image = features['image']
image = preprocess_image(image)
image = tf.image.resize(image, [CROP_TO, CROP_TO])
# Run model on image
logits = saved_module(image)
# Show image and predictions
show_preds(logits, image[0], correct_flowers_label=features['label'].numpy(), tf_flowers_logits=True)
Voila - we now have a model that predicts tulips as tulips and not bell peppers. :)