AlexNet in Keras

In this notebook, we leverage an AlexNet-like deep, convolutional neural network to classify flowers into the 17 categories of the Oxford Flowers data set. Derived from this earlier notebook.

Set seed for reproducibility


In [1]:
import numpy as np
np.random.seed(42)

Load dependencies


In [2]:
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
from keras.layers.normalization import BatchNormalization
from keras.callbacks import TensorBoard # for part 3.5 on TensorBoard


Using TensorFlow backend.

Load and preprocess data


In [3]:
import tflearn.datasets.oxflower17 as oxflower17
X, Y = oxflower17.load_data(one_hot=True)

Design neural network architecture


In [4]:
model = Sequential()

model.add(Conv2D(96, kernel_size=(11, 11), strides=(4, 4), activation='relu', input_shape=(224, 224, 3)))
model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
model.add(BatchNormalization())

model.add(Conv2D(256, kernel_size=(5, 5), activation='relu'))
model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
model.add(BatchNormalization())

model.add(Conv2D(256, kernel_size=(3, 3), activation='relu'))
model.add(Conv2D(384, kernel_size=(3, 3), activation='relu'))
model.add(Conv2D(384, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
model.add(BatchNormalization())

model.add(Flatten())
model.add(Dense(4096, activation='tanh'))
model.add(Dropout(0.5))
model.add(Dense(4096, activation='tanh'))
model.add(Dropout(0.5))

model.add(Dense(17, activation='softmax'))

In [8]:
model.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 54, 54, 96)        34944     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 26, 26, 96)        0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 26, 26, 96)        384       
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 22, 22, 256)       614656    
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 10, 10, 256)       0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 10, 10, 256)       1024      
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 8, 8, 256)         590080    
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 6, 6, 384)         885120    
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 4, 4, 384)         1327488   
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 1, 1, 384)         0         
_________________________________________________________________
batch_normalization_3 (Batch (None, 1, 1, 384)         1536      
_________________________________________________________________
flatten_1 (Flatten)          (None, 384)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 4096)              1576960   
_________________________________________________________________
dropout_1 (Dropout)          (None, 4096)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 4096)              16781312  
_________________________________________________________________
dropout_2 (Dropout)          (None, 4096)              0         
_________________________________________________________________
dense_3 (Dense)              (None, 17)                69649     
=================================================================
Total params: 21,883,153
Trainable params: 21,881,681
Non-trainable params: 1,472
_________________________________________________________________

Configure model


In [5]:
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

Configure TensorBoard (for part 5 of lesson 3)


In [6]:
tensorbrd = TensorBoard('logs/alexnet')

Train!


In [7]:
model.fit(X, Y, batch_size=64, epochs=100, verbose=1, validation_split=0.1, shuffle=True, 
          callbacks=[tensorbrd])


Train on 1224 samples, validate on 136 samples
Epoch 1/100
1224/1224 [==============================] - 35s - loss: 5.4212 - acc: 0.2067 - val_loss: 11.3663 - val_acc: 0.0956
Epoch 2/100
1224/1224 [==============================] - 35s - loss: 3.9938 - acc: 0.2492 - val_loss: 6.4582 - val_acc: 0.0735
Epoch 3/100
1224/1224 [==============================] - 34s - loss: 2.7325 - acc: 0.2949 - val_loss: 5.6394 - val_acc: 0.1324
Epoch 4/100
1224/1224 [==============================] - 34s - loss: 2.5455 - acc: 0.3031 - val_loss: 4.1848 - val_acc: 0.2206
Epoch 5/100
1224/1224 [==============================] - 34s - loss: 2.5307 - acc: 0.3317 - val_loss: 2.9730 - val_acc: 0.2426
Epoch 6/100
1224/1224 [==============================] - 35s - loss: 2.3763 - acc: 0.3448 - val_loss: 2.9469 - val_acc: 0.2426
Epoch 7/100
1224/1224 [==============================] - 35s - loss: 2.4945 - acc: 0.3448 - val_loss: 3.7506 - val_acc: 0.2574
Epoch 8/100
1224/1224 [==============================] - 36s - loss: 2.6480 - acc: 0.3660 - val_loss: 5.1129 - val_acc: 0.1691
Epoch 9/100
1224/1224 [==============================] - 34s - loss: 2.8148 - acc: 0.3268 - val_loss: 2.8049 - val_acc: 0.3603
Epoch 10/100
1224/1224 [==============================] - 34s - loss: 2.3363 - acc: 0.3905 - val_loss: 2.8927 - val_acc: 0.3603
Epoch 11/100
1224/1224 [==============================] - 34s - loss: 2.2786 - acc: 0.4060 - val_loss: 2.5592 - val_acc: 0.3162
Epoch 12/100
1224/1224 [==============================] - 34s - loss: 2.0467 - acc: 0.4281 - val_loss: 2.8987 - val_acc: 0.3382
Epoch 13/100
1224/1224 [==============================] - 35s - loss: 1.9947 - acc: 0.4592 - val_loss: 2.1224 - val_acc: 0.3676
Epoch 14/100
1224/1224 [==============================] - 34s - loss: 1.7445 - acc: 0.4910 - val_loss: 2.0347 - val_acc: 0.4412
Epoch 15/100
1224/1224 [==============================] - 35s - loss: 1.8929 - acc: 0.4935 - val_loss: 2.3847 - val_acc: 0.4338
Epoch 16/100
1224/1224 [==============================] - 35s - loss: 1.9688 - acc: 0.4583 - val_loss: 3.3796 - val_acc: 0.3529
Epoch 17/100
1224/1224 [==============================] - 35s - loss: 2.2440 - acc: 0.4436 - val_loss: 2.7138 - val_acc: 0.3750
Epoch 18/100
1224/1224 [==============================] - 35s - loss: 1.9742 - acc: 0.5000 - val_loss: 2.8813 - val_acc: 0.4044
Epoch 19/100
1224/1224 [==============================] - 35s - loss: 1.8282 - acc: 0.5049 - val_loss: 3.6857 - val_acc: 0.3750
Epoch 20/100
1224/1224 [==============================] - 35s - loss: 1.8215 - acc: 0.5049 - val_loss: 2.2820 - val_acc: 0.5294
Epoch 21/100
1224/1224 [==============================] - 35s - loss: 1.8518 - acc: 0.5212 - val_loss: 1.9001 - val_acc: 0.5074
Epoch 22/100
1224/1224 [==============================] - 35s - loss: 1.7738 - acc: 0.5147 - val_loss: 2.0581 - val_acc: 0.5074
Epoch 23/100
1224/1224 [==============================] - 34s - loss: 1.4318 - acc: 0.5784 - val_loss: 3.7218 - val_acc: 0.3456
Epoch 24/100
1224/1224 [==============================] - 34s - loss: 1.6964 - acc: 0.5482 - val_loss: 3.9307 - val_acc: 0.2574
Epoch 25/100
1224/1224 [==============================] - 35s - loss: 1.5907 - acc: 0.5621 - val_loss: 2.9847 - val_acc: 0.3897
Epoch 26/100
1224/1224 [==============================] - 35s - loss: 1.6202 - acc: 0.5752 - val_loss: 3.7071 - val_acc: 0.3676
Epoch 27/100
1224/1224 [==============================] - 34s - loss: 1.7010 - acc: 0.5547 - val_loss: 2.5451 - val_acc: 0.3897
Epoch 28/100
1224/1224 [==============================] - 34s - loss: 1.5992 - acc: 0.5605 - val_loss: 3.3575 - val_acc: 0.3824
Epoch 29/100
1224/1224 [==============================] - 34s - loss: 1.4056 - acc: 0.5989 - val_loss: 2.6031 - val_acc: 0.4559
Epoch 30/100
1224/1224 [==============================] - 26s - loss: 1.4243 - acc: 0.6217 - val_loss: 2.7693 - val_acc: 0.4485
Epoch 31/100
1224/1224 [==============================] - 24s - loss: 1.2810 - acc: 0.6176 - val_loss: 4.5931 - val_acc: 0.3603
Epoch 32/100
1224/1224 [==============================] - 25s - loss: 1.3800 - acc: 0.6381 - val_loss: 2.9544 - val_acc: 0.4265
Epoch 33/100
1224/1224 [==============================] - 25s - loss: 1.4526 - acc: 0.6087 - val_loss: 3.2366 - val_acc: 0.4559
Epoch 34/100
1224/1224 [==============================] - 26s - loss: 1.4056 - acc: 0.6332 - val_loss: 2.6401 - val_acc: 0.5294
Epoch 35/100
1224/1224 [==============================] - 35s - loss: 1.5018 - acc: 0.5948 - val_loss: 4.3320 - val_acc: 0.3603
Epoch 36/100
1224/1224 [==============================] - 34s - loss: 1.3303 - acc: 0.6381 - val_loss: 3.3895 - val_acc: 0.3971
Epoch 37/100
1224/1224 [==============================] - 34s - loss: 1.2792 - acc: 0.6422 - val_loss: 2.9808 - val_acc: 0.4485
Epoch 38/100
1224/1224 [==============================] - 34s - loss: 1.2895 - acc: 0.6356 - val_loss: 2.7670 - val_acc: 0.4632
Epoch 39/100
1224/1224 [==============================] - 30s - loss: 1.2147 - acc: 0.6520 - val_loss: 2.7203 - val_acc: 0.5147
Epoch 40/100
1224/1224 [==============================] - 25s - loss: 1.2939 - acc: 0.6340 - val_loss: 3.2740 - val_acc: 0.4485
Epoch 41/100
1224/1224 [==============================] - 25s - loss: 1.1160 - acc: 0.6797 - val_loss: 2.1987 - val_acc: 0.5074
Epoch 42/100
1224/1224 [==============================] - 25s - loss: 0.9781 - acc: 0.7190 - val_loss: 2.5899 - val_acc: 0.5074
Epoch 43/100
1224/1224 [==============================] - 25s - loss: 0.9695 - acc: 0.7288 - val_loss: 2.1273 - val_acc: 0.5147
Epoch 44/100
1224/1224 [==============================] - 31s - loss: 1.0449 - acc: 0.6912 - val_loss: 2.3803 - val_acc: 0.6176
Epoch 45/100
1224/1224 [==============================] - 35s - loss: 0.9135 - acc: 0.7516 - val_loss: 3.4901 - val_acc: 0.4706
Epoch 46/100
1224/1224 [==============================] - 34s - loss: 0.7137 - acc: 0.7672 - val_loss: 2.5690 - val_acc: 0.5368
Epoch 47/100
1224/1224 [==============================] - 34s - loss: 0.7864 - acc: 0.7606 - val_loss: 6.0906 - val_acc: 0.1471
Epoch 48/100
1224/1224 [==============================] - 34s - loss: 0.8475 - acc: 0.7435 - val_loss: 4.1929 - val_acc: 0.4632
Epoch 49/100
1224/1224 [==============================] - 34s - loss: 0.7135 - acc: 0.7876 - val_loss: 2.4818 - val_acc: 0.5809
Epoch 50/100
1224/1224 [==============================] - 34s - loss: 0.6861 - acc: 0.7859 - val_loss: 2.4039 - val_acc: 0.6103
Epoch 51/100
1224/1224 [==============================] - 34s - loss: 0.7028 - acc: 0.7966 - val_loss: 2.8229 - val_acc: 0.5441
Epoch 52/100
1224/1224 [==============================] - 34s - loss: 0.8514 - acc: 0.7647 - val_loss: 2.4944 - val_acc: 0.5147
Epoch 53/100
1224/1224 [==============================] - 34s - loss: 1.1566 - acc: 0.6846 - val_loss: 4.6325 - val_acc: 0.3529
Epoch 54/100
1224/1224 [==============================] - 35s - loss: 1.7106 - acc: 0.5850 - val_loss: 4.7924 - val_acc: 0.3603
Epoch 55/100
1224/1224 [==============================] - 34s - loss: 1.2370 - acc: 0.6789 - val_loss: 4.2017 - val_acc: 0.4338
Epoch 56/100
1224/1224 [==============================] - 34s - loss: 1.1042 - acc: 0.6716 - val_loss: 2.4918 - val_acc: 0.5074
Epoch 57/100
1224/1224 [==============================] - 34s - loss: 0.9409 - acc: 0.7198 - val_loss: 2.6726 - val_acc: 0.5956
Epoch 58/100
1224/1224 [==============================] - 34s - loss: 0.9597 - acc: 0.7149 - val_loss: 2.1746 - val_acc: 0.5368
Epoch 59/100
1224/1224 [==============================] - 31s - loss: 1.0154 - acc: 0.7116 - val_loss: 2.1224 - val_acc: 0.6176
Epoch 60/100
1224/1224 [==============================] - 25s - loss: 1.4147 - acc: 0.6667 - val_loss: 6.5234 - val_acc: 0.3015
Epoch 61/100
1224/1224 [==============================] - 25s - loss: 1.8682 - acc: 0.5670 - val_loss: 6.0193 - val_acc: 0.2132
Epoch 62/100
1224/1224 [==============================] - 25s - loss: 1.4930 - acc: 0.6078 - val_loss: 3.5747 - val_acc: 0.3971
Epoch 63/100
1224/1224 [==============================] - 26s - loss: 1.2548 - acc: 0.6650 - val_loss: 2.8592 - val_acc: 0.4412
Epoch 64/100
1224/1224 [==============================] - 32s - loss: 1.0849 - acc: 0.6895 - val_loss: 2.6487 - val_acc: 0.5368
Epoch 65/100
1224/1224 [==============================] - 34s - loss: 0.9805 - acc: 0.7132 - val_loss: 3.6187 - val_acc: 0.4485
Epoch 66/100
1224/1224 [==============================] - 34s - loss: 0.8874 - acc: 0.7418 - val_loss: 2.8937 - val_acc: 0.4926
Epoch 67/100
1224/1224 [==============================] - 34s - loss: 0.7834 - acc: 0.7598 - val_loss: 2.8119 - val_acc: 0.5221
Epoch 68/100
1224/1224 [==============================] - 34s - loss: 0.6975 - acc: 0.7819 - val_loss: 2.4226 - val_acc: 0.5074
Epoch 69/100
1224/1224 [==============================] - 34s - loss: 0.6471 - acc: 0.8039 - val_loss: 2.2897 - val_acc: 0.5735
Epoch 70/100
1224/1224 [==============================] - 34s - loss: 0.7631 - acc: 0.7819 - val_loss: 2.4562 - val_acc: 0.5882
Epoch 71/100
1224/1224 [==============================] - 34s - loss: 0.6553 - acc: 0.7908 - val_loss: 2.3580 - val_acc: 0.5809
Epoch 72/100
1224/1224 [==============================] - 35s - loss: 0.5675 - acc: 0.8235 - val_loss: 2.9143 - val_acc: 0.5294
Epoch 73/100
1224/1224 [==============================] - 27s - loss: 0.5287 - acc: 0.8431 - val_loss: 2.9401 - val_acc: 0.5441
Epoch 74/100
1224/1224 [==============================] - 25s - loss: 0.6145 - acc: 0.8203 - val_loss: 1.9691 - val_acc: 0.5882
Epoch 75/100
1224/1224 [==============================] - 24s - loss: 0.3952 - acc: 0.8685 - val_loss: 2.6215 - val_acc: 0.5882
Epoch 76/100
1224/1224 [==============================] - 25s - loss: 0.4491 - acc: 0.8578 - val_loss: 2.3133 - val_acc: 0.5735
Epoch 77/100
1224/1224 [==============================] - 25s - loss: 0.4666 - acc: 0.8554 - val_loss: 2.4880 - val_acc: 0.6324
Epoch 78/100
1224/1224 [==============================] - 30s - loss: 0.7113 - acc: 0.8145 - val_loss: 2.7856 - val_acc: 0.5809
Epoch 79/100
1224/1224 [==============================] - 25s - loss: 0.6339 - acc: 0.8096 - val_loss: 3.5978 - val_acc: 0.5147
Epoch 80/100
1224/1224 [==============================] - 25s - loss: 0.7559 - acc: 0.7827 - val_loss: 2.9329 - val_acc: 0.5368
Epoch 81/100
1224/1224 [==============================] - 25s - loss: 1.3393 - acc: 0.6814 - val_loss: 11.2519 - val_acc: 0.0956
Epoch 82/100
1224/1224 [==============================] - 24s - loss: 2.3512 - acc: 0.4894 - val_loss: 4.9423 - val_acc: 0.2647
Epoch 83/100
1224/1224 [==============================] - 32s - loss: 1.9400 - acc: 0.5425 - val_loss: 4.1003 - val_acc: 0.3235
Epoch 84/100
1224/1224 [==============================] - 35s - loss: 1.6382 - acc: 0.5940 - val_loss: 2.6543 - val_acc: 0.4926
Epoch 85/100
1224/1224 [==============================] - 34s - loss: 1.4111 - acc: 0.6324 - val_loss: 2.0040 - val_acc: 0.5735
Epoch 86/100
1224/1224 [==============================] - 34s - loss: 1.2437 - acc: 0.6912 - val_loss: 2.0457 - val_acc: 0.5294
Epoch 87/100
1224/1224 [==============================] - 25s - loss: 1.0179 - acc: 0.7067 - val_loss: 2.3981 - val_acc: 0.5956
Epoch 88/100
1224/1224 [==============================] - 25s - loss: 0.9751 - acc: 0.7353 - val_loss: 2.8484 - val_acc: 0.4632
Epoch 89/100
1224/1224 [==============================] - 25s - loss: 0.7592 - acc: 0.7745 - val_loss: 2.7576 - val_acc: 0.5221
Epoch 90/100
1224/1224 [==============================] - 24s - loss: 0.8512 - acc: 0.7623 - val_loss: 2.3349 - val_acc: 0.5882
Epoch 91/100
1224/1224 [==============================] - 27s - loss: 0.8372 - acc: 0.7500 - val_loss: 2.2602 - val_acc: 0.6471
Epoch 92/100
1224/1224 [==============================] - 33s - loss: 0.9406 - acc: 0.7361 - val_loss: 2.3562 - val_acc: 0.6250
Epoch 93/100
1224/1224 [==============================] - 25s - loss: 0.9881 - acc: 0.7500 - val_loss: 2.2908 - val_acc: 0.5221
Epoch 94/100
1224/1224 [==============================] - 26s - loss: 0.8543 - acc: 0.7590 - val_loss: 2.5853 - val_acc: 0.5441
Epoch 95/100
1224/1224 [==============================] - 25s - loss: 0.7622 - acc: 0.7949 - val_loss: 2.2524 - val_acc: 0.5956
Epoch 96/100
1224/1224 [==============================] - 25s - loss: 0.6826 - acc: 0.8105 - val_loss: 2.8272 - val_acc: 0.5368
Epoch 97/100
1224/1224 [==============================] - 29s - loss: 0.6542 - acc: 0.8219 - val_loss: 2.6472 - val_acc: 0.5662
Epoch 98/100
1224/1224 [==============================] - 34s - loss: 0.6940 - acc: 0.7974 - val_loss: 2.4740 - val_acc: 0.5662
Epoch 99/100
1224/1224 [==============================] - 34s - loss: 0.5167 - acc: 0.8374 - val_loss: 2.2321 - val_acc: 0.6103
Epoch 100/100
1224/1224 [==============================] - 35s - loss: 0.4271 - acc: 0.8619 - val_loss: 2.3397 - val_acc: 0.6324
Out[7]:
<keras.callbacks.History at 0x7f65dc6a7630>

In [ ]: