This project demonstrates the use of various machine learning algorithms to identify handwritten digits. Machine learning algorithms from the scikit-learn package are implemented, where selected parameters are optimized using cross-validation.
The handwritten digit dataset in this code was taken from Andrew Ng's coursera course. Dr. Ng obtained this data from MNIST database, http://yann.lecun.com/exdb/mnist/.
In :%matplotlib inline %load_ext autoreload %autoreload 2 import h5py import support as sp
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
Let's load the dataset. X is an array containing the handwritten digit data. There are 5000 unique grayscale images with 20x20 voxels each. y is the vector of digit labels. We can use examples from X and y to train a machine learning model to identify digits in new images.
In :#import urllib #import ssl #context=ssl._create_unverified_context() #data=urllib.URLopener(context=context) #data.retrieve("https://www.dropbox.com/s/iee58ulksn2kv9b/digits.hdf5?dl=1", "digits.hdf5") #print "handwritten digit images and labels retrieved"
In :X, y = sp.loaddata('digits.hdf5')
In the next code block we define the fraction of the original datasets which will be reserved for selecting the best set of parameters for our algortithm (cross-validation set) and for evaluating the accuracy of the final model (test set). Recommended selections for cv_frac and test_frac include 0.05, 0.10, 0.15, 0.20 and 0.25.
In :cv_frac = 0.1 test_frac = 0.1 X_cal, y_cal, X_cv, y_cv, X_test, y_test = sp.get_sets(X, y, cv_frac, test_frac)
In this code block we select a machine learning algorithm. We also retrieve the valid set of model parameters to iterate through in our cross-validation procedure.
Valid choices for mtype are as follows:
In :mtype = "NN" mc = sp.allmodels("NN") mc.get_Cvec()
C gives the regularization strength
We now perform cross validation to select the best parameters for our chosen algorithm without biasing our evaluation of the final prediction accuracy
In :acc_cv = sp.cross_validate(mc, X_cal, y_cal, X_cv, y_cv)
C gives the regularization strength prediction accuracy, cross-validation data (C=1e-07): 92.2% prediction accuracy, cross-validation data (C=1e-06): 93.0% prediction accuracy, cross-validation data (C=1e-05): 93.2% prediction accuracy, cross-validation data (C=0.0001): 93.2% prediction accuracy, cross-validation data (C=0.001): 93.4% prediction accuracy, cross-validation data (C=0.01): 93.0% prediction accuracy, cross-validation data (C=0.1): 94.6% prediction accuracy, cross-validation data (C=1.0): 95.8%
Now we evaluate the accuracy of the digit identification for the test data for the parameter selected in the cross-validation step
In :y_test_pred = sp.test_acc(mc, acc_cv, X_cal, y_cal, X_test, y_test)
prediction accuracy, test data (C=1.0): 95.6%
Let's see whether the prediction accuracy is dependent on which digit we are trying to predict
In :sp.plot_acc(y_test, y_test_pred)
As a final visualization let's look at a few randomly selected digits and how successful we are at identifying them
In :sp.plot_prediction(X_test, y_test_pred)