KNN algorithm example with the sklean digits dataset

Purpose

We will use the digits dataset to train a k-Nearest Neighbor algorithm to read hand-written numbers.


In [6]:
%matplotlib inline
import matplotlib
import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import load_digits
from sklearn.neighbors import KNeighborsClassifier
from sklearn.cross_validation import cross_val_score

You can read the description of the dataset by using the 'DESCR' key:


In [7]:
print(digits['DESCR'])


 Optical Recognition of Handwritten Digits Data Set

Notes
-----
Data Set Characteristics:
    :Number of Instances: 5620
    :Number of Attributes: 64
    :Attribute Information: 8x8 image of integer pixels in the range 0..16.
    :Missing Attribute Values: None
    :Creator: E. Alpaydin (alpaydin '@' boun.edu.tr)
    :Date: July; 1998

This is a copy of the test set of the UCI ML hand-written digits datasets
http://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits

The data set contains images of hand-written digits: 10 classes where
each class refers to a digit.

Preprocessing programs made available by NIST were used to extract
normalized bitmaps of handwritten digits from a preprinted form. From a
total of 43 people, 30 contributed to the training set and different 13
to the test set. 32x32 bitmaps are divided into nonoverlapping blocks of
4x4 and the number of on pixels are counted in each block. This generates
an input matrix of 8x8 where each element is an integer in the range
0..16. This reduces dimensionality and gives invariance to small
distortions.

For info on NIST preprocessing routines, see M. D. Garris, J. L. Blue, G.
T. Candela, D. L. Dimmick, J. Geist, P. J. Grother, S. A. Janet, and C.
L. Wilson, NIST Form-Based Handprint Recognition System, NISTIR 5469,
1994.

References
----------
  - C. Kaynak (1995) Methods of Combining Multiple Classifiers and Their
    Applications to Handwritten Digit Recognition, MSc Thesis, Institute of
    Graduate Studies in Science and Engineering, Bogazici University.
  - E. Alpaydin, C. Kaynak (1998) Cascading Classifiers, Kybernetika.
  - Ken Tang and Ponnuthurai N. Suganthan and Xi Yao and A. Kai Qin.
    Linear dimensionalityreduction using relevance weighted LDA. School of
    Electrical and Electronic Engineering Nanyang Technological University.
    2005.
  - Claudio Gentile. A New Approximate Maximal Margin Classification
    Algorithm. NIPS. 2000.

We first begin by loading the digits dataset and setting the features matrix (X) and the response vector (y)


In [8]:
digits = load_digits()
X = digits.data
y = digits.target

In [10]:
knn = KNeighborsClassifier(n_neighbors=4)
scores = cross_val_score(knn, X, y, cv=10, scoring='accuracy')
print(scores)
print(scores.mean())


[ 0.92972973  0.98907104  0.98342541  0.97777778  0.95530726  0.98882682
  0.98324022  0.97752809  0.97740113  0.96590909]
0.972821657254

In [39]:
k_range = range(1, 31)
k_scores = []
for k in k_range:
    knn = KNeighborsClassifier(n_neighbors=k)
    scores = cross_val_score(knn, X, y, cv=10, scoring='accuracy')
    k_scores.append(scores.mean())
print(k_scores)


[0.97614938602520218, 0.97671197596272707, 0.9777892113798643, 0.97282165725397274, 0.97394828725469063, 0.96951482797974065, 0.9700794639117305, 0.96672098562665276, 0.96731915128765722, 0.96453789195260331, 0.96561232468072655, 0.96620804017654616, 0.96507178414018602, 0.96451271993566468, 0.9645131249223089, 0.96396081370163866, 0.96339584195022621, 0.96396081370163866, 0.96339266795162293, 0.96120054609571248, 0.96062052503467377, 0.95953046703899125, 0.95953964367433675, 0.95897153338799246, 0.95671826644805547, 0.9556071553369444, 0.95504586203672626, 0.9561758067351066, 0.95449982908147535, 0.95449355201161157]

In [40]:
plt.plot(k_range, k_scores)
plt.xlabel('Value of K for KNN')
plt.ylabel('Cross-Validated Accuracy')


Out[40]:
<matplotlib.text.Text at 0x10a9a7438>

In [14]:
# Displaying different keys/attributes 
# of the dataset
print 'Keys:', digits.keys()

# Loading data
# This includes the pixel value for each of the samples
digits_data = digits['data']
print 'Data for 1st element:', digits_data[0]

# Targets
# This is what actual number for each sample, i.e. the 'truth'
digits_targetnames = digits['target_names']
print 'Target names:', digits_targetnames

digits_target = digits['target']
print 'Targets:', digits_target


Keys: ['images', 'data', 'target_names', 'DESCR', 'target']
Data for 1st element: [  0.   0.   5.  13.   9.   1.   0.   0.   0.   0.  13.  15.  10.  15.   5.
   0.   0.   3.  15.   2.   0.  11.   8.   0.   0.   4.  12.   0.   0.   8.
   8.   0.   0.   5.   8.   0.   0.   9.   8.   0.   0.   4.  11.   0.   1.
  12.   7.   0.   0.   2.  14.   5.  10.  12.   0.   0.   0.   0.   6.  13.
  10.   0.   0.   0.]
Target names: [0 1 2 3 4 5 6 7 8 9]
Targets: [0 1 2 ..., 8 9 8]

This means that you you have 1797 samples, and each of the them are characterized by 64 different features (pixel values).

We can also visualize some of the data, using the 'images' keys:


In [15]:
# Choosing a colormap
color_map_used = plt.get_cmap('autumn')

In [16]:
# Visualizing some of the targets
fig, axes = plt.subplots(2,5, sharex=True, sharey=True, figsize=(20,12))
axes_f = axes.flatten()
for ii in range(len(axes_f)):
    axes_f[ii].imshow(digits['images'][ii], cmap = color_map_used)
    axes_f[ii].text(1, -1, 'Target: {0}'.format(digits_target[ii]), fontsize=30)
plt.show()


The algorithm will be able to use the pixel values to determine that the first number is '0' and the other then is '4'.

Let's see some examples of the number 2:


In [17]:
IDX2 = num.where( digits_target == 2)[0]
print 'There are {0} samples of the number 2 in the dataset'.format(IDX2.size)

fig, axes = plt.subplots(2,5, sharex=True, sharey=True, figsize=(20,12))
axes_f = axes.flatten()
for ii in range(len(axes_f)):
    axes_f[ii].imshow(digits['images'][IDX2][ii], cmap = color_map_used)
    axes_f[ii].text(1, -1, 'Target: {0}'.format(digits_target[IDX2][ii]), fontsize=30)
plt.show()


There are 177 samples of the number 2 in the dataset

In [18]:
print 'And now the number 4\n'
IDX4 = num.where( digits_target == 4)[0]
fig, axes = plt.subplots(2,5, sharex=True, sharey=True, figsize=(20,12))
axes_f = axes.flatten()
for ii in range(len(axes_f)):
    axes_f[ii].imshow(digits['images'][IDX4][ii], cmap = color_map_used)
    axes_f[ii].text(1, -1, 'Target: {0}'.format(digits_target[IDX4][ii]), fontsize=30)
plt.show()


And now the number 4

You can see how different each input by subtracting one target from another. In here, I'm subtracting two images that represent the number '4':


In [19]:
# Difference between two samples of the number 4
plt.imshow(digits['images'][IDX4][1] - digits['images'][IDX4][8], cmap=color_map_used)
plt.show()


This figure shows how different two samples can be from each other.