The MNIST dataset

The MNIST database of handwritten digits, available at Yann Lecun web site, has a training set of 60,000 examples, and a test set of 10,000 examples. It is a subset of a larger set available from NIST.

The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting.


Table of contents


MNIST numbers

training images: 60000
test images: 10000
image pixels: 28x28
image format: raw vector of 784 elements
encoding 0-255

MNIST with python

Get the dataset and the python functionalities

To can obtain and use it in python in two steps:

sudo pip install python_mnist
  • Download the files of the dataset in a folder called data. You can find them here. Just unzip them in the same directory of your python scripts.
Using the dataset in python

Now we can use the dataset in a readable way by using the load_training and load_testing methods of the MNIST object:


In [4]:
# import the mnist class
from mnist import MNIST 

# init with the 'data' dir
mndata = MNIST('./data')    

# Load data
mndata.load_training() 
mndata.load_testing()

# The number of pixels per side of all images
img_side = 28

# Each input is a raw vector.
# The number of units of the network 
# corresponds to the number of input elements
n_mnist_pixels = img_side*img_side

Below, as an example, we take the first ten samples from the training set and plot them:


In [5]:
%matplotlib inline
from pylab import *

# Define the number of samples to take
num_samples = 10

# create a figure where we will store all samples 
figure(figsize=(10,1))

# Iterate over samples indices
for sample in xrange(num_samples) :
    
    # The image corresponding to the 'sample' index
    img = mndata.train_images[sample]
    
    # The label of the image
    label = mndata.train_labels[sample]
    
    # The image is stored as a rolled vector,
    # we have to roll it back in a matrix
    aimg = array(img).reshape(img_side, img_side)
    
    # Open a subplot for each sample
    subplot(1, num_samples, sample+1)
    
    # The corresponding digit is the title of the plot
    title(label)
    
    # We use imshow to plot the matrix of pixels
    imshow(aimg, interpolation = 'none', 
        aspect = 'auto', cmap = cm.binary)
    axis("off")

show()














Next cell is just for styling


In [6]:
from IPython.core.display import HTML
def css_styling():
    styles = open("../style/ipybn.css", "r").read()
    return HTML(styles)
css_styling()


Out[6]: