Fine Tuning a pre-trained Deep CNN on a GPU machine

This session is inspired by a blog post by François Chollet, the creator of the Keras library.

WARNING: the execution of notebook requires a GPU e.g. nvidia K80, GTX 980 or later with at least 6GB of GPU RAM.

For this session we are going to use the dataset of the dogs-vs-cats.

To download the data yourself, create a password-based account on Kaggle, then click on the download link of one of the data file when you are logged-in in your browser to get to the form that makes you accept the terms and conditions of that challenge.

Then in a shell session possibly on a server do the following:

pip3 install kaggle
# You need to download a new api key here https://www.kaggle.com/{my_name}/account
# And save it likewise `~/.kaggle/kaggle.json`.
mkdir -p ~/data/dogs-vs-cats
cd ~/data/dogs-vs-cats
kaggle competitions download -c dogs-vs-cats

This should download 3 files among which: train.zip and test1.zip (and a CSV template file we won't need).

Once this is done we can extract the archives for the train set:


In [6]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import os
import os.path as op
import shutil
from zipfile import ZipFile

data_folder = op.expanduser('~/data/dogs-vs-cats')
train_folder = op.join(data_folder, 'train')

if not op.exists(train_folder):
    train_zip = op.join(data_folder, 'dogs-vs-cats.zip')
    print('Extracting %s...' % train_zip)
    ZipFile(train_zip).extractall(data_folder)
    
    ZipFile(op.join(data_folder, "train.zip")).extractall(data_folder)

The Keras image data helpers want images for different classes ('cat' and 'dog') to live in distinct subfolders. Let's rearrange the image files to follow that convention:


In [7]:
def rearrange_folders(folder):
    image_filenames = [op.join(folder, fn) for fn in os.listdir(folder)
                       if fn.endswith('.jpg')]
    if len(image_filenames) == 0:
        return
    print("Rearranging %d images in %s into one subfolder per class..."
          % (len(image_filenames), folder))
    for image_filename in image_filenames:
        subfolder, _ = image_filename.split('.', 1)
        subfolder = op.join(folder, subfolder)
        if not op.exists(subfolder):
            os.mkdir(subfolder)
        shutil.move(image_filename, subfolder)

rearrange_folders(train_folder)


Rearranging 25000 images in /Users/arthurdouillard/data/dogs-vs-cats/train into one subfolder per class...

Lets build a validation dataset by taking 500 images of cats and 500 images of dogs out of the training set:


In [8]:
n_validation = 500

validation_folder = op.join(data_folder, 'validation')
if not op.exists(validation_folder):
    os.mkdir(validation_folder)
    for class_name in ['dog', 'cat']:
        train_subfolder = op.join(train_folder, class_name)
        validation_subfolder = op.join(validation_folder, class_name)
        print("Populating %s..." % validation_subfolder)
        os.mkdir(validation_subfolder)
        images_filenames = sorted(os.listdir(train_subfolder))
        for image_filename in images_filenames[-n_validation:]:
            shutil.move(op.join(train_subfolder, image_filename),
                        validation_subfolder)
        print("Moved %d images" % len(os.listdir(validation_subfolder)))


Populating /Users/arthurdouillard/data/dogs-vs-cats/validation/dog...
Moved 500 images
Populating /Users/arthurdouillard/data/dogs-vs-cats/validation/cat...
Moved 500 images

Data Loading and Data Augmentation

Let's use keras utilities to manually load the first image file of the cat folder. If keras complains about the missing "PIL" library, make sure to install it with one of the following commands:

conda install pillow

# or

pip install pillow

You might need to restart the kernel of this notebook to get Keras work.


In [9]:
from tensorflow.keras.preprocessing.image import array_to_img, img_to_array, load_img

img = load_img(op.join(train_folder, 'cat', 'cat.249.jpg'))
x = img_to_array(img)

print(x.shape)


(336, 344, 3)

In [10]:
plt.imshow(x.astype(np.uint8))
plt.axis('off');


Keras provides tools to generate many variations from a single image: this is useful to augment the dataset with variants that should not affect the image label: a rotated image of a cat is an image of a cat.

Doing data augmentation at train time make neural networks ignore such label-preserving transformations and therefore help reduce overfitting.


In [15]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

augmenting_datagen = ImageDataGenerator(
    rescale=1. / 255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    channel_shift_range=9,
    fill_mode='nearest'
)

In [12]:
plt.figure(figsize=(11, 5))
flow = augmenting_datagen.flow(x[np.newaxis, :, :, :])
for i, x_augmented in zip(range(15), flow):
    plt.subplot(3, 5, i + 1)
    plt.imshow(x_augmented[0])
    plt.axis('off')


The ImageDataGenerator object can the be pointed to the dataset folder both load the image and augment them on the fly and resize / crop them to fit the input dimensions of the classification neural network.


In [13]:
flow = augmenting_datagen.flow_from_directory(
    train_folder, batch_size=1, target_size=(224, 224))

plt.figure(figsize=(11, 5))
for i, (X, y) in zip(range(15), flow):
    plt.subplot(3, 5, i + 1)
    plt.imshow(X[0])
    plt.axis('off')


Found 24000 images belonging to 2 classes.

Loading a pre-trained computer vision model

Let us load a state of the art model with a good tradeoff between prediction speed, model size and predictive accuracy, namely a Residual Network with 54 parameterized layers (53 convolutional + 1 fully connected for the softmax):


In [16]:
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input

full_imagenet_model = ResNet50(weights='imagenet')

In [17]:
print(full_imagenet_model.summary())


Model: "resnet50"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 230, 230, 3)  0           input_2[0][0]                    
__________________________________________________________________________________________________
conv1_conv (Conv2D)             (None, 112, 112, 64) 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
conv1_bn (BatchNormalization)   (None, 112, 112, 64) 256         conv1_conv[0][0]                 
__________________________________________________________________________________________________
conv1_relu (Activation)         (None, 112, 112, 64) 0           conv1_bn[0][0]                   
__________________________________________________________________________________________________
pool1_pad (ZeroPadding2D)       (None, 114, 114, 64) 0           conv1_relu[0][0]                 
__________________________________________________________________________________________________
pool1_pool (MaxPooling2D)       (None, 56, 56, 64)   0           pool1_pad[0][0]                  
__________________________________________________________________________________________________
conv2_block1_1_conv (Conv2D)    (None, 56, 56, 64)   4160        pool1_pool[0][0]                 
__________________________________________________________________________________________________
conv2_block1_1_bn (BatchNormali (None, 56, 56, 64)   256         conv2_block1_1_conv[0][0]        
__________________________________________________________________________________________________
conv2_block1_1_relu (Activation (None, 56, 56, 64)   0           conv2_block1_1_bn[0][0]          
__________________________________________________________________________________________________
conv2_block1_2_conv (Conv2D)    (None, 56, 56, 64)   36928       conv2_block1_1_relu[0][0]        
__________________________________________________________________________________________________
conv2_block1_2_bn (BatchNormali (None, 56, 56, 64)   256         conv2_block1_2_conv[0][0]        
__________________________________________________________________________________________________
conv2_block1_2_relu (Activation (None, 56, 56, 64)   0           conv2_block1_2_bn[0][0]          
__________________________________________________________________________________________________
conv2_block1_0_conv (Conv2D)    (None, 56, 56, 256)  16640       pool1_pool[0][0]                 
__________________________________________________________________________________________________
conv2_block1_3_conv (Conv2D)    (None, 56, 56, 256)  16640       conv2_block1_2_relu[0][0]        
__________________________________________________________________________________________________
conv2_block1_0_bn (BatchNormali (None, 56, 56, 256)  1024        conv2_block1_0_conv[0][0]        
__________________________________________________________________________________________________
conv2_block1_3_bn (BatchNormali (None, 56, 56, 256)  1024        conv2_block1_3_conv[0][0]        
__________________________________________________________________________________________________
conv2_block1_add (Add)          (None, 56, 56, 256)  0           conv2_block1_0_bn[0][0]          
                                                                 conv2_block1_3_bn[0][0]          
__________________________________________________________________________________________________
conv2_block1_out (Activation)   (None, 56, 56, 256)  0           conv2_block1_add[0][0]           
__________________________________________________________________________________________________
conv2_block2_1_conv (Conv2D)    (None, 56, 56, 64)   16448       conv2_block1_out[0][0]           
__________________________________________________________________________________________________
conv2_block2_1_bn (BatchNormali (None, 56, 56, 64)   256         conv2_block2_1_conv[0][0]        
__________________________________________________________________________________________________
conv2_block2_1_relu (Activation (None, 56, 56, 64)   0           conv2_block2_1_bn[0][0]          
__________________________________________________________________________________________________
conv2_block2_2_conv (Conv2D)    (None, 56, 56, 64)   36928       conv2_block2_1_relu[0][0]        
__________________________________________________________________________________________________
conv2_block2_2_bn (BatchNormali (None, 56, 56, 64)   256         conv2_block2_2_conv[0][0]        
__________________________________________________________________________________________________
conv2_block2_2_relu (Activation (None, 56, 56, 64)   0           conv2_block2_2_bn[0][0]          
__________________________________________________________________________________________________
conv2_block2_3_conv (Conv2D)    (None, 56, 56, 256)  16640       conv2_block2_2_relu[0][0]        
__________________________________________________________________________________________________
conv2_block2_3_bn (BatchNormali (None, 56, 56, 256)  1024        conv2_block2_3_conv[0][0]        
__________________________________________________________________________________________________
conv2_block2_add (Add)          (None, 56, 56, 256)  0           conv2_block1_out[0][0]           
                                                                 conv2_block2_3_bn[0][0]          
__________________________________________________________________________________________________
conv2_block2_out (Activation)   (None, 56, 56, 256)  0           conv2_block2_add[0][0]           
__________________________________________________________________________________________________
conv2_block3_1_conv (Conv2D)    (None, 56, 56, 64)   16448       conv2_block2_out[0][0]           
__________________________________________________________________________________________________
conv2_block3_1_bn (BatchNormali (None, 56, 56, 64)   256         conv2_block3_1_conv[0][0]        
__________________________________________________________________________________________________
conv2_block3_1_relu (Activation (None, 56, 56, 64)   0           conv2_block3_1_bn[0][0]          
__________________________________________________________________________________________________
conv2_block3_2_conv (Conv2D)    (None, 56, 56, 64)   36928       conv2_block3_1_relu[0][0]        
__________________________________________________________________________________________________
conv2_block3_2_bn (BatchNormali (None, 56, 56, 64)   256         conv2_block3_2_conv[0][0]        
__________________________________________________________________________________________________
conv2_block3_2_relu (Activation (None, 56, 56, 64)   0           conv2_block3_2_bn[0][0]          
__________________________________________________________________________________________________
conv2_block3_3_conv (Conv2D)    (None, 56, 56, 256)  16640       conv2_block3_2_relu[0][0]        
__________________________________________________________________________________________________
conv2_block3_3_bn (BatchNormali (None, 56, 56, 256)  1024        conv2_block3_3_conv[0][0]        
__________________________________________________________________________________________________
conv2_block3_add (Add)          (None, 56, 56, 256)  0           conv2_block2_out[0][0]           
                                                                 conv2_block3_3_bn[0][0]          
__________________________________________________________________________________________________
conv2_block3_out (Activation)   (None, 56, 56, 256)  0           conv2_block3_add[0][0]           
__________________________________________________________________________________________________
conv3_block1_1_conv (Conv2D)    (None, 28, 28, 128)  32896       conv2_block3_out[0][0]           
__________________________________________________________________________________________________
conv3_block1_1_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block1_1_conv[0][0]        
__________________________________________________________________________________________________
conv3_block1_1_relu (Activation (None, 28, 28, 128)  0           conv3_block1_1_bn[0][0]          
__________________________________________________________________________________________________
conv3_block1_2_conv (Conv2D)    (None, 28, 28, 128)  147584      conv3_block1_1_relu[0][0]        
__________________________________________________________________________________________________
conv3_block1_2_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block1_2_conv[0][0]        
__________________________________________________________________________________________________
conv3_block1_2_relu (Activation (None, 28, 28, 128)  0           conv3_block1_2_bn[0][0]          
__________________________________________________________________________________________________
conv3_block1_0_conv (Conv2D)    (None, 28, 28, 512)  131584      conv2_block3_out[0][0]           
__________________________________________________________________________________________________
conv3_block1_3_conv (Conv2D)    (None, 28, 28, 512)  66048       conv3_block1_2_relu[0][0]        
__________________________________________________________________________________________________
conv3_block1_0_bn (BatchNormali (None, 28, 28, 512)  2048        conv3_block1_0_conv[0][0]        
__________________________________________________________________________________________________
conv3_block1_3_bn (BatchNormali (None, 28, 28, 512)  2048        conv3_block1_3_conv[0][0]        
__________________________________________________________________________________________________
conv3_block1_add (Add)          (None, 28, 28, 512)  0           conv3_block1_0_bn[0][0]          
                                                                 conv3_block1_3_bn[0][0]          
__________________________________________________________________________________________________
conv3_block1_out (Activation)   (None, 28, 28, 512)  0           conv3_block1_add[0][0]           
__________________________________________________________________________________________________
conv3_block2_1_conv (Conv2D)    (None, 28, 28, 128)  65664       conv3_block1_out[0][0]           
__________________________________________________________________________________________________
conv3_block2_1_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block2_1_conv[0][0]        
__________________________________________________________________________________________________
conv3_block2_1_relu (Activation (None, 28, 28, 128)  0           conv3_block2_1_bn[0][0]          
__________________________________________________________________________________________________
conv3_block2_2_conv (Conv2D)    (None, 28, 28, 128)  147584      conv3_block2_1_relu[0][0]        
__________________________________________________________________________________________________
conv3_block2_2_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block2_2_conv[0][0]        
__________________________________________________________________________________________________
conv3_block2_2_relu (Activation (None, 28, 28, 128)  0           conv3_block2_2_bn[0][0]          
__________________________________________________________________________________________________
conv3_block2_3_conv (Conv2D)    (None, 28, 28, 512)  66048       conv3_block2_2_relu[0][0]        
__________________________________________________________________________________________________
conv3_block2_3_bn (BatchNormali (None, 28, 28, 512)  2048        conv3_block2_3_conv[0][0]        
__________________________________________________________________________________________________
conv3_block2_add (Add)          (None, 28, 28, 512)  0           conv3_block1_out[0][0]           
                                                                 conv3_block2_3_bn[0][0]          
__________________________________________________________________________________________________
conv3_block2_out (Activation)   (None, 28, 28, 512)  0           conv3_block2_add[0][0]           
__________________________________________________________________________________________________
conv3_block3_1_conv (Conv2D)    (None, 28, 28, 128)  65664       conv3_block2_out[0][0]           
__________________________________________________________________________________________________
conv3_block3_1_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block3_1_conv[0][0]        
__________________________________________________________________________________________________
conv3_block3_1_relu (Activation (None, 28, 28, 128)  0           conv3_block3_1_bn[0][0]          
__________________________________________________________________________________________________
conv3_block3_2_conv (Conv2D)    (None, 28, 28, 128)  147584      conv3_block3_1_relu[0][0]        
__________________________________________________________________________________________________
conv3_block3_2_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block3_2_conv[0][0]        
__________________________________________________________________________________________________
conv3_block3_2_relu (Activation (None, 28, 28, 128)  0           conv3_block3_2_bn[0][0]          
__________________________________________________________________________________________________
conv3_block3_3_conv (Conv2D)    (None, 28, 28, 512)  66048       conv3_block3_2_relu[0][0]        
__________________________________________________________________________________________________
conv3_block3_3_bn (BatchNormali (None, 28, 28, 512)  2048        conv3_block3_3_conv[0][0]        
__________________________________________________________________________________________________
conv3_block3_add (Add)          (None, 28, 28, 512)  0           conv3_block2_out[0][0]           
                                                                 conv3_block3_3_bn[0][0]          
__________________________________________________________________________________________________
conv3_block3_out (Activation)   (None, 28, 28, 512)  0           conv3_block3_add[0][0]           
__________________________________________________________________________________________________
conv3_block4_1_conv (Conv2D)    (None, 28, 28, 128)  65664       conv3_block3_out[0][0]           
__________________________________________________________________________________________________
conv3_block4_1_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block4_1_conv[0][0]        
__________________________________________________________________________________________________
conv3_block4_1_relu (Activation (None, 28, 28, 128)  0           conv3_block4_1_bn[0][0]          
__________________________________________________________________________________________________
conv3_block4_2_conv (Conv2D)    (None, 28, 28, 128)  147584      conv3_block4_1_relu[0][0]        
__________________________________________________________________________________________________
conv3_block4_2_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block4_2_conv[0][0]        
__________________________________________________________________________________________________
conv3_block4_2_relu (Activation (None, 28, 28, 128)  0           conv3_block4_2_bn[0][0]          
__________________________________________________________________________________________________
conv3_block4_3_conv (Conv2D)    (None, 28, 28, 512)  66048       conv3_block4_2_relu[0][0]        
__________________________________________________________________________________________________
conv3_block4_3_bn (BatchNormali (None, 28, 28, 512)  2048        conv3_block4_3_conv[0][0]        
__________________________________________________________________________________________________
conv3_block4_add (Add)          (None, 28, 28, 512)  0           conv3_block3_out[0][0]           
                                                                 conv3_block4_3_bn[0][0]          
__________________________________________________________________________________________________
conv3_block4_out (Activation)   (None, 28, 28, 512)  0           conv3_block4_add[0][0]           
__________________________________________________________________________________________________
conv4_block1_1_conv (Conv2D)    (None, 14, 14, 256)  131328      conv3_block4_out[0][0]           
__________________________________________________________________________________________________
conv4_block1_1_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block1_1_conv[0][0]        
__________________________________________________________________________________________________
conv4_block1_1_relu (Activation (None, 14, 14, 256)  0           conv4_block1_1_bn[0][0]          
__________________________________________________________________________________________________
conv4_block1_2_conv (Conv2D)    (None, 14, 14, 256)  590080      conv4_block1_1_relu[0][0]        
__________________________________________________________________________________________________
conv4_block1_2_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block1_2_conv[0][0]        
__________________________________________________________________________________________________
conv4_block1_2_relu (Activation (None, 14, 14, 256)  0           conv4_block1_2_bn[0][0]          
__________________________________________________________________________________________________
conv4_block1_0_conv (Conv2D)    (None, 14, 14, 1024) 525312      conv3_block4_out[0][0]           
__________________________________________________________________________________________________
conv4_block1_3_conv (Conv2D)    (None, 14, 14, 1024) 263168      conv4_block1_2_relu[0][0]        
__________________________________________________________________________________________________
conv4_block1_0_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block1_0_conv[0][0]        
__________________________________________________________________________________________________
conv4_block1_3_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block1_3_conv[0][0]        
__________________________________________________________________________________________________
conv4_block1_add (Add)          (None, 14, 14, 1024) 0           conv4_block1_0_bn[0][0]          
                                                                 conv4_block1_3_bn[0][0]          
__________________________________________________________________________________________________
conv4_block1_out (Activation)   (None, 14, 14, 1024) 0           conv4_block1_add[0][0]           
__________________________________________________________________________________________________
conv4_block2_1_conv (Conv2D)    (None, 14, 14, 256)  262400      conv4_block1_out[0][0]           
__________________________________________________________________________________________________
conv4_block2_1_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block2_1_conv[0][0]        
__________________________________________________________________________________________________
conv4_block2_1_relu (Activation (None, 14, 14, 256)  0           conv4_block2_1_bn[0][0]          
__________________________________________________________________________________________________
conv4_block2_2_conv (Conv2D)    (None, 14, 14, 256)  590080      conv4_block2_1_relu[0][0]        
__________________________________________________________________________________________________
conv4_block2_2_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block2_2_conv[0][0]        
__________________________________________________________________________________________________
conv4_block2_2_relu (Activation (None, 14, 14, 256)  0           conv4_block2_2_bn[0][0]          
__________________________________________________________________________________________________
conv4_block2_3_conv (Conv2D)    (None, 14, 14, 1024) 263168      conv4_block2_2_relu[0][0]        
__________________________________________________________________________________________________
conv4_block2_3_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block2_3_conv[0][0]        
__________________________________________________________________________________________________
conv4_block2_add (Add)          (None, 14, 14, 1024) 0           conv4_block1_out[0][0]           
                                                                 conv4_block2_3_bn[0][0]          
__________________________________________________________________________________________________
conv4_block2_out (Activation)   (None, 14, 14, 1024) 0           conv4_block2_add[0][0]           
__________________________________________________________________________________________________
conv4_block3_1_conv (Conv2D)    (None, 14, 14, 256)  262400      conv4_block2_out[0][0]           
__________________________________________________________________________________________________
conv4_block3_1_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block3_1_conv[0][0]        
__________________________________________________________________________________________________
conv4_block3_1_relu (Activation (None, 14, 14, 256)  0           conv4_block3_1_bn[0][0]          
__________________________________________________________________________________________________
conv4_block3_2_conv (Conv2D)    (None, 14, 14, 256)  590080      conv4_block3_1_relu[0][0]        
__________________________________________________________________________________________________
conv4_block3_2_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block3_2_conv[0][0]        
__________________________________________________________________________________________________
conv4_block3_2_relu (Activation (None, 14, 14, 256)  0           conv4_block3_2_bn[0][0]          
__________________________________________________________________________________________________
conv4_block3_3_conv (Conv2D)    (None, 14, 14, 1024) 263168      conv4_block3_2_relu[0][0]        
__________________________________________________________________________________________________
conv4_block3_3_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block3_3_conv[0][0]        
__________________________________________________________________________________________________
conv4_block3_add (Add)          (None, 14, 14, 1024) 0           conv4_block2_out[0][0]           
                                                                 conv4_block3_3_bn[0][0]          
__________________________________________________________________________________________________
conv4_block3_out (Activation)   (None, 14, 14, 1024) 0           conv4_block3_add[0][0]           
__________________________________________________________________________________________________
conv4_block4_1_conv (Conv2D)    (None, 14, 14, 256)  262400      conv4_block3_out[0][0]           
__________________________________________________________________________________________________
conv4_block4_1_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block4_1_conv[0][0]        
__________________________________________________________________________________________________
conv4_block4_1_relu (Activation (None, 14, 14, 256)  0           conv4_block4_1_bn[0][0]          
__________________________________________________________________________________________________
conv4_block4_2_conv (Conv2D)    (None, 14, 14, 256)  590080      conv4_block4_1_relu[0][0]        
__________________________________________________________________________________________________
conv4_block4_2_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block4_2_conv[0][0]        
__________________________________________________________________________________________________
conv4_block4_2_relu (Activation (None, 14, 14, 256)  0           conv4_block4_2_bn[0][0]          
__________________________________________________________________________________________________
conv4_block4_3_conv (Conv2D)    (None, 14, 14, 1024) 263168      conv4_block4_2_relu[0][0]        
__________________________________________________________________________________________________
conv4_block4_3_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block4_3_conv[0][0]        
__________________________________________________________________________________________________
conv4_block4_add (Add)          (None, 14, 14, 1024) 0           conv4_block3_out[0][0]           
                                                                 conv4_block4_3_bn[0][0]          
__________________________________________________________________________________________________
conv4_block4_out (Activation)   (None, 14, 14, 1024) 0           conv4_block4_add[0][0]           
__________________________________________________________________________________________________
conv4_block5_1_conv (Conv2D)    (None, 14, 14, 256)  262400      conv4_block4_out[0][0]           
__________________________________________________________________________________________________
conv4_block5_1_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block5_1_conv[0][0]        
__________________________________________________________________________________________________
conv4_block5_1_relu (Activation (None, 14, 14, 256)  0           conv4_block5_1_bn[0][0]          
__________________________________________________________________________________________________
conv4_block5_2_conv (Conv2D)    (None, 14, 14, 256)  590080      conv4_block5_1_relu[0][0]        
__________________________________________________________________________________________________
conv4_block5_2_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block5_2_conv[0][0]        
__________________________________________________________________________________________________
conv4_block5_2_relu (Activation (None, 14, 14, 256)  0           conv4_block5_2_bn[0][0]          
__________________________________________________________________________________________________
conv4_block5_3_conv (Conv2D)    (None, 14, 14, 1024) 263168      conv4_block5_2_relu[0][0]        
__________________________________________________________________________________________________
conv4_block5_3_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block5_3_conv[0][0]        
__________________________________________________________________________________________________
conv4_block5_add (Add)          (None, 14, 14, 1024) 0           conv4_block4_out[0][0]           
                                                                 conv4_block5_3_bn[0][0]          
__________________________________________________________________________________________________
conv4_block5_out (Activation)   (None, 14, 14, 1024) 0           conv4_block5_add[0][0]           
__________________________________________________________________________________________________
conv4_block6_1_conv (Conv2D)    (None, 14, 14, 256)  262400      conv4_block5_out[0][0]           
__________________________________________________________________________________________________
conv4_block6_1_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block6_1_conv[0][0]        
__________________________________________________________________________________________________
conv4_block6_1_relu (Activation (None, 14, 14, 256)  0           conv4_block6_1_bn[0][0]          
__________________________________________________________________________________________________
conv4_block6_2_conv (Conv2D)    (None, 14, 14, 256)  590080      conv4_block6_1_relu[0][0]        
__________________________________________________________________________________________________
conv4_block6_2_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block6_2_conv[0][0]        
__________________________________________________________________________________________________
conv4_block6_2_relu (Activation (None, 14, 14, 256)  0           conv4_block6_2_bn[0][0]          
__________________________________________________________________________________________________
conv4_block6_3_conv (Conv2D)    (None, 14, 14, 1024) 263168      conv4_block6_2_relu[0][0]        
__________________________________________________________________________________________________
conv4_block6_3_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block6_3_conv[0][0]        
__________________________________________________________________________________________________
conv4_block6_add (Add)          (None, 14, 14, 1024) 0           conv4_block5_out[0][0]           
                                                                 conv4_block6_3_bn[0][0]          
__________________________________________________________________________________________________
conv4_block6_out (Activation)   (None, 14, 14, 1024) 0           conv4_block6_add[0][0]           
__________________________________________________________________________________________________
conv5_block1_1_conv (Conv2D)    (None, 7, 7, 512)    524800      conv4_block6_out[0][0]           
__________________________________________________________________________________________________
conv5_block1_1_bn (BatchNormali (None, 7, 7, 512)    2048        conv5_block1_1_conv[0][0]        
__________________________________________________________________________________________________
conv5_block1_1_relu (Activation (None, 7, 7, 512)    0           conv5_block1_1_bn[0][0]          
__________________________________________________________________________________________________
conv5_block1_2_conv (Conv2D)    (None, 7, 7, 512)    2359808     conv5_block1_1_relu[0][0]        
__________________________________________________________________________________________________
conv5_block1_2_bn (BatchNormali (None, 7, 7, 512)    2048        conv5_block1_2_conv[0][0]        
__________________________________________________________________________________________________
conv5_block1_2_relu (Activation (None, 7, 7, 512)    0           conv5_block1_2_bn[0][0]          
__________________________________________________________________________________________________
conv5_block1_0_conv (Conv2D)    (None, 7, 7, 2048)   2099200     conv4_block6_out[0][0]           
__________________________________________________________________________________________________
conv5_block1_3_conv (Conv2D)    (None, 7, 7, 2048)   1050624     conv5_block1_2_relu[0][0]        
__________________________________________________________________________________________________
conv5_block1_0_bn (BatchNormali (None, 7, 7, 2048)   8192        conv5_block1_0_conv[0][0]        
__________________________________________________________________________________________________
conv5_block1_3_bn (BatchNormali (None, 7, 7, 2048)   8192        conv5_block1_3_conv[0][0]        
__________________________________________________________________________________________________
conv5_block1_add (Add)          (None, 7, 7, 2048)   0           conv5_block1_0_bn[0][0]          
                                                                 conv5_block1_3_bn[0][0]          
__________________________________________________________________________________________________
conv5_block1_out (Activation)   (None, 7, 7, 2048)   0           conv5_block1_add[0][0]           
__________________________________________________________________________________________________
conv5_block2_1_conv (Conv2D)    (None, 7, 7, 512)    1049088     conv5_block1_out[0][0]           
__________________________________________________________________________________________________
conv5_block2_1_bn (BatchNormali (None, 7, 7, 512)    2048        conv5_block2_1_conv[0][0]        
__________________________________________________________________________________________________
conv5_block2_1_relu (Activation (None, 7, 7, 512)    0           conv5_block2_1_bn[0][0]          
__________________________________________________________________________________________________
conv5_block2_2_conv (Conv2D)    (None, 7, 7, 512)    2359808     conv5_block2_1_relu[0][0]        
__________________________________________________________________________________________________
conv5_block2_2_bn (BatchNormali (None, 7, 7, 512)    2048        conv5_block2_2_conv[0][0]        
__________________________________________________________________________________________________
conv5_block2_2_relu (Activation (None, 7, 7, 512)    0           conv5_block2_2_bn[0][0]          
__________________________________________________________________________________________________
conv5_block2_3_conv (Conv2D)    (None, 7, 7, 2048)   1050624     conv5_block2_2_relu[0][0]        
__________________________________________________________________________________________________
conv5_block2_3_bn (BatchNormali (None, 7, 7, 2048)   8192        conv5_block2_3_conv[0][0]        
__________________________________________________________________________________________________
conv5_block2_add (Add)          (None, 7, 7, 2048)   0           conv5_block1_out[0][0]           
                                                                 conv5_block2_3_bn[0][0]          
__________________________________________________________________________________________________
conv5_block2_out (Activation)   (None, 7, 7, 2048)   0           conv5_block2_add[0][0]           
__________________________________________________________________________________________________
conv5_block3_1_conv (Conv2D)    (None, 7, 7, 512)    1049088     conv5_block2_out[0][0]           
__________________________________________________________________________________________________
conv5_block3_1_bn (BatchNormali (None, 7, 7, 512)    2048        conv5_block3_1_conv[0][0]        
__________________________________________________________________________________________________
conv5_block3_1_relu (Activation (None, 7, 7, 512)    0           conv5_block3_1_bn[0][0]          
__________________________________________________________________________________________________
conv5_block3_2_conv (Conv2D)    (None, 7, 7, 512)    2359808     conv5_block3_1_relu[0][0]        
__________________________________________________________________________________________________
conv5_block3_2_bn (BatchNormali (None, 7, 7, 512)    2048        conv5_block3_2_conv[0][0]        
__________________________________________________________________________________________________
conv5_block3_2_relu (Activation (None, 7, 7, 512)    0           conv5_block3_2_bn[0][0]          
__________________________________________________________________________________________________
conv5_block3_3_conv (Conv2D)    (None, 7, 7, 2048)   1050624     conv5_block3_2_relu[0][0]        
__________________________________________________________________________________________________
conv5_block3_3_bn (BatchNormali (None, 7, 7, 2048)   8192        conv5_block3_3_conv[0][0]        
__________________________________________________________________________________________________
conv5_block3_add (Add)          (None, 7, 7, 2048)   0           conv5_block2_out[0][0]           
                                                                 conv5_block3_3_bn[0][0]          
__________________________________________________________________________________________________
conv5_block3_out (Activation)   (None, 7, 7, 2048)   0           conv5_block3_add[0][0]           
__________________________________________________________________________________________________
avg_pool (GlobalAveragePooling2 (None, 2048)         0           conv5_block3_out[0][0]           
__________________________________________________________________________________________________
probs (Dense)                   (None, 1000)         2049000     avg_pool[0][0]                   
==================================================================================================
Total params: 25,636,712
Trainable params: 25,583,592
Non-trainable params: 53,120
__________________________________________________________________________________________________
None

If you have graphviz system package and the pydot_ng python package installed you can uncomment the following cell to display the structure of the network.


In [ ]:
# from IPython.display import SVG
# from keras.utils.vis_utils import model_to_dot

# model_viz = model_to_dot(full_imagenet_model,
#                          show_layer_names=False,
#                          show_shapes=True)
# SVG(model_viz.create(prog='dot', format='svg'))

Transfer learning

Let's remove the last dense classification layer that is specific to the image net classes and use the previous layer (after flattening) as a feature extractor:


In [18]:
from tensorflow.keras.models import Model

output = full_imagenet_model.layers[-2].output
base_model = Model(full_imagenet_model.input, output)

When using this model we need to be careful to apply the same image processing as was used during the training, otherwise the marginal distribution of the input pixels might not be on the right scale:


In [20]:
def preprocess_function(x):
    if x.ndim == 3:
        x = x[np.newaxis, :, :, :]
    return preprocess_input(x)

In [21]:
batch_size = 50

datagen = ImageDataGenerator(preprocessing_function=preprocess_function)

train_flow = datagen.flow_from_directory(
    train_folder,
    target_size=(224, 224),
    batch_size=batch_size,
    class_mode='binary',
    shuffle=True,
)

X, y = next(train_flow)
print(X.shape, y.shape)


Found 24000 images belonging to 2 classes.
(50, 224, 224, 3) (50,)

Exercise: write a function that iterate of over 5000 images in the training set (bach after batch), extracts the activations of the last layer of base_model (by calling predicts) and collect the results in a big numpy array with dimensions (5000, 2048) for the features and (5000,) for the matching image labels.


In [26]:
# %load solutions/dogs_vs_cats_extract_features.py
from time import time

features = []
labels = []

t0 = time()
count = 0
for X, y in train_flow:
    labels.append(y)
    features.append(base_model.predict(X))
    count += len(y)
    if count % 100 == 0:
        print("processed %d images at %d images/s"
              % (count, count / (time() - t0)))
    if count >= 5000:
        break

labels_train = np.concatenate(labels)
features_train = np.vstack(features)
np.save('labels_train.npy', labels_train)
np.save('features_train.npy', features_train)


processed 100 images at 7 images/s
processed 200 images at 6 images/s
processed 300 images at 6 images/s
processed 400 images at 6 images/s
processed 500 images at 7 images/s
processed 600 images at 7 images/s
processed 700 images at 7 images/s
processed 800 images at 7 images/s
processed 900 images at 7 images/s
processed 1000 images at 7 images/s
processed 1100 images at 7 images/s
processed 1200 images at 7 images/s
processed 1300 images at 7 images/s
processed 1400 images at 7 images/s
processed 1500 images at 8 images/s
processed 1600 images at 8 images/s
processed 1700 images at 8 images/s
processed 1800 images at 8 images/s
processed 1900 images at 8 images/s
processed 2000 images at 8 images/s
processed 2100 images at 8 images/s
processed 2200 images at 8 images/s
processed 2300 images at 8 images/s
processed 2400 images at 8 images/s
processed 2500 images at 8 images/s
processed 2600 images at 8 images/s
processed 2700 images at 8 images/s
processed 2800 images at 8 images/s
processed 2900 images at 8 images/s
processed 3000 images at 8 images/s
processed 3100 images at 8 images/s
processed 3200 images at 8 images/s
processed 3300 images at 8 images/s
processed 3400 images at 8 images/s
processed 3500 images at 8 images/s
processed 3600 images at 8 images/s
processed 3700 images at 8 images/s
processed 3800 images at 8 images/s
processed 3900 images at 8 images/s
processed 4000 images at 8 images/s
processed 4100 images at 8 images/s
processed 4200 images at 8 images/s
processed 4300 images at 8 images/s
processed 4400 images at 8 images/s
processed 4500 images at 8 images/s
processed 4600 images at 8 images/s
processed 4700 images at 8 images/s
processed 4800 images at 8 images/s
processed 4900 images at 8 images/s
processed 5000 images at 8 images/s

Let's load precomputed features if available:


In [27]:
print("Loading precomputed features")
labels_train = np.load('labels_train.npy')
features_train = np.load('features_train.npy')


Loading precomputed features

Let's train a simple linear model on those features. First let's check that the resulting small dataset has balanced classes:


In [28]:
print(labels_train.shape)


(5000,)

In [29]:
np.mean(labels_train)


Out[29]:
0.5062

In [30]:
n_samples, n_features = features_train.shape
print(n_features, "features extracted")


2048 features extracted

Let's define the classification model:


In [31]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.optimizers import Adam


top_model = Sequential()
top_model.add(Dense(1, input_dim=n_features, activation='sigmoid'))
top_model.compile(optimizer=Adam(lr=1e-4),
                  loss='binary_crossentropy', metrics=['accuracy'])

top_model.fit(features_train, labels_train,
              validation_split=0.1, verbose=2, epochs=15)


Train on 4500 samples, validate on 500 samples
Epoch 1/15
4500/4500 - 1s - loss: 0.4411 - accuracy: 0.8231 - val_loss: 0.2424 - val_accuracy: 0.9420
Epoch 2/15
4500/4500 - 0s - loss: 0.1883 - accuracy: 0.9622 - val_loss: 0.1419 - val_accuracy: 0.9660
Epoch 3/15
4500/4500 - 0s - loss: 0.1248 - accuracy: 0.9749 - val_loss: 0.1051 - val_accuracy: 0.9740
Epoch 4/15
4500/4500 - 0s - loss: 0.0966 - accuracy: 0.9789 - val_loss: 0.0862 - val_accuracy: 0.9740
Epoch 5/15
4500/4500 - 0s - loss: 0.0805 - accuracy: 0.9813 - val_loss: 0.0736 - val_accuracy: 0.9740
Epoch 6/15
4500/4500 - 0s - loss: 0.0700 - accuracy: 0.9824 - val_loss: 0.0669 - val_accuracy: 0.9760
Epoch 7/15
4500/4500 - 0s - loss: 0.0624 - accuracy: 0.9838 - val_loss: 0.0599 - val_accuracy: 0.9860
Epoch 8/15
4500/4500 - 0s - loss: 0.0571 - accuracy: 0.9851 - val_loss: 0.0561 - val_accuracy: 0.9840
Epoch 9/15
4500/4500 - 0s - loss: 0.0526 - accuracy: 0.9853 - val_loss: 0.0532 - val_accuracy: 0.9800
Epoch 10/15
4500/4500 - 0s - loss: 0.0490 - accuracy: 0.9860 - val_loss: 0.0524 - val_accuracy: 0.9780
Epoch 11/15
4500/4500 - 0s - loss: 0.0459 - accuracy: 0.9867 - val_loss: 0.0484 - val_accuracy: 0.9880
Epoch 12/15
4500/4500 - 0s - loss: 0.0436 - accuracy: 0.9884 - val_loss: 0.0456 - val_accuracy: 0.9880
Epoch 13/15
4500/4500 - 0s - loss: 0.0412 - accuracy: 0.9882 - val_loss: 0.0450 - val_accuracy: 0.9880
Epoch 14/15
4500/4500 - 0s - loss: 0.0392 - accuracy: 0.9887 - val_loss: 0.0438 - val_accuracy: 0.9880
Epoch 15/15
4500/4500 - 0s - loss: 0.0375 - accuracy: 0.9893 - val_loss: 0.0431 - val_accuracy: 0.9880
Out[31]:
<tensorflow.python.keras.callbacks.History at 0x13d5e5198>

Alright so the transfer learning is already at ~0.98 / 0.99 accuracy. This is not too surprising as the cats and dogs classes are already part of the imagenet label set.

Note that this is already as good or slightly better than the winner of the original kaggle competition three years ago. At that time they did not have pretrained resnet models at hand.

Or validation set has 1000 images, so an accuracy of 0.990 means only 10 classification errors.

Let's plug this on top the base model to be able to use it to make some classifications on our held out validation image folder:


In [32]:
model = Model(base_model.input, top_model(base_model.output))

In [33]:
flow = ImageDataGenerator().flow_from_directory(
        validation_folder, batch_size=1, target_size=(224, 224))

plt.figure(figsize=(12, 8))
for i, (X, y) in zip(range(15), flow):
    plt.subplot(3, 5, i + 1)
    plt.imshow(X[0] / 255)
    prediction = model.predict(preprocess_input(X))
    label = "dog" if y[:, 1] > 0.5 else "cat"
    plt.title("dog prob=%0.4f\ntrue label: %s"
              % (prediction, label))
    plt.axis('off')


Found 1000 images belonging to 2 classes.

Let's compute the validation score on the full validation set:


In [34]:
valgen = ImageDataGenerator(preprocessing_function=preprocess_function)
val_flow = valgen.flow_from_directory(
    validation_folder, batch_size=batch_size, target_size=(224, 224),
    shuffle=False, class_mode='binary')

all_correct = []
for i, (X, y) in zip(range(val_flow.n // batch_size), val_flow):
    predictions = model.predict(X).ravel()
    correct = list((predictions > 0.5) == y)
    all_correct.extend(correct)
    print("Processed %d images" % len(all_correct))
    
print("Validation accuracy: %0.4f" % np.mean(all_correct))


Found 1000 images belonging to 2 classes.
Processed 50 images
Processed 100 images
Processed 150 images
Processed 200 images
Processed 250 images
Processed 300 images
Processed 350 images
Processed 400 images
Processed 450 images
Processed 500 images
Processed 550 images
Processed 600 images
Processed 650 images
Processed 700 images
Processed 750 images
Processed 800 images
Processed 850 images
Processed 900 images
Processed 950 images
Processed 1000 images
Validation accuracy: 0.9910

Exercise: display the example where the model makes the most confident mistakes.

To display images in jupyter notebook you can use:

from IPython.display import Image, display
import os.path as op

display(Image(op.join(validation_folder, image_name)))

The filenames of items sampled by a flow (without random shuffling) can be accessed via: val_flow.filenames.


In [ ]:


In [35]:
# %load solutions/dogs_vs_cats_worst_predictions.py
from IPython.display import Image, display

predicted_batches = []
label_batches = []
n_batches = val_flow.n // batch_size
for i, (X, y) in zip(range(n_batches), val_flow):
    predicted_batches.append(model.predict(X).ravel())
    label_batches.append(y)
    print("%d/%d" % (i + 1, n_batches))

predictions = np.concatenate(predicted_batches)
true_labels = np.concatenate(label_batches)
top_offenders = np.abs(predictions - true_labels).argsort()[::-1][:10]

image_names = np.array(val_flow.filenames, dtype=np.object)[top_offenders]
for img, pred in zip(image_names, predictions[top_offenders]):
    print("predicted dog probability: %0.4f" % pred)
    display(Image(op.join(validation_folder, img)))

# Analysis:
#
# The  worst offender has the grid occlusion: this kind of grids is
# probably much more frequent for dogs in in the rest of the training
# set. This is an unwanted bias of our dataset.
#
# To fix it we would probably need to add other images with similar
# occlusion patterns to teach the model to be invariant to them.
# This could be achieved with a dedicated data augmentation scheme.
#
# The image with both a dog and a cat could clearly be considered a
# labeling error: this kind of ambiguous images should be removed
# from the validation set to properly asses the generalization ability
# of the model.
#
# The other errors are harder to understand. Introspecting the gradients
# back to the pixel space could help understand what's misleading the
# model. It could be some elements in the background that are
# statistically very correlated to dogs in the training set.


1/20
2/20
3/20
4/20
5/20
6/20
7/20
8/20
9/20
10/20
11/20
12/20
13/20
14/20
15/20
16/20
17/20
18/20
19/20
20/20
predicted dog probability: 0.9496
predicted dog probability: 0.8695
predicted dog probability: 0.1766
predicted dog probability: 0.7228
predicted dog probability: 0.3212
predicted dog probability: 0.6309
predicted dog probability: 0.6283
predicted dog probability: 0.6013
predicted dog probability: 0.5463
predicted dog probability: 0.4960

Fine tuning

Let's identify the location of the residual blocks (merge by addition in a residual architecture):


In [36]:
from tensorflow.keras.layers import Add

[(i, l.output_shape)
 for (i, l) in enumerate(model.layers)
 if isinstance(l, Add)]


Out[36]:
[(17, (None, 56, 56, 256)),
 (27, (None, 56, 56, 256)),
 (37, (None, 56, 56, 256)),
 (49, (None, 28, 28, 512)),
 (59, (None, 28, 28, 512)),
 (69, (None, 28, 28, 512)),
 (79, (None, 28, 28, 512)),
 (91, (None, 14, 14, 1024)),
 (101, (None, 14, 14, 1024)),
 (111, (None, 14, 14, 1024)),
 (121, (None, 14, 14, 1024)),
 (131, (None, 14, 14, 1024)),
 (141, (None, 14, 14, 1024)),
 (153, (None, 7, 7, 2048)),
 (163, (None, 7, 7, 2048)),
 (173, (None, 7, 7, 2048))]

Let's fix the weights of the low level layers and fine tune the top level layers:


In [37]:
for i, layer in enumerate(model.layers):
    layer.trainable = i >= 151

Let's fine tune a bit the top level layers to see if we can further improve the accuracy. Use the nvidia-smi command in a bash terminal on the server to monitor the GPU usage when the model is training.


In [ ]:
from keras import optimizers

augmenting_datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
    preprocessing_function=preprocess_function,
)
train_flow = augmenting_datagen.flow_from_directory(
    train_folder, target_size=(224, 224), batch_size=batch_size,
    class_mode='binary', shuffle=True, seed=0)

opt = optimizers.SGD(lr=1e-4, momentum=0.9)
model.compile(optimizer=opt, loss='binary_crossentropy',
              metrics=['accuracy'])


# compute the validation metrics every 5000 training samples
history = model.fit_generator(train_flow, 5000,
                              epochs=30,
                              validation_data=val_flow,
                              validation_steps=val_flow.n)

# Note: the pretrained model was already very good. Fine tuning
# does not really seem to help. It might be more interesting to
# introspect the quality of the labeling in the training set to
# check for images that are too ambiguous and should be removed
# from the training set.

Bonus exercise: train your own architecture from scratch using adam and data augmentation. Start with a small architecture first (e.g. 4 convolutions layers interleaved with 2 max pooling layers followed by a Flatten and two fully connected layers).

Bonus exercise: run this notebook on an instance with several GPUs (NC12 or NC24 instances on Azure) and try to speed up the training with: https://medium.com/@kuza55/transparent-multi-gpu-training-on-tensorflow-with-keras-8b0016fd9012


In [ ]: