In [1]:
from __future__ import print_function
import numpy as np
np.random.seed(1338) # for reproducibility
from keras.datasets import mnist
from keras.models import Graph
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, AveragePooling2D
from keras.utils import np_utils
In [2]:
batch_size = 128
nb_classes = 2
nb_epoch = 5
# input image dimensions
img_rows, img_cols = 28, 28
# number of convolutional filters to use
nb_filters = 64
# size of pooling area for max pooling
nb_pool = 2
# convolution kernel size
nb_conv = 3
In [3]:
#Loading the training and testing data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
img_rows, img_cols = 28, 28
X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
In [4]:
#Seed for reproducibilty
np.random.seed(1338)
#Selecting 6000 random examples from the test data
test_rows = np.random.randint(0,X_test.shape[0],6000)
X_test = X_test[test_rows]
Y = y_test[test_rows]
#Converting the output to binary classification(Six=1,Not Six=0)
Y_test = Y == 6
Y_test = Y_test.astype(int)
#Selecting the 5918 examples where the output is 6
X_six = X_train[y_train == 6]
Y_six = y_train[y_train == 6]
#Selecting the examples where the output is not 6
X_not_six = X_train[y_train != 6]
Y_not_six = y_train[y_train != 6]
#Selecting 6000 random examples from the data that contains only the data where the output is not 6
random_rows = np.random.randint(0,X_six.shape[0],6000)
X_not_six = X_not_six[random_rows]
Y_not_six = Y_not_six[random_rows]
In [5]:
#Appending the data with output as 6 and data with output as not six
X_train = np.append(X_six,X_not_six)
#Reshaping the appended data to appropraite form
X_train = X_train.reshape(X_six.shape[0] + X_not_six.shape[0], 1, img_rows, img_cols)
#Appending the labels and converting the labels to binary classification(Six=1,Not Six=0)
Y_labels = np.append(Y_six,Y_not_six)
Y_train = Y_labels == 6
Y_train = Y_train.astype(int)
In [6]:
#Converting the classes to its binary categorical form
Y_train = np_utils.to_categorical(Y_train, nb_classes)
Y_test = np_utils.to_categorical(Y_test, nb_classes)
In [7]:
def build_resnet():
model = Graph()
model.add_input(input_shape=(1, 28, 28), name="0")
# First piece
model.add_node(Convolution2D(
nb_filters, nb_conv, nb_conv, input_shape=(1, img_rows, img_cols), activation="relu"), name="1", input="0")
model.add_node(Convolution2D(
nb_filters, nb_conv, nb_conv, activation="relu"), name="2", input="1")
model.add_node(Convolution2D(
nb_filters, nb_conv, nb_conv, subsample=(2, 2), activation="relu"), name="X", input="1")
# residual module
model.add_node(Convolution2D(nb_filters, nb_conv, nb_conv,
activation="relu", border_mode="same"), name="r1", input="X")
model.add_node(Convolution2D(nb_filters, nb_conv, nb_conv,
activation="relu", border_mode="same"), name="r2", input="r1")
# add layer_3 + residual_module
model.add_node(Convolution2D(
nb_filters, nb_conv, nb_conv, subsample=(2, 2), activation="relu"),
name="3", inputs=["X", "r2"], merge_mode="sum")
# classifier
model.add_node(Convolution2D(2, nb_conv, nb_conv, activation="linear"), name="4", input="3")
out_size = model._graph_nodes['4'].output_shape[-1] # thanks shape inference
model.add_node(AveragePooling2D((out_size, out_size)),
name="pool", input="4")
model.add_node(Flatten(), name="flat", input="pool")
model.add_node(Activation("softmax"), name="out", input="flat", create_output=True)
model.compile(loss={"out": 'categorical_crossentropy'}, optimizer='adam',metrics=['accuracy'])
model.fit({"0": X_train, "out": Y_train}, nb_epoch=5)
score = model.evaluate({"0": X_train, "out": Y_train}, verbose=0)
print('Test score:', score[0])
print('Test accuracy:', score[1])
In [8]:
%timeit -n1 -r1 build_resnet()
In [9]:
#model.compile(loss={"out": 'categorical_crossentropy'}, optimizer='adam')
#model.fit({"0": X_train, "out": Y_train}, nb_epoch=5, show_accuracy=True)