Handwritten Digit Indentification

This project demonstrates the use of various machine learning algorithms to identify handwritten digits. Machine learning algorithms from the scikit-learn package are implemented, where selected parameters are optimized using cross-validation.

The handwritten digit dataset in this code was taken from Andrew Ng's coursera course. Dr. Ng obtained this data from MNIST database, http://yann.lecun.com/exdb/mnist/.

In [14]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import h5py
import support as sp

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

Let's load the dataset. X is an array containing the handwritten digit data. There are 5000 unique grayscale images with 20x20 voxels each. y is the vector of digit labels. We can use examples from X and y to train a machine learning model to identify digits in new images.

In [15]:
#import urllib
#import ssl


#data.retrieve("https://www.dropbox.com/s/iee58ulksn2kv9b/digits.hdf5?dl=1", "digits.hdf5")
#print "handwritten digit images and labels retrieved"

In [16]:
X, y = sp.loaddata('digits.hdf5')

In the next code block we define the fraction of the original datasets which will be reserved for selecting the best set of parameters for our algortithm (cross-validation set) and for evaluating the accuracy of the final model (test set). Recommended selections for cv_frac and test_frac include 0.05, 0.10, 0.15, 0.20 and 0.25.

In [17]:
cv_frac = 0.1
test_frac = 0.1
X_cal, y_cal, X_cv, y_cv, X_test, y_test = sp.get_sets(X, y, cv_frac, test_frac)

In this code block we select a machine learning algorithm. We also retrieve the valid set of model parameters to iterate through in our cross-validation procedure.

Valid choices for mtype are as follows:

  • GB: In gradient boosting an ensemble of decision trees inform the final model. Each successive decision tree in the ensemble targets the weaknesses of previous trees. In contrast to other boosting methods, gradient boosting allows for any differentiable loss function (AdaBoost can be replicated with the correct choice of loss function).
  • LOG: In logistic regression a sigmoidal function is fit to the data for the purposes of classification. Generally, predictions greater than 0.5 are taken as positive examples and predictions less than zero are taken as negative examples.
  • NN: In artificial neural networks, between the input and output units, potentially multiple layers of units, usually in the form of simple models, are connected together. This complexity allows the neural network to naturally identify the higher order features which are important for the problem at hand.
  • RF: In Random forest classification individual predictions from an ensemble of weak decision trees are combined together to provide much stronger predictive capability.
  • SVM: Support vector machine classification is similar to logistic regression, except that other similarity functions other than the sigmoidal form may be chosen.

In [18]:
mtype = "NN"
mc = sp.allmodels("NN")

C gives the regularization strength

We now perform cross validation to select the best parameters for our chosen algorithm without biasing our evaluation of the final prediction accuracy

In [19]:
acc_cv = sp.cross_validate(mc, X_cal, y_cal, X_cv, y_cv)

C gives the regularization strength
prediction accuracy, cross-validation data (C=1e-07): 92.2%
prediction accuracy, cross-validation data (C=1e-06): 93.0%
prediction accuracy, cross-validation data (C=1e-05): 93.2%
prediction accuracy, cross-validation data (C=0.0001): 93.2%
prediction accuracy, cross-validation data (C=0.001): 93.4%
prediction accuracy, cross-validation data (C=0.01): 93.0%
prediction accuracy, cross-validation data (C=0.1): 94.6%
prediction accuracy, cross-validation data (C=1.0): 95.8%

Now we evaluate the accuracy of the digit identification for the test data for the parameter selected in the cross-validation step

In [20]:
y_test_pred = sp.test_acc(mc, acc_cv, X_cal, y_cal, X_test, y_test)

prediction accuracy, test data (C=1.0): 95.6%

Let's see whether the prediction accuracy is dependent on which digit we are trying to predict

In [21]:
sp.plot_acc(y_test, y_test_pred)

As a final visualization let's look at a few randomly selected digits and how successful we are at identifying them

In [22]:
sp.plot_prediction(X_test, y_test_pred)