HyperParameter Tuning

Problem:

Builds simple CNN models on MNIST and uses sklearn's GridSearchCV to find best model


In [1]:
import numpy as np
np.random.seed(1337)  # for reproducibility

In [2]:
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.utils import np_utils
from keras.wrappers.scikit_learn import KerasClassifier
from keras import backend as K


Using TensorFlow backend.

In [3]:
from sklearn.model_selection import GridSearchCV

Data Preparation


In [4]:
nb_classes = 10

# input image dimensions
img_rows, img_cols = 28, 28

In [5]:
# load training data and do basic data normalization
(X_train, y_train), (X_test, y_test) = mnist.load_data()

if K.image_dim_ordering() == 'th':
    X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
    X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
    X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

In [6]:
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255

# convert class vectors to binary class matrices
y_train = np_utils.to_categorical(y_train, nb_classes)
y_test = np_utils.to_categorical(y_test, nb_classes)

Build Model


In [13]:
def make_model(dense_layer_sizes, filters, kernel_size, pool_size):
    '''Creates model comprised of 2 convolutional layers followed by dense layers

    dense_layer_sizes: List of layer sizes. This list has one number for each layer
    nb_filters: Number of convolutional filters in each convolutional layer
    nb_conv: Convolutional kernel size
    nb_pool: Size of pooling area for max pooling
    '''

    model = Sequential()

    model.add(Conv2D(filters, (kernel_size, kernel_size),
                     padding='valid', input_shape=input_shape))
    model.add(Activation('relu'))
    model.add(Conv2D(filters, (kernel_size, kernel_size)))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(pool_size, pool_size)))
    model.add(Dropout(0.25))

    model.add(Flatten())
    for layer_size in dense_layer_sizes:
        model.add(Dense(layer_size))
        model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(nb_classes))
    model.add(Activation('softmax'))

    model.compile(loss='categorical_crossentropy',
                  optimizer='adadelta',
                  metrics=['accuracy'])

    return model

In [14]:
dense_size_candidates = [[32], [64], [32, 32], [64, 64]]
my_classifier = KerasClassifier(make_model, batch_size=32)

GridSearch HyperParameters


In [15]:
validator = GridSearchCV(my_classifier,
                         param_grid={'dense_layer_sizes': dense_size_candidates,
                                     # nb_epoch is avail for tuning even when not
                                     # an argument to model building function
                                     'epochs': [3, 6],
                                     'filters': [8],
                                     'kernel_size': [3],
                                     'pool_size': [2]},
                         scoring='neg_log_loss',
                         n_jobs=1)
validator.fit(X_train, y_train)


Epoch 1/3
40000/40000 [==============================] - ETA: 0s - loss: 0.8971 - acc: 0.694 - 10s - loss: 0.8961 - acc: 0.6953    
Epoch 2/3
40000/40000 [==============================] - 9s - loss: 0.5362 - acc: 0.8299     
Epoch 3/3
40000/40000 [==============================] - 10s - loss: 0.4425 - acc: 0.8594    
39552/40000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 11s - loss: 0.7593 - acc: 0.7543    
Epoch 2/3
40000/40000 [==============================] - 10s - loss: 0.4489 - acc: 0.8597    
Epoch 3/3
40000/40000 [==============================] - 10s - loss: 0.3841 - acc: 0.8814    
39648/40000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 10s - loss: 0.9089 - acc: 0.6946    
Epoch 2/3
40000/40000 [==============================] - 9s - loss: 0.5560 - acc: 0.8228     
Epoch 3/3
40000/40000 [==============================] - 10s - loss: 0.4597 - acc: 0.8556    
39680/40000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 11s - loss: 0.8415 - acc: 0.7162    
Epoch 2/6
40000/40000 [==============================] - 10s - loss: 0.4929 - acc: 0.8423    
Epoch 3/6
40000/40000 [==============================] - 9s - loss: 0.4172 - acc: 0.8703     
Epoch 4/6
40000/40000 [==============================] - 10s - loss: 0.3819 - acc: 0.8812    
Epoch 5/6
40000/40000 [==============================] - 10s - loss: 0.3491 - acc: 0.8919    
Epoch 6/6
40000/40000 [==============================] - 10s - loss: 0.3284 - acc: 0.8985    
39680/40000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 11s - loss: 0.7950 - acc: 0.7349    
Epoch 2/6
40000/40000 [==============================] - 10s - loss: 0.4913 - acc: 0.8428    
Epoch 3/6
40000/40000 [==============================] - 10s - loss: 0.4081 - acc: 0.8709    
Epoch 4/6
40000/40000 [==============================] - 10s - loss: 0.3613 - acc: 0.8870    
Epoch 5/6
40000/40000 [==============================] - 10s - loss: 0.3293 - acc: 0.8968    
Epoch 6/6
40000/40000 [==============================] - 10s - loss: 0.3024 - acc: 0.9058    
39936/40000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 11s - loss: 0.9822 - acc: 0.6735    
Epoch 2/6
40000/40000 [==============================] - 10s - loss: 0.6270 - acc: 0.8009    
Epoch 3/6
40000/40000 [==============================] - 9s - loss: 0.5045 - acc: 0.8409     
Epoch 4/6
40000/40000 [==============================] - 10s - loss: 0.4396 - acc: 0.8599    
Epoch 5/6
40000/40000 [==============================] - 10s - loss: 0.3978 - acc: 0.8775    
Epoch 6/6
40000/40000 [==============================] - 10s - loss: 0.3605 - acc: 0.8871    
39872/40000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 11s - loss: 0.6851 - acc: 0.7777    
Epoch 2/3
40000/40000 [==============================] - 10s - loss: 0.3989 - acc: 0.8776    
Epoch 3/3
40000/40000 [==============================] - 10s - loss: 0.3225 - acc: 0.9021    
39552/40000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 11s - loss: 0.5846 - acc: 0.8164    
Epoch 2/3
40000/40000 [==============================] - 10s - loss: 0.3243 - acc: 0.9053    
Epoch 3/3
40000/40000 [==============================] - 10s - loss: 0.2697 - acc: 0.9213    
39680/40000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 11s - loss: 0.6339 - acc: 0.8017    
Epoch 2/3
40000/40000 [==============================] - 10s - loss: 0.3417 - acc: 0.8975    
Epoch 3/3
40000/40000 [==============================] - 10s - loss: 0.2783 - acc: 0.9184    
39648/40000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 11s - loss: 0.6652 - acc: 0.7854    
Epoch 2/6
40000/40000 [==============================] - 10s - loss: 0.3693 - acc: 0.8911    
Epoch 3/6
40000/40000 [==============================] - 10s - loss: 0.2923 - acc: 0.9130    
Epoch 4/6
40000/40000 [==============================] - 10s - loss: 0.2479 - acc: 0.9274    
Epoch 5/6
40000/40000 [==============================] - 10s - loss: 0.2176 - acc: 0.9360    
Epoch 6/6
40000/40000 [==============================] - 10s - loss: 0.1994 - acc: 0.9416    
39616/40000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 11s - loss: 0.6463 - acc: 0.7952    
Epoch 2/6
40000/40000 [==============================] - 10s - loss: 0.3648 - acc: 0.8898    
Epoch 3/6
40000/40000 [==============================] - 10s - loss: 0.2880 - acc: 0.9154    
Epoch 4/6
40000/40000 [==============================] - 10s - loss: 0.2497 - acc: 0.9249    
Epoch 5/6
40000/40000 [==============================] - 10s - loss: 0.2154 - acc: 0.9357    
Epoch 6/6
40000/40000 [==============================] - 10s - loss: 0.1946 - acc: 0.9417    
39584/40000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 11s - loss: 0.6212 - acc: 0.8012    
Epoch 2/6
40000/40000 [==============================] - 10s - loss: 0.3341 - acc: 0.9008    
Epoch 3/6
40000/40000 [==============================] - 10s - loss: 0.2706 - acc: 0.9195    
Epoch 4/6
40000/40000 [==============================] - 10s - loss: 0.2343 - acc: 0.9307    
Epoch 5/6
40000/40000 [==============================] - 10s - loss: 0.2109 - acc: 0.9383    
Epoch 6/6
40000/40000 [==============================] - 10s - loss: 0.1961 - acc: 0.9420    
39648/40000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 12s - loss: 0.9322 - acc: 0.6835    
Epoch 2/3
40000/40000 [==============================] - 10s - loss: 0.5578 - acc: 0.8202    
Epoch 3/3
40000/40000 [==============================] - 11s - loss: 0.4651 - acc: 0.8518    
40000/40000 [==============================] - 4s     
Epoch 1/3
40000/40000 [==============================] - 11s - loss: 0.7615 - acc: 0.7467    
Epoch 2/3
40000/40000 [==============================] - 10s - loss: 0.4369 - acc: 0.8634    
Epoch 3/3
40000/40000 [==============================] - 10s - loss: 0.3646 - acc: 0.8865    
39904/40000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 12s - loss: 0.7744 - acc: 0.7471    
Epoch 2/3
40000/40000 [==============================] - 11s - loss: 0.4294 - acc: 0.8674    
Epoch 3/3
40000/40000 [==============================] - 11s - loss: 0.3620 - acc: 0.8873    
39968/40000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 12s - loss: 0.8007 - acc: 0.7354    
Epoch 2/6
40000/40000 [==============================] - 10s - loss: 0.4769 - acc: 0.8499    
Epoch 3/6
40000/40000 [==============================] - 11s - loss: 0.4020 - acc: 0.8743    
Epoch 4/6
40000/40000 [==============================] - 11s - loss: 0.3551 - acc: 0.8905    
Epoch 5/6
40000/40000 [==============================] - 11s - loss: 0.3256 - acc: 0.8993    
Epoch 6/6
40000/40000 [==============================] - 11s - loss: 0.3005 - acc: 0.9067    
39520/40000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 12s - loss: 0.8505 - acc: 0.7123    
Epoch 2/6
40000/40000 [==============================] - 10s - loss: 0.5156 - acc: 0.8321    
Epoch 3/6
40000/40000 [==============================] - 11s - loss: 0.4208 - acc: 0.8660    
Epoch 4/6
40000/40000 [==============================] - 11s - loss: 0.3614 - acc: 0.8854    
Epoch 5/6
40000/40000 [==============================] - 11s - loss: 0.3258 - acc: 0.8980    
Epoch 6/6
40000/40000 [==============================] - 11s - loss: 0.3044 - acc: 0.9046    
39936/40000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 12s - loss: 0.7670 - acc: 0.7494    
Epoch 2/6
40000/40000 [==============================] - 11s - loss: 0.4593 - acc: 0.8574    
Epoch 3/6
40000/40000 [==============================] - ETA: 0s - loss: 0.3896 - acc: 0.880 - 11s - loss: 0.3898 - acc: 0.8799    
Epoch 4/6
40000/40000 [==============================] - 10s - loss: 0.3514 - acc: 0.8907    
Epoch 5/6
40000/40000 [==============================] - 10s - loss: 0.3124 - acc: 0.9020    
Epoch 6/6
40000/40000 [==============================] - 11s - loss: 0.2981 - acc: 0.9097    
39680/40000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 12s - loss: 0.5547 - acc: 0.8239    
Epoch 2/3
40000/40000 [==============================] - 11s - loss: 0.2752 - acc: 0.9204    
Epoch 3/3
40000/40000 [==============================] - 11s - loss: 0.2183 - acc: 0.9359    
39520/40000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 12s - loss: 0.5718 - acc: 0.8172    
Epoch 2/3
40000/40000 [==============================] - 11s - loss: 0.3141 - acc: 0.9054    
Epoch 3/3
40000/40000 [==============================] - 11s - loss: 0.2536 - acc: 0.9247    
39680/40000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 12s - loss: 0.5111 - acc: 0.8399    
Epoch 2/3
40000/40000 [==============================] - 11s - loss: 0.2469 - acc: 0.9270    
Epoch 3/3
40000/40000 [==============================] - 11s - loss: 0.1992 - acc: 0.9422    
20000/20000 [==============================] - 2s     
40000/40000 [==============================] - 4s     
Epoch 1/6
40000/40000 [==============================] - 12s - loss: 0.6041 - acc: 0.8066    
Epoch 2/6
40000/40000 [==============================] - 11s - loss: 0.2951 - acc: 0.9132    
Epoch 3/6
40000/40000 [==============================] - 11s - loss: 0.2343 - acc: 0.9315    
Epoch 4/6
40000/40000 [==============================] - 11s - loss: 0.1995 - acc: 0.9418    
Epoch 5/6
40000/40000 [==============================] - 11s - loss: 0.1779 - acc: 0.9487    
Epoch 6/6
40000/40000 [==============================] - 11s - loss: 0.1612 - acc: 0.9540    
39680/40000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 12s - loss: 0.6137 - acc: 0.8069    
Epoch 2/6
40000/40000 [==============================] - 11s - loss: 0.3075 - acc: 0.9096    
Epoch 3/6
40000/40000 [==============================] - 11s - loss: 0.2309 - acc: 0.9325    
Epoch 4/6
40000/40000 [==============================] - 11s - loss: 0.1935 - acc: 0.9443    
Epoch 5/6
40000/40000 [==============================] - 11s - loss: 0.1679 - acc: 0.9518    
Epoch 6/6
40000/40000 [==============================] - 11s - loss: 0.1576 - acc: 0.9551    
39680/40000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 12s - loss: 0.5143 - acc: 0.8400    
Epoch 2/6
40000/40000 [==============================] - 11s - loss: 0.2743 - acc: 0.9205    
Epoch 3/6
40000/40000 [==============================] - 11s - loss: 0.2248 - acc: 0.9350    
Epoch 4/6
40000/40000 [==============================] - 11s - loss: 0.1964 - acc: 0.9428    
Epoch 5/6
40000/40000 [==============================] - 11s - loss: 0.1736 - acc: 0.9496    
Epoch 6/6
40000/40000 [==============================] - 11s - loss: 0.1643 - acc: 0.9521    
39840/40000 [============================>.] - ETA: 0sEpoch 1/6
60000/60000 [==============================] - 18s - loss: 0.4674 - acc: 0.8567    
Epoch 2/6
60000/60000 [==============================] - 16s - loss: 0.2417 - acc: 0.9293    
Epoch 3/6
60000/60000 [==============================] - 16s - loss: 0.1966 - acc: 0.9428    
Epoch 4/6
60000/60000 [==============================] - 17s - loss: 0.1695 - acc: 0.9519    
Epoch 5/6
60000/60000 [==============================] - 16s - loss: 0.1504 - acc: 0.9571    
Epoch 6/6
60000/60000 [==============================] - 15s - loss: 0.1393 - acc: 0.9597    
Out[15]:
GridSearchCV(cv=None, error_score='raise',
       estimator=<keras.wrappers.scikit_learn.KerasClassifier object at 0x7f434a86ce48>,
       fit_params={}, iid=True, n_jobs=1,
       param_grid={'filters': [8], 'pool_size': [2], 'epochs': [3, 6], 'dense_layer_sizes': [[32], [64], [32, 32], [64, 64]], 'kernel_size': [3]},
       pre_dispatch='2*n_jobs', refit=True, return_train_score=True,
       scoring='neg_log_loss', verbose=0)

In [16]:
print('The parameters of the best model are: ')
print(validator.best_params_)

# validator.best_estimator_ returns sklearn-wrapped version of best model.
# validator.best_estimator_.model returns the (unwrapped) keras model
best_model = validator.best_estimator_.model
metric_names = best_model.metrics_names
metric_values = best_model.evaluate(X_test, y_test)
for metric, value in zip(metric_names, metric_values):
    print(metric, ': ', value)


The parameters of the best model are: 
{'filters': 8, 'pool_size': 2, 'epochs': 6, 'dense_layer_sizes': [64, 64], 'kernel_size': 3}
 9920/10000 [============================>.] - ETA: 0sloss :  0.0577878101223
acc :  0.9822

There's more:

The GridSearchCV model in scikit-learn performs a complete search, considering all the possible combinations of Hyper-parameters we want to optimise.

If we want to apply for an optmised and bounded search in the hyper-parameter space, I strongly suggest to take a look at: