MNIST Digit Classification - CNN


In [1]:
from __future__ import division, print_function
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.preprocessing import OneHotEncoder
import matplotlib.pyplot as plt
import mxnet as mx
import numpy as np
import os
%matplotlib inline

In [2]:
DATA_DIR = "../../data"
TRAIN_FILE = os.path.join(DATA_DIR, "mnist_train.csv")
TEST_FILE = os.path.join(DATA_DIR, "mnist_test.csv")

MODEL_FILE = os.path.join(DATA_DIR, "mxnet-mnist-cn")

LEARNING_RATE = 0.001
INPUT_SIZE = 28
BATCH_SIZE = 128
NUM_CLASSES = 10
NUM_EPOCHS = 5

Prepare Data

Images have Theano-style dim ordering, i.e, (num_batches, num_channels, height, width).


In [8]:
def parse_file(filename):
    xdata, ydata = [], []
    fin = open(filename, "rb")
    i = 0
    for line in fin:
        if i % 10000 == 0:
            print("{:s}: {:d} lines read".format(
                os.path.basename(filename), i))
        cols = line.strip().split(",")
        ydata.append(int(cols[0]))
        xs1d = np.array([float(x) for x in cols[1:]])
        xs4d = np.reshape(xs1d, (1, INPUT_SIZE, INPUT_SIZE))
        xdata.append(xs4d)
        i += 1
    fin.close()
    print("{:s}: {:d} lines read".format(os.path.basename(filename), i))
    X = np.array(xdata)
    ohe = OneHotEncoder(n_values=NUM_CLASSES)
    Y = ohe.fit_transform([ydata]).todense().reshape(len(ydata), -1)
    return X, Y

Xtrain, Ytrain = parse_file(TRAIN_FILE)
Xtest, Ytest = parse_file(TEST_FILE)
print(Xtrain.shape, Ytrain.shape, Xtest.shape, Ytest.shape)


mnist_train.csv: 0 lines read
mnist_train.csv: 10000 lines read
mnist_train.csv: 20000 lines read
mnist_train.csv: 30000 lines read
mnist_train.csv: 40000 lines read
mnist_train.csv: 50000 lines read
mnist_train.csv: 60000 lines read
mnist_test.csv: 0 lines read
mnist_test.csv: 10000 lines read
(60000, 1, 28, 28) (60000, 10) (10000, 1, 28, 28) (10000, 10)

In [9]:
train_gen = mx.io.NDArrayIter(Xtrain, label=Ytrain, batch_size=BATCH_SIZE, shuffle=True)
val_gen = mx.io.NDArrayIter(Xtest, label=Ytest, batch_size=BATCH_SIZE)

Define Network


In [10]:
data = mx.symbol.Variable("data")
# CONV1: 5x5 kernel, channels 1 => 32, maxpool(2)
conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=32)
relu1 = mx.sym.Activation(data=conv1, act_type="relu")
pool1 = mx.sym.Pooling(data=relu1, pool_type="max", kernel=(2,2), stride=(2,2))
# CONV2: 5x5 kernel, channels 32 => 64, maxpool(2)
conv2 = mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=64)
relu2 = mx.sym.Activation(data=conv2, act_type="relu")
pool2 = mx.sym.Pooling(data=relu2, pool_type="max", kernel=(2,2), stride=(2,2))
# FC1: 7*7*64 => 512
flatten = mx.sym.Flatten(data=pool2)
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=512)
relu3 = mx.sym.Activation(data=fc1, act_type="relu")
drop3 = mx.sym.Dropout(data=relu3, p=0.25)
# FC2: 512 => 10
fc2 = mx.sym.FullyConnected(data=drop3, num_hidden=NUM_CLASSES)
net = mx.sym.SoftmaxOutput(data=fc2, name='softmax')

Train Network


In [11]:
import logging
logging.getLogger().setLevel(logging.DEBUG)

train_gen.reset()
val_gen.reset()

model = mx.mod.Module(symbol=net, data_names=["data"], label_names=["softmax_label"])

checkpoint = mx.callback.do_checkpoint(MODEL_FILE)
model.fit(train_gen, 
          eval_data=val_gen, 
          optimizer="adam", 
          optimizer_params={"learning_rate": LEARNING_RATE},
          eval_metric="acc",
          num_epoch=NUM_EPOCHS,
          epoch_end_callback=checkpoint)


INFO:root:Epoch[0] Train-accuracy=0.901997
INFO:root:Epoch[0] Time cost=152.544
INFO:root:Saved checkpoint to "../../data/mxnet-mnist-cn-0001.params"
INFO:root:Epoch[0] Validation-accuracy=0.903441
INFO:root:Epoch[1] Train-accuracy=0.909876
INFO:root:Epoch[1] Time cost=153.703
INFO:root:Saved checkpoint to "../../data/mxnet-mnist-cn-0002.params"
INFO:root:Epoch[1] Validation-accuracy=0.913430
INFO:root:Epoch[2] Train-accuracy=0.921700
INFO:root:Epoch[2] Time cost=156.109
INFO:root:Saved checkpoint to "../../data/mxnet-mnist-cn-0003.params"
INFO:root:Epoch[2] Validation-accuracy=0.923536
INFO:root:Epoch[3] Train-accuracy=0.933645
INFO:root:Epoch[3] Time cost=153.288
INFO:root:Saved checkpoint to "../../data/mxnet-mnist-cn-0004.params"
INFO:root:Epoch[3] Validation-accuracy=0.942484
INFO:root:Epoch[4] Train-accuracy=0.945044
INFO:root:Epoch[4] Time cost=152.634
INFO:root:Saved checkpoint to "../../data/mxnet-mnist-cn-0005.params"
INFO:root:Epoch[4] Validation-accuracy=0.968888

Evaluate Network


In [12]:
test_gen = mx.io.NDArrayIter(Xtest, label=Ytest, batch_size=BATCH_SIZE)
test_accuracy = model.score(test_gen, eval_metric="acc")
print(test_accuracy)


[('accuracy', 0.9688884493670886)]

In [ ]: