Kyle, Joe, Mark: Milestone 4, Deep Network

Our deep network is a CNN based on a Keras website example and trained it using binary-xent with sigmoidal activation for the output layer. We resized all our images to 300x185x3 upon loading them and selected binary crossentropy as our loss function. More details about the model are available later in the document.

We ran into some issues during the training of our model related to the amount of images AWS servers could manage at any given time. With a "p2.xlarge" instance roughly 15,000 images total could be stored in memory (12GB on Tesla K80). We had many more images than this, so we took advantage of being able to run 10 training epochs to train our model 10k images at a time, draw another sample, and train again in an iterative process to improve test performance. We also standardized each image, so that it has a mean of 0 and std of 1.

After 100 epochs of training on a small subset of our data, our model achieved a binary accuracy of 0.9939. As can be seen from the matplotlib visualizations, four out of our seven genres were predicted as almost entirely zero while three predicted genre lables were more accurate. Rather than using overall accuracy, we explored the average Precision and Recall across test observations when tuning parameters for our model. Our main limiting factor for model improvement in this round was time required to test new parameters without "breaking the bank" on AWS credits. We hope to continue improving this model prior to the final paper.


In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
from scipy import ndimage

import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, Activation, Conv2D, MaxPooling2D
from keras.optimizers import SGD, RMSprop
from keras.utils import plot_model
from keras.utils.vis_utils import model_to_dot
from keras.preprocessing.image import ImageDataGenerator

from IPython.display import SVG
import matplotlib.pyplot as plt

%matplotlib inline


Using TensorFlow backend.

In [2]:
# Load data
%cd ~/data/
labs = pd.read_csv('multilabels.csv')
ids = pd.read_csv('features_V1.csv', usecols=[0])

# Take care of some weirdness that led to duplicate entries
labs = pd.concat([ids,labs], axis=1, ignore_index=True)
labs = labs.drop_duplicates(subset=[0])

ids = labs.pop(0).as_matrix()
labs = labs.as_matrix()


/home/ubuntu/data

In [3]:
# Split train/test - 15k is about the limit of what we can hold in memory (12GB on Tesla K80)
n_train = 10000
n_test = 5000

rnd_ids = np.random.choice(np.squeeze(ids), size=n_train+n_test, replace=False)
train_ids = rnd_ids[:n_train]
test_ids = rnd_ids[n_train:]

# Pull in multilabels
y_train = labs[np.nonzero(np.in1d(np.squeeze(ids),train_ids))[0]]
y_test = labs[np.nonzero(np.in1d(np.squeeze(ids),test_ids))[0]]



# Read in images - need to do some goofy stuff here to handle the highly irregular image sizes and formats
X_train = np.zeros([n_train, 600, 185, 3])
ct = 0
for i in train_ids:
    IM = ndimage.imread('posters/{}.jpg'.format(i))
    try:
        X_train[ct,:IM.shape[0],:,:] = IM[:,:,:3]
    except:
        X_train[ct,:IM.shape[0],:,0] = IM
    ct += 1
    if ct % 100 == 0:
        print 'training data {i}/{n} loaded'.format(i=ct, n=n_train)
X_train = X_train[:,:300,:,:] # trim excess off edges
print 'training data loaded'


X_test = np.zeros([n_test, 600, 185, 3])
ct = 0
for i in test_ids:
    IM = ndimage.imread('posters/{}.jpg'.format(i))
    try:
        X_test[ct,:IM.shape[0],:,:] = IM[:,:,:3]
    except:
        X_test[ct,:IM.shape[0],:,0] = IM
    ct += 1
    if ct % 100 == 0:
        print 'test data {i}/{n} loaded'.format(i=ct, n=n_test)
X_test = X_test[:,:300,:,:] # trim excess off edges
print 'test data loaded'

# Create dataGenerator to feed image batches - 
# this is nice because it also standardizes training data
datagen = ImageDataGenerator(
    samplewise_center=True,
    samplewise_std_normalization=True)


# compute quantities required for featurewise normalization
# (std, mean, and principal components if ZCA whitening is applied)
datagen.fit(X_train)


training data 100/10000 loaded
training data 200/10000 loaded
training data 300/10000 loaded
training data 400/10000 loaded
training data 500/10000 loaded
training data 600/10000 loaded
training data 700/10000 loaded
training data 800/10000 loaded
training data 900/10000 loaded
training data 1000/10000 loaded
training data 1100/10000 loaded
training data 1200/10000 loaded
training data 1300/10000 loaded
training data 1400/10000 loaded
training data 1500/10000 loaded
training data 1600/10000 loaded
training data 1700/10000 loaded
training data 1800/10000 loaded
training data 1900/10000 loaded
training data 2000/10000 loaded
training data 2100/10000 loaded
training data 2200/10000 loaded
training data 2300/10000 loaded
training data 2400/10000 loaded
training data 2500/10000 loaded
training data 2600/10000 loaded
training data 2700/10000 loaded
training data 2800/10000 loaded
training data 2900/10000 loaded
training data 3000/10000 loaded
training data 3100/10000 loaded
training data 3200/10000 loaded
training data 3300/10000 loaded
training data 3400/10000 loaded
training data 3500/10000 loaded
training data 3600/10000 loaded
training data 3700/10000 loaded
training data 3800/10000 loaded
training data 3900/10000 loaded
training data 4000/10000 loaded
training data 4100/10000 loaded
training data 4200/10000 loaded
training data 4300/10000 loaded
training data 4400/10000 loaded
training data 4500/10000 loaded
training data 4600/10000 loaded
training data 4700/10000 loaded
training data 4800/10000 loaded
training data 4900/10000 loaded
training data 5000/10000 loaded
training data 5100/10000 loaded
training data 5200/10000 loaded
training data 5300/10000 loaded
training data 5400/10000 loaded
training data 5500/10000 loaded
training data 5600/10000 loaded
training data 5700/10000 loaded
training data 5800/10000 loaded
training data 5900/10000 loaded
training data 6000/10000 loaded
training data 6100/10000 loaded
training data 6200/10000 loaded
training data 6300/10000 loaded
training data 6400/10000 loaded
training data 6500/10000 loaded
training data 6600/10000 loaded
training data 6700/10000 loaded
training data 6800/10000 loaded
training data 6900/10000 loaded
training data 7000/10000 loaded
training data 7100/10000 loaded
training data 7200/10000 loaded
training data 7300/10000 loaded
training data 7400/10000 loaded
training data 7500/10000 loaded
training data 7600/10000 loaded
training data 7700/10000 loaded
training data 7800/10000 loaded
training data 7900/10000 loaded
training data 8000/10000 loaded
training data 8100/10000 loaded
training data 8200/10000 loaded
training data 8300/10000 loaded
training data 8400/10000 loaded
training data 8500/10000 loaded
training data 8600/10000 loaded
training data 8700/10000 loaded
training data 8800/10000 loaded
training data 8900/10000 loaded
training data 9000/10000 loaded
training data 9100/10000 loaded
training data 9200/10000 loaded
training data 9300/10000 loaded
training data 9400/10000 loaded
training data 9500/10000 loaded
training data 9600/10000 loaded
training data 9700/10000 loaded
training data 9800/10000 loaded
training data 9900/10000 loaded
training data 10000/10000 loaded
training data loaded
test data 100/5000 loaded
test data 200/5000 loaded
test data 300/5000 loaded
test data 400/5000 loaded
test data 500/5000 loaded
test data 600/5000 loaded
test data 700/5000 loaded
test data 800/5000 loaded
test data 900/5000 loaded
test data 1000/5000 loaded
test data 1100/5000 loaded
test data 1200/5000 loaded
test data 1300/5000 loaded
test data 1400/5000 loaded
test data 1500/5000 loaded
test data 1600/5000 loaded
test data 1700/5000 loaded
test data 1800/5000 loaded
test data 1900/5000 loaded
test data 2000/5000 loaded
test data 2100/5000 loaded
test data 2200/5000 loaded
test data 2300/5000 loaded
test data 2400/5000 loaded
test data 2500/5000 loaded
test data 2600/5000 loaded
test data 2700/5000 loaded
test data 2800/5000 loaded
test data 2900/5000 loaded
test data 3000/5000 loaded
test data 3100/5000 loaded
test data 3200/5000 loaded
test data 3300/5000 loaded
test data 3400/5000 loaded
test data 3500/5000 loaded
test data 3600/5000 loaded
test data 3700/5000 loaded
test data 3800/5000 loaded
test data 3900/5000 loaded
test data 4000/5000 loaded
test data 4100/5000 loaded
test data 4200/5000 loaded
test data 4300/5000 loaded
test data 4400/5000 loaded
test data 4500/5000 loaded
test data 4600/5000 loaded
test data 4700/5000 loaded
test data 4800/5000 loaded
test data 4900/5000 loaded
test data 5000/5000 loaded
test data loaded

In [7]:
# Build CNN model

model = Sequential()

# input: 300x185 images with 3 channels -> (300, 185, 3) tensors.
# this applies 32 convolution filters of size 3x3 each.
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(300, 185, 3)))
model.add(Conv2D(32, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(7, activation='sigmoid'))

#sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
#model.compile(loss='binary_crossentropy', optimizer=sgd)



model.compile(loss='binary_crossentropy',
              optimizer='rmsprop',
              metrics=['binary_accuracy'])

model.summary()

# Visualize network graph
#SVG(model_to_dot(model).create(prog='dot', format='svg'))


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_5 (Conv2D)            (None, 298, 183, 32)      896       
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 296, 181, 32)      9248      
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 148, 90, 32)       0         
_________________________________________________________________
dropout_4 (Dropout)          (None, 148, 90, 32)       0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 146, 88, 64)       18496     
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 144, 86, 64)       36928     
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 72, 43, 64)        0         
_________________________________________________________________
dropout_5 (Dropout)          (None, 72, 43, 64)        0         
_________________________________________________________________
flatten_2 (Flatten)          (None, 198144)            0         
_________________________________________________________________
dense_3 (Dense)              (None, 256)               50725120  
_________________________________________________________________
dropout_6 (Dropout)          (None, 256)               0         
_________________________________________________________________
dense_4 (Dense)              (None, 7)                 1799      
=================================================================
Total params: 50,792,487.0
Trainable params: 50,792,487.0
Non-trainable params: 0.0
_________________________________________________________________

Model construction

We decided to implement a convolutional neural network that we adapted from the Keras "VGG-like convnet" tutorial. This is essentially a much simplified version of the pre-trained VGG-16 model that we fine-tuned.

The network consists of a two convolution layers with ReLu activation, followed by max pooling, a repeat of this motif, and then two fully connected layers - the first with ReLU activation and the output layer with sigmoidal activation. The model is regularized using dropout after each max pool and between the two fully connected layers.

We trained the network using a binary cross-entropy loss function and an RMSprop optimizer. We tried a number of other optimizers, including SGD w/momentum, Adam, and Nadam - none of which showed much difference in performance and were all slighly slower to learn than RMSprop.

Because we use an RMSProp optimizer we plan to only tune the learning rate for the model as the documentation suggests leaving the other parameters at their defaults. First we test this model on a smaller set (512) images to make sure it works well, then move on to training this model on a set of 10,000 images.


In [11]:
# Fit the model with a small batch of training data, n=512 images
model.fit_generator(datagen.flow(X_train, y_train, batch_size=32),
                    steps_per_epoch=len(X_train) / 32, epochs=100)

score = model.evaluate(X_test, y_test, batch_size=32)


Epoch 1/100
16/16 [==============================] - 7s - loss: 0.9580 - binary_accuracy: 0.7522      
Epoch 2/100
16/16 [==============================] - 6s - loss: 0.5059 - binary_accuracy: 0.7799     
Epoch 3/100
16/16 [==============================] - 6s - loss: 0.4889 - binary_accuracy: 0.7793     
Epoch 4/100
16/16 [==============================] - 6s - loss: 0.4813 - binary_accuracy: 0.7877     
Epoch 5/100
16/16 [==============================] - 6s - loss: 0.4676 - binary_accuracy: 0.7997     
Epoch 6/100
16/16 [==============================] - 6s - loss: 0.4140 - binary_accuracy: 0.8172     
Epoch 7/100
16/16 [==============================] - 6s - loss: 0.2996 - binary_accuracy: 0.8753     
Epoch 8/100
16/16 [==============================] - 6s - loss: 0.2085 - binary_accuracy: 0.9132     
Epoch 9/100
16/16 [==============================] - 7s - loss: 0.1439 - binary_accuracy: 0.9475     
Epoch 10/100
16/16 [==============================] - 7s - loss: 0.1091 - binary_accuracy: 0.9595     
Epoch 11/100
16/16 [==============================] - 7s - loss: 0.0982 - binary_accuracy: 0.9662     
Epoch 12/100
16/16 [==============================] - 7s - loss: 0.0696 - binary_accuracy: 0.9738     
Epoch 13/100
16/16 [==============================] - 7s - loss: 0.0578 - binary_accuracy: 0.9796     
Epoch 14/100
16/16 [==============================] - 7s - loss: 0.0497 - binary_accuracy: 0.9824     
Epoch 15/100
16/16 [==============================] - 7s - loss: 0.0421 - binary_accuracy: 0.9866     
Epoch 16/100
16/16 [==============================] - 7s - loss: 0.0604 - binary_accuracy: 0.9830     
Epoch 17/100
16/16 [==============================] - 7s - loss: 0.0473 - binary_accuracy: 0.9830     
Epoch 18/100
16/16 [==============================] - 7s - loss: 0.0483 - binary_accuracy: 0.9849     
Epoch 19/100
16/16 [==============================] - 7s - loss: 0.0465 - binary_accuracy: 0.9860     
Epoch 20/100
16/16 [==============================] - 7s - loss: 0.0441 - binary_accuracy: 0.9855     
Epoch 21/100
16/16 [==============================] - 7s - loss: 0.0332 - binary_accuracy: 0.9874     
Epoch 22/100
16/16 [==============================] - 7s - loss: 0.0318 - binary_accuracy: 0.9888     
Epoch 23/100
16/16 [==============================] - 7s - loss: 0.0395 - binary_accuracy: 0.9891     
Epoch 24/100
16/16 [==============================] - 7s - loss: 0.0396 - binary_accuracy: 0.9874     
Epoch 25/100
16/16 [==============================] - 7s - loss: 0.0214 - binary_accuracy: 0.9905     
Epoch 26/100
16/16 [==============================] - 7s - loss: 0.0295 - binary_accuracy: 0.9888     
Epoch 27/100
16/16 [==============================] - 7s - loss: 0.0252 - binary_accuracy: 0.9908     
Epoch 28/100
16/16 [==============================] - 7s - loss: 0.0368 - binary_accuracy: 0.9874     
Epoch 29/100
16/16 [==============================] - 7s - loss: 0.0283 - binary_accuracy: 0.9902     
Epoch 30/100
16/16 [==============================] - 7s - loss: 0.0234 - binary_accuracy: 0.9919     
Epoch 31/100
16/16 [==============================] - 7s - loss: 0.0226 - binary_accuracy: 0.9916     
Epoch 32/100
16/16 [==============================] - 7s - loss: 0.0148 - binary_accuracy: 0.9944     
Epoch 33/100
16/16 [==============================] - 7s - loss: 0.0219 - binary_accuracy: 0.9936     
Epoch 34/100
16/16 [==============================] - 7s - loss: 0.0245 - binary_accuracy: 0.9914     
Epoch 35/100
16/16 [==============================] - 7s - loss: 0.0190 - binary_accuracy: 0.9916     
Epoch 36/100
16/16 [==============================] - 7s - loss: 0.0453 - binary_accuracy: 0.9880     
Epoch 37/100
16/16 [==============================] - 7s - loss: 0.0216 - binary_accuracy: 0.9933     
Epoch 38/100
16/16 [==============================] - 7s - loss: 0.0191 - binary_accuracy: 0.9933     
Epoch 39/100
16/16 [==============================] - 7s - loss: 0.0192 - binary_accuracy: 0.9922     
Epoch 40/100
16/16 [==============================] - 7s - loss: 0.0188 - binary_accuracy: 0.9936     
Epoch 41/100
16/16 [==============================] - 7s - loss: 0.0226 - binary_accuracy: 0.9925     
Epoch 42/100
16/16 [==============================] - 6s - loss: 0.0147 - binary_accuracy: 0.9953         
Epoch 43/100
16/16 [==============================] - 6s - loss: 0.0179 - binary_accuracy: 0.9939     
Epoch 44/100
16/16 [==============================] - 6s - loss: 0.0166 - binary_accuracy: 0.9947     
Epoch 45/100
16/16 [==============================] - 6s - loss: 0.0183 - binary_accuracy: 0.9933     
Epoch 46/100
16/16 [==============================] - 7s - loss: 0.0183 - binary_accuracy: 0.9944     
Epoch 47/100
16/16 [==============================] - 6s - loss: 0.0147 - binary_accuracy: 0.9944     
Epoch 48/100
16/16 [==============================] - 6s - loss: 0.0211 - binary_accuracy: 0.9930         
Epoch 49/100
16/16 [==============================] - 6s - loss: 0.0173 - binary_accuracy: 0.9933     
Epoch 50/100
16/16 [==============================] - 6s - loss: 0.0220 - binary_accuracy: 0.9916     
Epoch 51/100
16/16 [==============================] - 6s - loss: 0.0179 - binary_accuracy: 0.9941     
Epoch 52/100
16/16 [==============================] - 6s - loss: 0.0154 - binary_accuracy: 0.9944     
Epoch 53/100
16/16 [==============================] - 6s - loss: 0.0195 - binary_accuracy: 0.9936     
Epoch 54/100
16/16 [==============================] - 6s - loss: 0.0141 - binary_accuracy: 0.9936     
Epoch 55/100
16/16 [==============================] - 6s - loss: 0.0141 - binary_accuracy: 0.9947         
Epoch 56/100
16/16 [==============================] - 6s - loss: 0.0145 - binary_accuracy: 0.9961     
Epoch 57/100
16/16 [==============================] - 6s - loss: 0.0222 - binary_accuracy: 0.9941     
Epoch 58/100
16/16 [==============================] - 6s - loss: 0.0164 - binary_accuracy: 0.9941     
Epoch 59/100
16/16 [==============================] - 6s - loss: 0.0129 - binary_accuracy: 0.9947     
Epoch 60/100
16/16 [==============================] - 6s - loss: 0.0159 - binary_accuracy: 0.9947     
Epoch 61/100
16/16 [==============================] - 6s - loss: 0.0123 - binary_accuracy: 0.9944     
Epoch 62/100
16/16 [==============================] - 6s - loss: 0.0118 - binary_accuracy: 0.9953     
Epoch 63/100
16/16 [==============================] - 6s - loss: 0.0202 - binary_accuracy: 0.9941     
Epoch 64/100
16/16 [==============================] - 6s - loss: 0.0100 - binary_accuracy: 0.9955     
Epoch 65/100
16/16 [==============================] - 6s - loss: 0.0124 - binary_accuracy: 0.9958     
Epoch 66/100
16/16 [==============================] - 6s - loss: 0.0325 - binary_accuracy: 0.9925     
Epoch 67/100
16/16 [==============================] - 6s - loss: 0.0134 - binary_accuracy: 0.9947     
Epoch 68/100
16/16 [==============================] - 6s - loss: 0.0132 - binary_accuracy: 0.9950     
Epoch 69/100
16/16 [==============================] - 6s - loss: 0.0164 - binary_accuracy: 0.9936     
Epoch 70/100
16/16 [==============================] - 6s - loss: 0.0210 - binary_accuracy: 0.9936     
Epoch 71/100
16/16 [==============================] - 6s - loss: 0.0139 - binary_accuracy: 0.9955     
Epoch 72/100
16/16 [==============================] - 6s - loss: 0.0135 - binary_accuracy: 0.9961         
Epoch 73/100
16/16 [==============================] - 6s - loss: 0.0289 - binary_accuracy: 0.9925     
Epoch 74/100
16/16 [==============================] - 6s - loss: 0.0093 - binary_accuracy: 0.9955         
Epoch 75/100
16/16 [==============================] - 6s - loss: 0.0137 - binary_accuracy: 0.9955         
Epoch 76/100
16/16 [==============================] - 6s - loss: 0.0149 - binary_accuracy: 0.9941     
Epoch 77/100
16/16 [==============================] - 6s - loss: 0.0091 - binary_accuracy: 0.9967     
Epoch 78/100
16/16 [==============================] - 6s - loss: 0.0100 - binary_accuracy: 0.9955         
Epoch 79/100
16/16 [==============================] - 6s - loss: 0.0111 - binary_accuracy: 0.9950         
Epoch 80/100
16/16 [==============================] - 6s - loss: 0.0153 - binary_accuracy: 0.9947     
Epoch 81/100
16/16 [==============================] - 6s - loss: 0.0213 - binary_accuracy: 0.9955         
Epoch 82/100
16/16 [==============================] - 6s - loss: 0.0144 - binary_accuracy: 0.9953     
Epoch 83/100
16/16 [==============================] - 6s - loss: 0.0120 - binary_accuracy: 0.9958     
Epoch 84/100
16/16 [==============================] - 6s - loss: 0.0094 - binary_accuracy: 0.9967         
Epoch 85/100
16/16 [==============================] - 6s - loss: 0.0153 - binary_accuracy: 0.9947     
Epoch 86/100
16/16 [==============================] - 6s - loss: 0.0139 - binary_accuracy: 0.9955         
Epoch 87/100
16/16 [==============================] - 6s - loss: 0.0175 - binary_accuracy: 0.9955         
Epoch 88/100
16/16 [==============================] - 6s - loss: 0.0156 - binary_accuracy: 0.9947     
Epoch 89/100
16/16 [==============================] - 6s - loss: 0.0417 - binary_accuracy: 0.9922     
Epoch 90/100
16/16 [==============================] - 6s - loss: 0.0113 - binary_accuracy: 0.9961     
Epoch 91/100
16/16 [==============================] - 6s - loss: 0.0127 - binary_accuracy: 0.9964         
Epoch 92/100
16/16 [==============================] - 6s - loss: 0.0128 - binary_accuracy: 0.9947     
Epoch 93/100
16/16 [==============================] - 6s - loss: 0.0122 - binary_accuracy: 0.9955         
Epoch 94/100
16/16 [==============================] - 6s - loss: 0.0112 - binary_accuracy: 0.9964     
Epoch 95/100
16/16 [==============================] - 6s - loss: 0.0093 - binary_accuracy: 0.9967         
Epoch 96/100
16/16 [==============================] - 6s - loss: 0.0096 - binary_accuracy: 0.9964         
Epoch 97/100
16/16 [==============================] - 6s - loss: 0.0148 - binary_accuracy: 0.9950     
Epoch 98/100
16/16 [==============================] - 6s - loss: 0.0099 - binary_accuracy: 0.9961         
Epoch 99/100
16/16 [==============================] - 6s - loss: 0.0118 - binary_accuracy: 0.9958         
Epoch 100/100
16/16 [==============================] - 6s - loss: 0.0248 - binary_accuracy: 0.9939     
64/64 [==============================] - 0s     

In [15]:
model.predict(X_test[:50,:,:,:])


Out[15]:
array([[  0.00000000e+00,   0.00000000e+00,   1.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  9.04261395e-26,   9.99990225e-01,   1.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   1.04094781e-02,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   9.97779191e-01,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00,   7.04561935e-26,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  2.42011310e-33,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   2.44924887e-17,   9.84471858e-01,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   1.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   1.00248814e-08,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   1.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   1.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  4.79507822e-35,   8.74605293e-14,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   1.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   1.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   1.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   8.49936128e-01,   9.21058817e-27,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   1.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  1.49368998e-05,   1.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   1.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   4.30223422e-20,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  1.98304372e-12,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   9.99999881e-01,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  1.94511796e-20,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   1.04908239e-31,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   1.70657435e-31,   8.37717836e-30,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   7.05975456e-38,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   1.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00,   5.16739895e-09,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   1.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   1.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   7.03969825e-29,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   1.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   1.00793104e-31,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  1.95060620e-38,   1.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  1.00000000e+00,   1.64890795e-10,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   1.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   1.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   2.10250813e-20,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   7.91063784e-27,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00,   1.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   1.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00]], dtype=float32)

In [25]:
plt.pcolor(model.predict(X_test))


Out[25]:
<matplotlib.collections.PolyCollection at 0x7fa02066ca10>

In [26]:
plt.pcolor(y_test)


Out[26]:
<matplotlib.collections.PolyCollection at 0x7fa0201b3a50>

In [34]:
# Now fit the model using a much bigger training set, n=1e4
model.fit_generator(datagen.flow(X_train, y_train, batch_size=32),
                    steps_per_epoch=len(X_train) / 32, epochs=10)


# Now, I think this may be wrong since we're not applying the 
# same input transformation as the data feeder:
score = model.evaluate(X_test, y_test, batch_size=32)


# Maybe the right way to do it looks something like:
#score = model.predict_generator(datagen.flow(X_test, y_test), steps=100)


Epoch 1/10
312/312 [==============================] - 134s - loss: 0.4772 - binary_accuracy: 0.7865     
Epoch 2/10
312/312 [==============================] - 134s - loss: 0.4728 - binary_accuracy: 0.7870     
Epoch 3/10
312/312 [==============================] - 133s - loss: 0.4698 - binary_accuracy: 0.7873     
Epoch 4/10
312/312 [==============================] - 133s - loss: 0.4633 - binary_accuracy: 0.7932     
Epoch 5/10
312/312 [==============================] - 134s - loss: 0.4423 - binary_accuracy: 0.8010     
Epoch 6/10
312/312 [==============================] - 134s - loss: 0.3984 - binary_accuracy: 0.8224     
Epoch 7/10
312/312 [==============================] - 134s - loss: 0.3490 - binary_accuracy: 0.8467     
Epoch 8/10
312/312 [==============================] - 134s - loss: 0.3048 - binary_accuracy: 0.8684     
Epoch 9/10
312/312 [==============================] - 134s - loss: 0.2735 - binary_accuracy: 0.8835     
Epoch 10/10
312/312 [==============================] - 134s - loss: 0.2442 - binary_accuracy: 0.8967     
5000/5000 [==============================] - 23s     

In [35]:
# Performance on the test set
print score


[2.5414306640625002, 0.76908571910858159]

We can see that the performance on the test set is much worse that, with an overall binary accuracy of 0.77 - much less than the final training set accuracy of 0.89. It seems that we have not trained the model on a large enough set of images, or we are just overfitting that data we have. We can try to load a new set of training images and then continue training the model to see if test performance improves.


In [ ]:
# Continue fitting the model for another 5 epochs and see if much changes - we could just be overfitting at this point
model.fit_generator(datagen.flow(X_train, y_train, batch_size=32),
                    steps_per_epoch=len(X_train) / 32, epochs=5)

score = model.evaluate(X_test, y_test, batch_size=32)


Epoch 1/5
312/312 [==============================] - 134s - loss: 0.2249 - binary_accuracy: 0.9064     
Epoch 2/5
312/312 [==============================] - 134s - loss: 0.2117 - binary_accuracy: 0.9144     
Epoch 3/5
312/312 [==============================] - 134s - loss: 0.2017 - binary_accuracy: 0.9183     
Epoch 4/5
312/312 [==============================] - 134s - loss: 0.1971 - binary_accuracy: 0.9218     
Epoch 5/5
 93/312 [=======>......................] - ETA: 94s - loss: 0.1819 - binary_accuracy: 0.9285 

Note, every time we test a different hyperparameter or a different number of epochs we rebuild a new model because Keras allows you to further train a model that has already been created. When we test different parameters we do not want to train multiple models on top on one another.


In [36]:
#RMSProp optimizer to tune learning rate, others at default as suggested by documentation
# try learning rate of 0.01
model.compile(loss='binary_crossentropy',
              optimizer=RMSprop(lr=0.01, rho=0.9, epsilon=1e-08, decay=0.0),
              metrics=['binary_accuracy'])

In [37]:
# Now fit the model using a much bigger training set, n=1e4
model.fit_generator(datagen.flow(X_train, y_train, batch_size=32),
                    steps_per_epoch=len(X_train) / 32, epochs=10)

score = model.evaluate(X_test, y_test, batch_size=32)


Epoch 1/10
312/312 [==============================] - 132s - loss: 3.5036 - binary_accuracy: 0.7815     
Epoch 2/10
312/312 [==============================] - 132s - loss: 3.4621 - binary_accuracy: 0.7844     
Epoch 3/10
312/312 [==============================] - 132s - loss: 3.4754 - binary_accuracy: 0.7837     
Epoch 4/10
312/312 [==============================] - 132s - loss: 3.4469 - binary_accuracy: 0.7855     
Epoch 5/10
312/312 [==============================] - 132s - loss: 3.4714 - binary_accuracy: 0.7839     
Epoch 6/10
312/312 [==============================] - 132s - loss: 3.4645 - binary_accuracy: 0.7843     
Epoch 7/10
312/312 [==============================] - 132s - loss: 3.4678 - binary_accuracy: 0.7841     
Epoch 8/10
312/312 [==============================] - 132s - loss: 3.4746 - binary_accuracy: 0.7837     
Epoch 9/10
312/312 [==============================] - 132s - loss: 3.4789 - binary_accuracy: 0.7834     
Epoch 10/10
312/312 [==============================] - 132s - loss: 3.4771 - binary_accuracy: 0.7835     
5000/5000 [==============================] - 23s     

In [38]:
print score


[3.4490453636169431, 0.78522857685089109]

In [16]:
#RMSProp optimizer to tune learning rate, others at default as suggested by documentation
# try learning rate of 0.0001
model.compile(loss='binary_crossentropy',
              optimizer=RMSprop(lr=0.0001, rho=0.9, epsilon=1e-08, decay=0.0),
              metrics=['binary_accuracy'])

In [17]:
# Now fit the model using a much bigger training set, n=1e4
model.fit_generator(datagen.flow(X_train, y_train, batch_size=32),
                    steps_per_epoch=len(X_train) / 32, epochs=10)


Epoch 1/10
312/312 [==============================] - 132s - loss: 0.3281 - binary_accuracy: 0.8557     
Epoch 2/10
312/312 [==============================] - 133s - loss: 0.3127 - binary_accuracy: 0.8629     
Epoch 3/10
312/312 [==============================] - 133s - loss: 0.3083 - binary_accuracy: 0.8658     
Epoch 4/10
312/312 [==============================] - 133s - loss: 0.2964 - binary_accuracy: 0.8719