Capstone Project: ResNet-50 for Cats.Vs.Dogs


In [1]:
from keras.datasets import mnist
from keras.models import Sequential, Model
from keras.layers import Dense, Dropout, Activation, Flatten, BatchNormalization, merge, Input
from keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D, AveragePooling2D, GlobalAveragePooling2D
from keras.utils import np_utils
from keras.models import model_from_json
from keras import backend as K
from keras.preprocessing import image
from keras.optimizers import SGD
from keras.utils.data_utils import get_file
import random
import os
import cv2

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
%matplotlib inline


Using TensorFlow backend.

Data preprocessing

  • The images in train folder are divided into a training set and a validation set.
  • The images both in training set and validation set are separately divided into two folders -- cat and dog according to their lables.

(the two steps above were finished in Preprocessing train dataset.ipynb)

  • The RGB color values of the images are rescaled to 0~1.
  • The size of the images are resized to 224*224.

In [2]:
from keras.preprocessing.image import ImageDataGenerator
image_width = 224
image_height = 224
image_size = (image_width, image_height)

train_datagen = ImageDataGenerator(rescale=1.0/255)

train_generator = train_datagen.flow_from_directory(
        'mytrain_ox',  # this is the target directory
        target_size=image_size,  # all images will be resized to 224x224
        batch_size=16,
        class_mode='binary')

validation_datagen = ImageDataGenerator(rescale=1.0/255)
validation_generator = validation_datagen.flow_from_directory(
        'myvalid_ox',  # this is the target directory
        target_size=image_size,  # all images will be resized to 224x224
        batch_size=16,
        class_mode='binary')


Found 6614 images belonging to 2 classes.
Found 735 images belonging to 2 classes.

show 16 images in the train dataset randomly


In [3]:
x, y = train_generator.next()

plt.figure(figsize=(16, 8))
for i, (img, label) in enumerate(zip(x, y)):
    plt.subplot(3, 6, i+1)
    if label == 1:
        plt.title('dog')
    else:
        plt.title('cat')
    plt.axis('off')
    plt.imshow(img, interpolation="nearest")


Build the structure of ResNet-50 for Cats.Vs.Dogs

  1. Define identity block.
  2. Define convolution block.
  3. Build the structure of ResNet-50 without top layer.
  4. Load weights
  5. Add top layer to ResNet-50.
  6. Setup training attribute.
  7. Compile the model.

1. Define identity block.

The identity_block is the block that has no conv layer at shortcut.

Arguments

  • input_tensor: input tensor
  • kernel_size: defualt 3, the kernel size of middle conv layer at main path
  • filters: list of integers, the nb_filters of 3 conv layer at main path
  • stage: integer, current stage label, used for generating layer names
  • block: 'a','b'..., current block label, used for generating layer names

In [4]:
def identity_block(input_tensor, kernel_size, filters, stage, block):
    
    nb_filter1, nb_filter2, nb_filter3 = filters
    if K.image_dim_ordering() == 'tf':
        bn_axis = 3
    else:
        bn_axis = 1
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    x = Convolution2D(nb_filter1, 1, 1, name=conv_name_base + '2a')(input_tensor)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
    x = Activation('relu')(x)

    x = Convolution2D(nb_filter2, kernel_size, kernel_size,
                      border_mode='same', name=conv_name_base + '2b')(x)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
    x = Activation('relu')(x)

    x = Convolution2D(nb_filter3, 1, 1, name=conv_name_base + '2c')(x)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)

    x = merge([x, input_tensor], mode='sum')
    x = Activation('relu')(x)
    return x

2. Define convolution block.

conv_block is the block that has a conv layer at shortcut

Arguments

  • input_tensor: input tensor
  • kernel_size: defualt 3, the kernel size of middle conv layer at main path
  • filters: list of integers, the nb_filters of 3 conv layer at main path
  • stage: integer, current stage label, used for generating layer names
  • block: 'a','b'..., current block label, used for generating layer names

    Note that from stage 3, the first conv layer at main path is with subsample=(2,2) And the shortcut should have subsample=(2,2) as well


In [5]:
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):

    nb_filter1, nb_filter2, nb_filter3 = filters
    if K.image_dim_ordering() == 'tf':
        bn_axis = 3
    else:
        bn_axis = 1
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    x = Convolution2D(nb_filter1, 1, 1, subsample=strides,
                      name=conv_name_base + '2a')(input_tensor)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
    x = Activation('relu')(x)

    x = Convolution2D(nb_filter2, kernel_size, kernel_size, border_mode='same',
                      name=conv_name_base + '2b')(x)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
    x = Activation('relu')(x)

    x = Convolution2D(nb_filter3, 1, 1, name=conv_name_base + '2c')(x)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)

    shortcut = Convolution2D(nb_filter3, 1, 1, subsample=strides,
                             name=conv_name_base + '1')(input_tensor)
    shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut)

    x = merge([x, shortcut], mode='sum')
    x = Activation('relu')(x)
    return x

3.Build the structure of ResNet-50 without top layer.


In [6]:
img_input = Input(shape=(image_width, image_height, 3))

x = ZeroPadding2D((3, 3))(img_input)
x = Convolution2D(64, 7, 7, subsample=(2, 2), name='conv1')(x)
x = BatchNormalization(axis=3, name='bn_conv1')(x)
x = Activation('relu')(x)
x = MaxPooling2D((3, 3), strides=(2, 2))(x)

x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')

x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')

x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')

x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')

base_model = Model(img_input, x)

4. Load weights.


In [7]:
TF_WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/\
v0.2/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
weights_path = get_file('resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5',
                        TF_WEIGHTS_PATH_NO_TOP,
                        cache_subdir='models',
                        md5_hash='a268eb855778b3df3c7506639542a6af')
base_model.load_weights(weights_path)

5. Add top layer to ResNet-50.


In [8]:
x = AveragePooling2D((7, 7), name='avg_pool')(base_model.output)
x = Flatten()(x)
x = Dropout(0.5)(x)
x = Dense(1, activation='sigmoid', name='output')(x)

model = Model(input=base_model.input, output=x)

6. Setup training attribute.

Freeze the weights except the top layer.


In [9]:
top_num = 4
for layer in model.layers[:-top_num]:
    layer.trainable = False

for layer in model.layers[-top_num:]:
    layer.trainable = True

7. Compile the model.


In [10]:
model.compile(loss='binary_crossentropy', optimizer='adadelta', metrics=['accuracy'])

Train ResNet-50 for Cats.Vs.Dogs and Save the best model.


In [11]:
from keras.callbacks import ModelCheckpoint, TensorBoard
best_model = ModelCheckpoint("resnet_best.h5", monitor='val_acc', verbose=0, save_best_only=True)

In [12]:
model.fit_generator(
        train_generator,
        samples_per_epoch=2048,
        nb_epoch=40,
        validation_data=validation_generator,
        nb_val_samples=1024,
        callbacks=[best_model, TensorBoard(log_dir='./logs', histogram_freq=1)])


WARNING:tensorflow:Passing a `GraphDef` to the SummaryWriter is deprecated. Pass a `Graph` object instead, such as `sess.graph`.
Epoch 1/40
2048/2048 [==============================] - 47s - loss: 0.4672 - acc: 0.7866 - val_loss: 0.2376 - val_acc: 0.9172
Epoch 2/40
2048/2048 [==============================] - 43s - loss: 0.2919 - acc: 0.8818 - val_loss: 0.1559 - val_acc: 0.9480
Epoch 3/40
2048/2048 [==============================] - 41s - loss: 0.2190 - acc: 0.9121 - val_loss: 0.1455 - val_acc: 0.9470
Epoch 4/40
2038/2048 [============================>.] - ETA: 0s - loss: 0.2006 - acc: 0.9200
/usr/local/lib/python2.7/site-packages/keras/engine/training.py:1470: UserWarning: Epoch comprised more than `samples_per_epoch` samples, which might affect learning results. Set `samples_per_epoch` correctly to avoid this warning.
  warnings.warn('Epoch comprised more than '
2054/2048 [==============================] - 38s - loss: 0.2007 - acc: 0.9202 - val_loss: 0.1108 - val_acc: 0.9673
Epoch 5/40
2048/2048 [==============================] - 38s - loss: 0.1842 - acc: 0.9224 - val_loss: 0.1269 - val_acc: 0.9586
Epoch 6/40
2048/2048 [==============================] - 38s - loss: 0.1537 - acc: 0.9419 - val_loss: 0.0872 - val_acc: 0.9702
Epoch 7/40
2054/2048 [==============================] - 38s - loss: 0.1453 - acc: 0.9455 - val_loss: 0.0953 - val_acc: 0.9663
Epoch 8/40
2048/2048 [==============================] - 38s - loss: 0.1402 - acc: 0.9492 - val_loss: 0.0723 - val_acc: 0.9769
Epoch 9/40
2048/2048 [==============================] - 38s - loss: 0.1456 - acc: 0.9429 - val_loss: 0.0844 - val_acc: 0.9711
Epoch 10/40
2054/2048 [==============================] - 38s - loss: 0.1433 - acc: 0.9421 - val_loss: 0.0773 - val_acc: 0.9692
Epoch 11/40
2048/2048 [==============================] - 38s - loss: 0.1156 - acc: 0.9580 - val_loss: 0.0707 - val_acc: 0.9827
Epoch 12/40
2048/2048 [==============================] - 37s - loss: 0.1147 - acc: 0.9521 - val_loss: 0.0780 - val_acc: 0.9740
Epoch 13/40
2054/2048 [==============================] - 38s - loss: 0.1223 - acc: 0.9479 - val_loss: 0.0735 - val_acc: 0.9788
Epoch 14/40
2048/2048 [==============================] - 37s - loss: 0.1315 - acc: 0.9453 - val_loss: 0.0654 - val_acc: 0.9769
Epoch 15/40
2048/2048 [==============================] - 37s - loss: 0.1242 - acc: 0.9517 - val_loss: 0.0682 - val_acc: 0.9769
Epoch 16/40
2048/2048 [==============================] - 38s - loss: 0.1032 - acc: 0.9634 - val_loss: 0.0642 - val_acc: 0.9769
Epoch 17/40
2054/2048 [==============================] - 38s - loss: 0.1016 - acc: 0.9611 - val_loss: 0.0635 - val_acc: 0.9750
Epoch 18/40
2048/2048 [==============================] - 37s - loss: 0.1241 - acc: 0.9531 - val_loss: 0.0648 - val_acc: 0.9779
Epoch 19/40
2048/2048 [==============================] - 38s - loss: 0.1029 - acc: 0.9585 - val_loss: 0.0642 - val_acc: 0.9788
Epoch 20/40
2054/2048 [==============================] - 38s - loss: 0.1100 - acc: 0.9542 - val_loss: 0.0640 - val_acc: 0.9759
Epoch 21/40
2048/2048 [==============================] - 37s - loss: 0.1097 - acc: 0.9561 - val_loss: 0.0645 - val_acc: 0.9740
Epoch 22/40
2048/2048 [==============================] - 37s - loss: 0.0978 - acc: 0.9614 - val_loss: 0.0629 - val_acc: 0.9798
Epoch 23/40
2054/2048 [==============================] - 38s - loss: 0.1194 - acc: 0.9489 - val_loss: 0.0578 - val_acc: 0.9779
Epoch 24/40
2048/2048 [==============================] - 37s - loss: 0.0933 - acc: 0.9663 - val_loss: 0.0523 - val_acc: 0.9798
Epoch 25/40
2048/2048 [==============================] - 37s - loss: 0.1202 - acc: 0.9536 - val_loss: 0.0632 - val_acc: 0.9731
Epoch 26/40
2054/2048 [==============================] - 38s - loss: 0.1078 - acc: 0.9581 - val_loss: 0.0585 - val_acc: 0.9817
Epoch 27/40
2048/2048 [==============================] - 38s - loss: 0.1031 - acc: 0.9580 - val_loss: 0.0595 - val_acc: 0.9769
Epoch 28/40
2048/2048 [==============================] - 38s - loss: 0.1150 - acc: 0.9556 - val_loss: 0.0554 - val_acc: 0.9807
Epoch 29/40
2048/2048 [==============================] - 37s - loss: 0.0906 - acc: 0.9624 - val_loss: 0.0543 - val_acc: 0.9779
Epoch 30/40
2054/2048 [==============================] - 38s - loss: 0.1028 - acc: 0.9659 - val_loss: 0.0552 - val_acc: 0.9769
Epoch 31/40
2048/2048 [==============================] - 38s - loss: 0.0913 - acc: 0.9653 - val_loss: 0.0500 - val_acc: 0.9808
Epoch 32/40
2048/2048 [==============================] - 38s - loss: 0.1122 - acc: 0.9595 - val_loss: 0.0580 - val_acc: 0.9817
Epoch 33/40
2054/2048 [==============================] - 38s - loss: 0.0950 - acc: 0.9635 - val_loss: 0.0584 - val_acc: 0.9778
Epoch 34/40
2048/2048 [==============================] - 38s - loss: 0.0892 - acc: 0.9663 - val_loss: 0.0545 - val_acc: 0.9836
Epoch 35/40
2048/2048 [==============================] - 38s - loss: 0.0940 - acc: 0.9663 - val_loss: 0.0542 - val_acc: 0.9778
Epoch 36/40
2054/2048 [==============================] - 38s - loss: 0.0921 - acc: 0.9640 - val_loss: 0.0514 - val_acc: 0.9798
Epoch 37/40
2048/2048 [==============================] - 38s - loss: 0.0948 - acc: 0.9663 - val_loss: 0.0515 - val_acc: 0.9827
Epoch 38/40
2048/2048 [==============================] - 37s - loss: 0.0983 - acc: 0.9634 - val_loss: 0.0528 - val_acc: 0.9827
Epoch 39/40
2054/2048 [==============================] - 38s - loss: 0.0968 - acc: 0.9640 - val_loss: 0.0532 - val_acc: 0.9808
Epoch 40/40
2048/2048 [==============================] - 38s - loss: 0.0776 - acc: 0.9692 - val_loss: 0.0616 - val_acc: 0.9788
Out[12]:
<keras.callbacks.History at 0x14b820dd0>

In [13]:
with open('resnet.json', 'w') as f:
    f.write(model.to_json())

Using the best model to predict images


In [14]:
with open('resnet.json', 'r') as f:
    model = model_from_json(f.read())
model.load_weights('resnet_best.h5')

Visualization

Show 20 random images and the prediction in the test folder by the classifier


In [21]:
x, y = validation_generator.next()
plt.figure(figsize=(16, 8))
for i in range(16):
    prediction = model.predict(np.expand_dims(x[i], axis=0))[0]
    
    plt.subplot(3, 6, i+1)
    if prediction < 0.5:
        plt.title('cat %.2f%%' % (100 - prediction*100))
    else:
        plt.title('dog %.2f%%' % (prediction*100))
    
    plt.axis('off')
    plt.imshow(x[i])


Draw feature heatmap

The shape of the output of the base model is (7, 7, 2048).

The shape of the weights of full connection is (2048, 1).

In order to draw the heatmap, I calculated the Class Activation Mapping (cam) of the output of the network then used OpenCV to visualize the result.

$cam = (P-0.5)*output*w$

  • cam: class activation mapping
  • P: the probability of cats or dogs
  • output: the output of base model
  • w: the weights of the full connection

In [22]:
layer_dict = dict([(layer.name, layer) for layer in model.layers])
weights = model.layers[-1].get_weights()[0]
model2 = Model(input=model.input, output=[layer_dict['merge_16'].output, model.output])

In [28]:
x, y = validation_generator.next()
plt.figure(figsize=(16, 8))
for i in range(16):
    img = (x[i]*255).astype(np.uint8)
    
    [base_model_outputs, prediction] = model2.predict(np.expand_dims(x[i], axis=0))
    prediction = prediction[0]
    base_model_outputs = base_model_outputs[0]
    
    plt.subplot(3, 6, i+1)
    if prediction < 0.5:
        plt.title('cat %.2f%%' % (100 - prediction*100))
    else:
        plt.title('dog %.2f%%' % (prediction*100))
    
    cam = (prediction - 0.5) * np.matmul(base_model_outputs, weights)

    cam -= cam.min()
    cam /= cam.max()
    cam -= 0.2
    cam /= 0.8
    
    cam = cv2.resize(cam, (224, 224))
    heatmap = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET)
    heatmap[np.where(cam <= 0.2)] = 0
    
    out = cv2.addWeighted(img, 0.8, heatmap[:,:,::-1], 0.4, 0)
    
    plt.axis('off')
    plt.imshow(out)



In [ ]: