MLP network for MNIST digits classification

~95% test accuracy in 20epochs if sgd.


In [6]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, Dropout
from tensorflow.keras.utils import to_categorical, plot_model
from tensorflow.keras.datasets import mnist

# load mnist dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# compute the number of labels
num_labels = len(np.unique(y_train))

# convert to one-hot vector
# e.g. 3 -> [0 0 0 1 0 0 0 0 0 0]
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

# image dimensions (assumed square)
image_size = x_train.shape[1]
input_size = image_size * image_size

# resize and normalize
x_train = np.reshape(x_train, [-1, input_size])
x_train = x_train.astype('float32') / 255
x_test = np.reshape(x_test, [-1, input_size])
x_test = x_test.astype('float32') / 255

# network parameters
batch_size = 128
hidden_units = 256
dropout = 0.45

# model is a 3-layer MLP with ReLU and dropout after each layer
model = Sequential()
model.add(Dense(hidden_units, input_dim=input_size))
model.add(Activation('relu'))
model.add(Dense(hidden_units))
model.add(Activation('relu'))
model.add(Dense(num_labels))
# this is the output for one-hot vector
model.add(Activation('softmax'))
model.summary()
# plot_model(model, to_file='mlp-mnist.png', show_shapes=True)

# loss function for one-hot vector
# use of sgd optimizer with default lr=0.01
# accuracy is good metric for classification tasks
model.compile(loss='categorical_crossentropy',
              optimizer='sgd',
              metrics=['accuracy'])
# train the network
model.fit(x_train, y_train, epochs=20, batch_size=batch_size)

# validate the model on test dataset to determine generalization
loss, acc = model.evaluate(x_test,
                           y_test, 
                           batch_size=batch_size,
                           verbose=False)
print("\nTest accuracy: %.1f%%" % (100.0 * acc))


Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_6 (Dense)              (None, 256)               200960    
_________________________________________________________________
activation_6 (Activation)    (None, 256)               0         
_________________________________________________________________
dense_7 (Dense)              (None, 256)               65792     
_________________________________________________________________
activation_7 (Activation)    (None, 256)               0         
_________________________________________________________________
dense_8 (Dense)              (None, 10)                2570      
_________________________________________________________________
activation_8 (Activation)    (None, 10)                0         
=================================================================
Total params: 269,322
Trainable params: 269,322
Non-trainable params: 0
_________________________________________________________________
Train on 60000 samples
Epoch 1/20
60000/60000 [==============================] - 3s 57us/sample - loss: 1.1975 - accuracy: 0.7223
Epoch 2/20
60000/60000 [==============================] - 3s 43us/sample - loss: 0.4732 - accuracy: 0.8772
Epoch 3/20
60000/60000 [==============================] - 3s 44us/sample - loss: 0.3731 - accuracy: 0.8975
Epoch 4/20
60000/60000 [==============================] - 2s 37us/sample - loss: 0.3302 - accuracy: 0.9069
Epoch 5/20
60000/60000 [==============================] - 3s 44us/sample - loss: 0.3032 - accuracy: 0.9145
Epoch 6/20
60000/60000 [==============================] - 2s 38us/sample - loss: 0.2829 - accuracy: 0.9197
Epoch 7/20
60000/60000 [==============================] - 2s 38us/sample - loss: 0.2666 - accuracy: 0.9248
Epoch 8/20
60000/60000 [==============================] - 3s 45us/sample - loss: 0.2526 - accuracy: 0.9286
Epoch 9/20
60000/60000 [==============================] - 3s 43us/sample - loss: 0.2407 - accuracy: 0.9326
Epoch 10/20
60000/60000 [==============================] - 2s 41us/sample - loss: 0.2299 - accuracy: 0.9350
Epoch 11/20
60000/60000 [==============================] - 2s 40us/sample - loss: 0.2203 - accuracy: 0.9380
Epoch 12/20
60000/60000 [==============================] - 3s 44us/sample - loss: 0.2115 - accuracy: 0.9405
Epoch 13/20
60000/60000 [==============================] - 2s 41us/sample - loss: 0.2033 - accuracy: 0.9426
Epoch 14/20
60000/60000 [==============================] - 2s 39us/sample - loss: 0.1956 - accuracy: 0.9441
Epoch 15/20
60000/60000 [==============================] - 2s 41us/sample - loss: 0.1886 - accuracy: 0.9462
Epoch 16/20
60000/60000 [==============================] - 3s 43us/sample - loss: 0.1816 - accuracy: 0.9485
Epoch 17/20
60000/60000 [==============================] - 2s 40us/sample - loss: 0.1756 - accuracy: 0.9503
Epoch 18/20
60000/60000 [==============================] - 2s 40us/sample - loss: 0.1697 - accuracy: 0.9519
Epoch 19/20
60000/60000 [==============================] - 2s 40us/sample - loss: 0.1644 - accuracy: 0.9538
Epoch 20/20
60000/60000 [==============================] - 2s 39us/sample - loss: 0.1593 - accuracy: 0.9551

Test accuracy: 95.3%

In [ ]:


In [ ]: