MNIST Digit Classification - FCN


In [1]:
from __future__ import division, print_function
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.preprocessing import OneHotEncoder
import matplotlib.pyplot as plt
import mxnet as mx
import numpy as np
import os
%matplotlib inline

In [2]:
DATA_DIR = "../../data"
TRAIN_FILE = os.path.join(DATA_DIR, "mnist_train.csv")
TEST_FILE = os.path.join(DATA_DIR, "mnist_test.csv")

MODEL_FILE = os.path.join(DATA_DIR, "mxnet-mnist-fcn")

LEARNING_RATE = 0.001
INPUT_SIZE = 28*28
BATCH_SIZE = 128
NUM_CLASSES = 10
NUM_EPOCHS = 10

Prepare Data


In [9]:
def parse_file(filename):
    xdata, ydata = [], []
    fin = open(filename, "rb")
    i = 0
    for line in fin:
        if i % 10000 == 0:
            print("{:s}: {:d} lines read".format(
                os.path.basename(filename), i))
        cols = line.strip().split(",")
        ydata.append(int(cols[0]))
#         xdata.append([float(x) / 255. for x in cols[1:]])
        xdata.append([float(x) for x in cols[1:]])
        i += 1
    fin.close()
    print("{:s}: {:d} lines read".format(os.path.basename(filename), i))
    X = np.array(xdata)
    ohe = OneHotEncoder(n_values=NUM_CLASSES)
    Y = ohe.fit_transform([ydata]).todense().reshape(len(ydata), -1)
    return X, Y

Xtrain, Ytrain = parse_file(TRAIN_FILE)
Xtest, Ytest = parse_file(TEST_FILE)
print(Xtrain.shape, Ytrain.shape, Xtest.shape, Ytest.shape)


mnist_train.csv: 0 lines read
mnist_train.csv: 10000 lines read
mnist_train.csv: 20000 lines read
mnist_train.csv: 30000 lines read
mnist_train.csv: 40000 lines read
mnist_train.csv: 50000 lines read
mnist_train.csv: 60000 lines read
mnist_test.csv: 0 lines read
mnist_test.csv: 10000 lines read
(60000, 784) (60000, 10) (10000, 784) (10000, 10)

In [10]:
train_gen = mx.io.NDArrayIter(Xtrain, label=Ytrain, batch_size=BATCH_SIZE, shuffle=True)
val_gen = mx.io.NDArrayIter(Xtest, label=Ytest, batch_size=BATCH_SIZE)

Define Network


In [11]:
# Create a place holder variable for the input data
data = mx.sym.Variable('data')
# FC1: 784 => 128
fc1  = mx.sym.FullyConnected(data=data, name='fc1', num_hidden=128)
fc1 = mx.sym.Activation(data=fc1, name='relu1', act_type="relu")
fc1 = mx.sym.Dropout(data=fc1, name="drop1", p=0.2)
# FC2: 128 => 64
fc2  = mx.sym.FullyConnected(data=fc1, name='fc2', num_hidden=64)
fc2 = mx.sym.Activation(data=fc2, name='relu2', act_type="relu")
fc2 = mx.sym.Dropout(data=fc2, name="drop2", p=0.2)
# FC3: 64 => 10
fc3  = mx.sym.FullyConnected(data=fc2, name='fc3', num_hidden=NUM_CLASSES)
# The softmax and loss layer
net  = mx.sym.SoftmaxOutput(data=fc3, name='softmax')

Train Network

No built-in method to capture loss and accuracy during training. One can register a custom callback to collect the training accuracy at the end of every epoch, but apparently the general approach in the MXNet community is to eyeball the numbers.


In [12]:
import logging
logging.getLogger().setLevel(logging.DEBUG)

train_gen.reset()
val_gen.reset()

model = mx.mod.Module(symbol=net, data_names=["data"], label_names=["softmax_label"])

checkpoint = mx.callback.do_checkpoint(MODEL_FILE)
num_batches_per_epoch = len(Xtrain) // BATCH_SIZE
model.fit(train_gen, 
          eval_data=val_gen, 
          optimizer="adam", 
          optimizer_params={"learning_rate": LEARNING_RATE},
          eval_metric="acc",
          num_epoch=NUM_EPOCHS,
          epoch_end_callback=checkpoint)


INFO:root:Epoch[0] Train-accuracy=0.900528
INFO:root:Epoch[0] Time cost=4.420
INFO:root:Saved checkpoint to "../../data/mxnet-mnist-fcn-0001.params"
INFO:root:Epoch[0] Validation-accuracy=0.900732
INFO:root:Epoch[1] Train-accuracy=0.903641
INFO:root:Epoch[1] Time cost=4.675
INFO:root:Saved checkpoint to "../../data/mxnet-mnist-fcn-0002.params"
INFO:root:Epoch[1] Validation-accuracy=0.903432
INFO:root:Epoch[2] Train-accuracy=0.907108
INFO:root:Epoch[2] Time cost=4.431
INFO:root:Saved checkpoint to "../../data/mxnet-mnist-fcn-0003.params"
INFO:root:Epoch[2] Validation-accuracy=0.906339
INFO:root:Epoch[3] Train-accuracy=0.911914
INFO:root:Epoch[3] Time cost=4.675
INFO:root:Saved checkpoint to "../../data/mxnet-mnist-fcn-0004.params"
INFO:root:Epoch[3] Validation-accuracy=0.911284
INFO:root:Epoch[4] Train-accuracy=0.916744
INFO:root:Epoch[4] Time cost=4.923
INFO:root:Saved checkpoint to "../../data/mxnet-mnist-fcn-0005.params"
INFO:root:Epoch[4] Validation-accuracy=0.916367
INFO:root:Epoch[5] Train-accuracy=0.919689
INFO:root:Epoch[5] Time cost=4.893
INFO:root:Saved checkpoint to "../../data/mxnet-mnist-fcn-0006.params"
INFO:root:Epoch[5] Validation-accuracy=0.923724
INFO:root:Epoch[6] Train-accuracy=0.923829
INFO:root:Epoch[6] Time cost=4.436
INFO:root:Saved checkpoint to "../../data/mxnet-mnist-fcn-0007.params"
INFO:root:Epoch[6] Validation-accuracy=0.926632
INFO:root:Epoch[7] Train-accuracy=0.927071
INFO:root:Epoch[7] Time cost=4.685
INFO:root:Saved checkpoint to "../../data/mxnet-mnist-fcn-0008.params"
INFO:root:Epoch[7] Validation-accuracy=0.923803
INFO:root:Epoch[8] Train-accuracy=0.929866
INFO:root:Epoch[8] Time cost=4.482
INFO:root:Saved checkpoint to "../../data/mxnet-mnist-fcn-0009.params"
INFO:root:Epoch[8] Validation-accuracy=0.930825
INFO:root:Epoch[9] Train-accuracy=0.932438
INFO:root:Epoch[9] Time cost=4.813
INFO:root:Saved checkpoint to "../../data/mxnet-mnist-fcn-0010.params"
INFO:root:Epoch[9] Validation-accuracy=0.939784

Evaluate Network


In [14]:
test_gen = mx.io.NDArrayIter(Xtest, label=Ytest, batch_size=BATCH_SIZE)
test_accuracy = model.score(test_gen, eval_metric="acc")
print(test_accuracy)


[('accuracy', 0.939784414556962)]

In [ ]: