This project demonstrates the use of various machine learning algorithms to identify handwritten digits. Machine learning algorithms from the scikit-learn package are implemented, where selected parameters are optimized using cross-validation.
The handwritten digit dataset in this code was taken from Andrew Ng's coursera course. Dr. Ng obtained this data from MNIST database, http://yann.lecun.com/exdb/mnist/.
In [14]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import h5py
import support as sp
Let's load the dataset. X is an array containing the handwritten digit data. There are 5000 unique grayscale images with 20x20 voxels each. y is the vector of digit labels. We can use examples from X and y to train a machine learning model to identify digits in new images.
In [15]:
#import urllib
#import ssl
#context=ssl._create_unverified_context()
#data=urllib.URLopener(context=context)
#data.retrieve("https://www.dropbox.com/s/iee58ulksn2kv9b/digits.hdf5?dl=1", "digits.hdf5")
#print "handwritten digit images and labels retrieved"
In [16]:
X, y = sp.loaddata('digits.hdf5')
In the next code block we define the fraction of the original datasets which will be reserved for selecting the best set of parameters for our algortithm (cross-validation set) and for evaluating the accuracy of the final model (test set). Recommended selections for cv_frac and test_frac include 0.05, 0.10, 0.15, 0.20 and 0.25.
In [17]:
cv_frac = 0.1
test_frac = 0.1
X_cal, y_cal, X_cv, y_cv, X_test, y_test = sp.get_sets(X, y, cv_frac, test_frac)
In this code block we select a machine learning algorithm. We also retrieve the valid set of model parameters to iterate through in our cross-validation procedure.
Valid choices for mtype are as follows:
In [18]:
mtype = "NN"
mc = sp.allmodels("NN")
mc.get_Cvec()
We now perform cross validation to select the best parameters for our chosen algorithm without biasing our evaluation of the final prediction accuracy
In [19]:
acc_cv = sp.cross_validate(mc, X_cal, y_cal, X_cv, y_cv)
Now we evaluate the accuracy of the digit identification for the test data for the parameter selected in the cross-validation step
In [20]:
y_test_pred = sp.test_acc(mc, acc_cv, X_cal, y_cal, X_test, y_test)
Let's see whether the prediction accuracy is dependent on which digit we are trying to predict
In [21]:
sp.plot_acc(y_test, y_test_pred)
As a final visualization let's look at a few randomly selected digits and how successful we are at identifying them
In [22]:
sp.plot_prediction(X_test, y_test_pred)