Implements a Siamese/Y-Network using Functional API

~99.4% test accuracy


In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from keras.layers import Dense, Dropout, Input
from keras.layers import Conv2D, MaxPooling2D, Flatten
from keras.models import Model
from keras.layers.merge import concatenate
from keras.datasets import mnist
from keras.utils import to_categorical
from keras.utils import plot_model

# load MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# from sparse label to categorical
num_labels = len(np.unique(y_train))
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

# reshape and normalize input images
image_size = x_train.shape[1]
x_train = np.reshape(x_train,[-1, image_size, image_size, 1])
x_test = np.reshape(x_test,[-1, image_size, image_size, 1])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# network parameters
input_shape = (image_size, image_size, 1)
batch_size = 32
kernel_size = 3
dropout = 0.4
n_filters = 32

# left branch of Y network
left_inputs = Input(shape=input_shape)
x = left_inputs
filters = n_filters
# 3 layers of Conv2D-Dropout-MaxPooling2D
# number of filters doubles after each layer (32-64-128)
for i in range(3):
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               padding='same',
               activation='relu')(x)
    x = Dropout(dropout)(x)
    x = MaxPooling2D()(x)
    filters *= 2

# right branch of Y network
right_inputs = Input(shape=input_shape)
y = right_inputs
filters = n_filters
# 3 layers of Conv2D-Dropout-MaxPooling2D
# number of filters doubles after each layer (32-64-128)
for i in range(3):
    y = Conv2D(filters=filters,
               kernel_size=kernel_size,
               padding='same',
               activation='relu',
               dilation_rate=2)(y)
    y = Dropout(dropout)(y)
    y = MaxPooling2D()(y)
    filters *= 2

# merge left and right branches outputs
y = concatenate([x, y])
# feature maps to vector in preparation to connecting to Dense layer
y = Flatten()(y)
y = Dropout(dropout)(y)
outputs = Dense(num_labels, activation='softmax')(y)

# build the model in functional API
model = Model([left_inputs, right_inputs], outputs)
# verify the model using graph
plot_model(model, to_file='cnn-y-network.png', show_shapes=True)
# verify the model using layer text description
model.summary()

# classifier loss, Adam optimizer, classifier accuracy
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

# train the model with input images and labels
model.fit([x_train, x_train],
          y_train, 
          validation_data=([x_test, x_test], y_test),
          epochs=20,
          batch_size=batch_size)

# model accuracy on test dataset
score = model.evaluate([x_test, x_test], y_test, batch_size=batch_size)
print("\nTest accuracy: %.1f%%" % (100.0 * score[1]))


/usr/local/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
Using TensorFlow backend.
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 28, 28, 32)   320         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 28, 28, 32)   320         input_2[0][0]                    
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 28, 28, 32)   0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
dropout_4 (Dropout)             (None, 28, 28, 32)   0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 14, 14, 32)   0           dropout_1[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 14, 14, 32)   0           dropout_4[0][0]                  
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 14, 14, 64)   18496       max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 14, 14, 64)   18496       max_pooling2d_4[0][0]            
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 14, 14, 64)   0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
dropout_5 (Dropout)             (None, 14, 14, 64)   0           conv2d_5[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 7, 7, 64)     0           dropout_2[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_5 (MaxPooling2D)  (None, 7, 7, 64)     0           dropout_5[0][0]                  
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 7, 7, 128)    73856       max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 7, 7, 128)    73856       max_pooling2d_5[0][0]            
__________________________________________________________________________________________________
dropout_3 (Dropout)             (None, 7, 7, 128)    0           conv2d_3[0][0]                   
__________________________________________________________________________________________________
dropout_6 (Dropout)             (None, 7, 7, 128)    0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 3, 3, 128)    0           dropout_3[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_6 (MaxPooling2D)  (None, 3, 3, 128)    0           dropout_6[0][0]                  
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 3, 3, 256)    0           max_pooling2d_3[0][0]            
                                                                 max_pooling2d_6[0][0]            
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 2304)         0           concatenate_1[0][0]              
__________________________________________________________________________________________________
dropout_7 (Dropout)             (None, 2304)         0           flatten_1[0][0]                  
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 10)           23050       dropout_7[0][0]                  
==================================================================================================
Total params: 208,394
Trainable params: 208,394
Non-trainable params: 0
__________________________________________________________________________________________________
Train on 60000 samples, validate on 10000 samples
Epoch 1/20
60000/60000 [==============================] - 351s 6ms/step - loss: 0.1769 - acc: 0.9435 - val_loss: 0.1412 - val_acc: 0.9904
Epoch 2/20
60000/60000 [==============================] - 361s 6ms/step - loss: 0.0664 - acc: 0.9795 - val_loss: 0.0923 - val_acc: 0.9903
Epoch 3/20
60000/60000 [==============================] - 359s 6ms/step - loss: 0.0528 - acc: 0.9835 - val_loss: 0.0772 - val_acc: 0.9908
Epoch 4/20
60000/60000 [==============================] - 286s 5ms/step - loss: 0.0471 - acc: 0.9854 - val_loss: 0.0836 - val_acc: 0.9930
Epoch 5/20
60000/60000 [==============================] - 244s 4ms/step - loss: 0.0429 - acc: 0.9861 - val_loss: 0.0736 - val_acc: 0.9931
Epoch 6/20
60000/60000 [==============================] - 242s 4ms/step - loss: 0.0393 - acc: 0.9878 - val_loss: 0.0507 - val_acc: 0.9931
Epoch 7/20
60000/60000 [==============================] - 183s 3ms/step - loss: 0.0364 - acc: 0.9883 - val_loss: 0.0434 - val_acc: 0.9934
Epoch 8/20
60000/60000 [==============================] - 141s 2ms/step - loss: 0.0364 - acc: 0.9890 - val_loss: 0.0471 - val_acc: 0.9931
Epoch 9/20
60000/60000 [==============================] - 139s 2ms/step - loss: 0.0358 - acc: 0.9892 - val_loss: 0.0384 - val_acc: 0.9938
Epoch 10/20
60000/60000 [==============================] - 138s 2ms/step - loss: 0.0342 - acc: 0.9891 - val_loss: 0.0500 - val_acc: 0.9915
Epoch 11/20
60000/60000 [==============================] - 138s 2ms/step - loss: 0.0327 - acc: 0.9895 - val_loss: 0.0328 - val_acc: 0.9919
Epoch 12/20
60000/60000 [==============================] - 137s 2ms/step - loss: 0.0328 - acc: 0.9896 - val_loss: 0.0364 - val_acc: 0.9928
Epoch 13/20
60000/60000 [==============================] - 139s 2ms/step - loss: 0.0334 - acc: 0.9898 - val_loss: 0.0322 - val_acc: 0.9938
Epoch 14/20
60000/60000 [==============================] - 143s 2ms/step - loss: 0.0306 - acc: 0.9908 - val_loss: 0.0346 - val_acc: 0.9942
Epoch 15/20
60000/60000 [==============================] - 180s 3ms/step - loss: 0.0319 - acc: 0.9903 - val_loss: 0.0417 - val_acc: 0.9942
Epoch 16/20
60000/60000 [==============================] - 148s 2ms/step - loss: 0.0306 - acc: 0.9906 - val_loss: 0.0308 - val_acc: 0.9930
Epoch 17/20
60000/60000 [==============================] - 149s 2ms/step - loss: 0.0299 - acc: 0.9906 - val_loss: 0.0300 - val_acc: 0.9945
Epoch 18/20
60000/60000 [==============================] - 137s 2ms/step - loss: 0.0316 - acc: 0.9904 - val_loss: 0.0333 - val_acc: 0.9923
Epoch 19/20
60000/60000 [==============================] - 141s 2ms/step - loss: 0.0279 - acc: 0.9917 - val_loss: 0.0255 - val_acc: 0.9942
Epoch 20/20
60000/60000 [==============================] - 137s 2ms/step - loss: 0.0297 - acc: 0.9910 - val_loss: 0.0249 - val_acc: 0.9938
10000/10000 [==============================] - 4s 377us/step

Test accuracy: 99.4%

In [ ]: