Fine-tuning a CNN with MXNet: Cats vs Dogs (Kaggle Redux)

In this tutorial we'll learn how to build a model to classifiy if an image is a cat or a dog. We'll use a pre-trained imagenet model from the MXNet model zoo. For practical problems we may not have a large dataset, hence its difficult to train these generalized models. However we can take advantage of models that are pre-trained on large dataset like imagenet where in the model has already learnt a lot of the image features.

The model used is based on the Convolution Neural Network (CNN) architecture. CNN's consist of multiple layer of fields that are model on biological visual receptors. At each layer the neuron collection process portions of input images and the outputs get tiled so as to obtain a higher level representation of the image. For more details on the how CNN's work check out CS231n course and MNIST example with MXNet.

To fine-tune a network we'll update all of the network’s weights and also replace the last fully-connected layer with the new number of output classes. In most cases to train we use a smaller learning rate given that we typically may have a small dataset. For more in depth reading on fine-tuning with MXNet check this tutorial

Setting up a deep learning environment with AWS Deep Learning AMI for MXNet

In this tutorial, we are going to use Deep Learning AMI. The Deep Learning AMI is a base Amazon Linux image provided by Amazon Web Services for use on Amazon Elastic Compute Cloud (Amazon EC2).It is designed to provide a stable, secure, and high performance execution environment for deep learning applications running on Amazon EC2. It includes popular deep learning frameworks, including MXNet.

For setting up an Deep Learning environment on AWS using Deep Learning AMI, please read this post on AWS AI Blog for detailed instruction.

Or you can choose to install MXNet to your own machine.

Prerequisites

Dataset: downoad and preprocessing

  1. Download data from https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data

  2. Extract train.zip into a folder "data" and create two folders "train" and "valid"

  3. Create additonal directories to get a directory structure as shown below, We'll label dogs as class 0 and cats as 1 (hence the prefix)

    train/

    ├── 1cats
    └── 0dogs

    valid/

    ├── 1cats
    └── 0dogs
  1. First move all the cat images into train/1cats and dog images into train/0dogs.

  2. Now lets move a percentage of these images in to the validation directory to create the validation set. You could use the code below to execute in a python script

    import os
    import random
    import shutil

    cats_dir = 'train/1cats'
    dogs_dir = 'train/0dogs'

    all_cats = os.listdir(cats_dir)
    all_dogs = os.listdir(dogs_dir)
    p = 20.0
    N = int(len(all_cats)/p)
    N = int(len(all_dogs)/p)

    for f in random.sample(all_cats, N):
       shutil.move( cats_dir + "/" + f, "valid/1cats/" + f)

    for f in random.sample(all_dogs, N):
        shutil.move( dogs_dir + "/" + f, "valid/0dogs/" + f)
  1. Create a list for training and validation set

    python ~/mxnet/tools/im2rec.py --list True --recursive True cats_dogs_train.lst data/train
    
    python ~/mxnet/tools/im2rec.py --list True --recursive True cats_dogs_val.lst data/valid
  2. Convert the images in to MXNet RecordIO format

    python ~/mxnet/tools/im2rec.py --resize 224 --quality 90 --num-thread 16 cats_dogs_train.lst data/train
    
    python ~/mxnet/tools/im2rec.py --resize 224 --quality 90 --num-thread 16 cats_dogs_val.lst data/valid

    You should see cats_dogs_train.rec and cats_dogs_val.rec files created.

Next we define the function which returns the data iterators.

CODE


In [9]:
# Data Iterators for cats vs dogs dataset

import mxnet as mx

def get_iterators(batch_size, data_shape=(3, 224, 224)):
    train = mx.io.ImageRecordIter(
        path_imgrec         = './cats_dogs_train.rec', 
        data_name           = 'data',
        label_name          = 'softmax_label',
        batch_size          = batch_size,
        data_shape          = data_shape,
        shuffle             = True,
        rand_crop           = True,
        rand_mirror         = True)
    val = mx.io.ImageRecordIter(
        path_imgrec         = './cats_dogs_val.rec',
        data_name           = 'data',
        label_name          = 'softmax_label',
        batch_size          = batch_size,
        data_shape          = data_shape,
        rand_crop           = False,
        rand_mirror         = False)
    return (train, val)

Dowload pre-trained model from the model zoo (Resnet-152)

We then download a pretrained 152-layer ResNet model and load into memory.

Note: If load_checkpoint reports error, we can remove the downloaded files and try get_model again.

In [17]:
# helper functions

import os, urllib
def download(url):
    filename = url.split("/")[-1]
    if not os.path.exists(filename):
        urllib.urlretrieve(url, filename)
        
def get_model(prefix, epoch):
    download(prefix+'-symbol.json')
    download(prefix+'-%04d.params' % (epoch,))

get_model('http://data.mxnet.io/models/imagenet/resnet/152-layers/resnet-152', 0)
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-152', 0)

Fine tuning the model

To fine-tune a network, we must first replace the last fully-connected layer with a new one that outputs the desired number of classes. We initialize its weights randomly. Then we continue training as normal. Sometimes it’s common use a smaller learning rate based on the intuition that we may already be close to a good result.

We first define a function which replaces the the last fully-connected layer for a given network.


In [6]:
def get_fine_tune_model(symbol, arg_params, num_classes, layer_name='flatten0'):
    """
    symbol: the pre-trained network symbol
    arg_params: the argument parameters of the pre-trained model
    num_classes: the number of classes for the fine-tune datasets
    layer_name: the layer name before the last fully-connected layer
    """
    all_layers = sym.get_internals()
    net = all_layers[layer_name+'_output']
    net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc1')
    net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
    new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})
    return (net, new_args)

Training the model

We now define a fit function that creates an MXNet module instance that we'll bind the data and symbols to.

init_params is called to randomly initialize parameters

set_params is called to replace all parameters except for the last fully-connected layer with pre-trained model.

Note: change mx.gpu to mx.cpu to run training on CPU (much slower)


In [10]:
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

def fit(symbol, arg_params, aux_params, train, val, batch_size, num_gpus=1, num_epoch=1):
    devs = [mx.gpu(i) for i in range(num_gpus)] # replace mx.gpu by mx.cpu for CPU training
    mod = mx.mod.Module(symbol=new_sym, context=devs)
    mod.bind(data_shapes=train.provide_data, label_shapes=train.provide_label)
    mod.init_params(initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2))
    mod.set_params(new_args, aux_params, allow_missing=True)
    mod.fit(train, val, 
        num_epoch=num_epoch,
        batch_end_callback = mx.callback.Speedometer(batch_size, 10),        
        kvstore='device',
        optimizer='sgd',
        optimizer_params={'learning_rate':0.009},
        eval_metric='acc')
    
    return mod

Now that we have the helper functions setup, we can start training. Its recommended that you train on a GPU instance, preferably p2.* family. In this example we assume an AWS EC2 p2.xlarge, which has one NVIDIA K80 GPU.


In [ ]:
num_classes = 2 # This is binary classification (dogs vs cat)
batch_per_gpu = 4
num_gpus = 1
(new_sym, new_args) = get_fine_tune_model(sym, arg_params, num_classes)

batch_size = batch_per_gpu * num_gpus
(train, val) = get_iterators(batch_size)
mod = fit(new_sym, new_args, aux_params, train, val, batch_size, num_gpus)
metric = mx.metric.Accuracy()
mod_score = mod.score(val, metric)
print mod_score

After 1 epoch we achive 97.22% training accuracy.

Lets save the newly trained model


In [11]:
prefix = 'resnet-mxnet-catsvsdogs'
epoch = 1
mc = mod.save_checkpoint(prefix, epoch)


2017-06-01 01:33:56,263 Saved checkpoint to "resnet-mxnet-catsvsdogs-0001.params"

Loading saved model


In [14]:
# load the model, make sure you have executed previous cells to train
import cv2
dshape = [('data', (1,3,224,224))]

def load_model(s_fname, p_fname):
    """
    Load model checkpoint from file.
    :return: (arg_params, aux_params)
    arg_params : dict of str to NDArray
        Model parameter, dict of name to NDArray of net's weights.
    aux_params : dict of str to NDArray
        Model parameter, dict of name to NDArray of net's auxiliary states.
    """
    symbol = mx.symbol.load(s_fname)
    save_dict = mx.nd.load(p_fname)
    arg_params = {}
    aux_params = {}
    for k, v in save_dict.items():
        tp, name = k.split(':', 1)
        if tp == 'arg':
            arg_params[name] = v
        if tp == 'aux':
            aux_params[name] = v
    return symbol, arg_params, aux_params

model_symbol = "resnet-mxnet-catsvsdogs-symbol.json"
model_params = "resnet-mxnet-catsvsdogs-0002.params"
sym, arg_params, aux_params = load_model(model_symbol, model_params)
mod = mx.mod.Module(symbol=sym)

# bind the model and set training == False; Define the data shape
mod.bind(for_training=False, data_shapes=dshape)
mod.set_params(arg_params, aux_params)

Generate Predictions for an arbitrary image


In [16]:
import urllib2
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])

def preprocess_image(img, show_img=False):
    '''
    convert the image to a numpy array
    '''
    img = cv2.resize(img, (224, 224))
    img = np.swapaxes(img, 0, 2)
    img = np.swapaxes(img, 1, 2) 
    img = img[np.newaxis, :] 
    return img

url = 'https://images-na.ssl-images-amazon.com/images/G/01/img15/pet-products/small-tiles/23695_pets_vertical_store_dogs_small_tile_8._CB312176604_.jpg'
req = urllib2.urlopen(url)

image = np.asarray(bytearray(req.read()), dtype="uint8")
image = cv2.imdecode(image, cv2.IMREAD_COLOR)
img = preprocess_image(image)

mod.forward(Batch([mx.nd.array(img)]))

# predict
prob = mod.get_outputs()[0].asnumpy()
print prob


[[  9.99993086e-01   6.96629468e-06]]

Inspecting incorrect labels


In [ ]:
# Generate predictions for entire validation dataset
import os
import cv2

path = 'data/valid/cats/' # change as needed
files = [path + f for f in os.listdir(path)]
incorrect_labels = []

# incorrect cat labels
for f in files:
    img = cv2.imread(f)
    img = preprocess_image(img)
    mod.forward(Batch([mx.nd.array(img)]))
    prob = mod.get_outputs()[0].asnumpy()
    if prob.argmax() != 1: # not a cat
        print f
        incorrect_labels.append(f)

In [ ]:
from matplotlib import pyplot as plt
%matplotlib inline
import numpy as np

# Plot helper
def plots(ims, figsize=(12,6), rows=1, interp=False, titles=None):
    if type(ims[0]) is np.ndarray:
        ims = np.array(ims).astype(np.uint8)
        if (ims.shape[-1] != 3):
            ims = ims.transpose((0,2,3,1))
    f = plt.figure(figsize=figsize)
    for i in range(len(ims)):
        sp = f.add_subplot(rows, len(ims)//rows, i+1)
        sp.axis('Off')
        if titles is not None:
            sp.set_title(titles[i], fontsize=16)
        plt.imshow(ims[i], interpolation=None if interp else 'none')

#individual plot of incorrect label
img_path = incorrect_labels[0]
plots([cv2.imread(img_path)])
plt.show()