The MNIST database of handwritten digits, available at Yann Lecun web site, has a training set of 60,000 examples, and a test set of 10,000 examples. It is a subset of a larger set available from NIST.
The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting.
training images: | 60000 |
test images: | 10000 |
image pixels: | 28x28 |
image format: | raw vector of 784 elements |
encoding 0-255 |
To can obtain and use it in python in two steps:
sudo pip install python_mnist
In [4]:
# import the mnist class
from mnist import MNIST
# init with the 'data' dir
mndata = MNIST('./data')
# Load data
mndata.load_training()
mndata.load_testing()
# The number of pixels per side of all images
img_side = 28
# Each input is a raw vector.
# The number of units of the network
# corresponds to the number of input elements
n_mnist_pixels = img_side*img_side
Below, as an example, we take the first ten samples from the training set and plot them:
In [5]:
%matplotlib inline
from pylab import *
# Define the number of samples to take
num_samples = 10
# create a figure where we will store all samples
figure(figsize=(10,1))
# Iterate over samples indices
for sample in xrange(num_samples) :
# The image corresponding to the 'sample' index
img = mndata.train_images[sample]
# The label of the image
label = mndata.train_labels[sample]
# The image is stored as a rolled vector,
# we have to roll it back in a matrix
aimg = array(img).reshape(img_side, img_side)
# Open a subplot for each sample
subplot(1, num_samples, sample+1)
# The corresponding digit is the title of the plot
title(label)
# We use imshow to plot the matrix of pixels
imshow(aimg, interpolation = 'none',
aspect = 'auto', cmap = cm.binary)
axis("off")
show()
Next cell is just for styling
In [6]:
from IPython.core.display import HTML
def css_styling():
styles = open("../style/ipybn.css", "r").read()
return HTML(styles)
css_styling()
Out[6]: