Loading the MNIST data

This is the MNIST data obtainable at http://yann.lecun.com/exdb/mnist/

The data is supplied as IDX files compressed in gzip format. The code below unzips the data, converts the IDX file to an ndarray, reshapes and one-hot encodes as necessary, scales the data and finally pickles the data for easy loading into the main script.

It's worth noting that the pickled data files are not backward compatible with Python 2.X, so if you haven't yet started using Python 3.X then you should download the gzips yourself and run this script locally to generate Python 2.X compatible pickle files. YMMV.

Finally, the details of the data are available on the website above. But in a nutshell, the training data contains 60 000 images, and the testing data contains 10 000 images. I randomly removed 10 000 of the training data points to set aside as a validation set.


In [93]:
import pickle
import gzip
import idx2numpy
import numpy as np
from sklearn.cross_validation import train_test_split
from sklearn.preprocessing import MinMaxScaler

In [94]:
# Uncompress the gzips and convert the IDX files to ndarray
with gzip.open('data/gzips/train-images-idx3-ubyte.gz', 'rb') as f:
    xtrain = idx2numpy.convert_from_file(f)

with gzip.open('data/gzips/train-labels-idx1-ubyte.gz', 'rb') as f:
    ytrain = idx2numpy.convert_from_file(f)

# Reshape the images to an [nXm] array
xtrain = xtrain.reshape(len(xtrain),-1)
xtrain = MinMaxScaler().fit_transform(xtrain)
# One-hot encode the y values
ytrain = np.eye(10)[ytrain].reshape(len(ytrain),10)
# Seperate out the validation set. Note: the random_state parameter will ensure you get the same results as me.
xtrain, xval, ytrain, yval = train_test_split(xtrain, ytrain, test_size=10000, random_state=0)

# Write the pickled files for importing easily into other scripts
with open('data/pickled/xtrain.pickle', 'wb') as f:
    pickle.dump(xtrain, f, pickle.HIGHEST_PROTOCOL)
    
with open('data/pickled/xval.pickle', 'wb') as f:
    pickle.dump(xval, f, pickle.HIGHEST_PROTOCOL)

with open('data/pickled/ytrain.pickle', 'wb') as f:
    pickle.dump(ytrain, f, pickle.HIGHEST_PROTOCOL)
    
with open('data/pickled/yval.pickle', 'wb') as f:
    pickle.dump(yval, f, pickle.HIGHEST_PROTOCOL)


/home/nobody/anaconda3/lib/python3.5/site-packages/sklearn/utils/validation.py:420: DataConversionWarning: Data with input dtype uint8 was converted to float64 by MinMaxScaler.
  warnings.warn(msg, DataConversionWarning)

In [95]:
# As above, but for the test set
with gzip.open('data/gzips/t10k-images-idx3-ubyte.gz', 'rb') as f:
    xtest = idx2numpy.convert_from_file(f)
    
with gzip.open('data/gzips/t10k-labels-idx1-ubyte.gz', 'rb') as f:
    ytest = idx2numpy.convert_from_file(f)

xtest = xtest.reshape(len(xtest),-1)
xtest = MinMaxScaler().fit_transform(xtest)
ytest = np.eye(10)[ytest].reshape(len(ytest),10)

with open('data/pickled/xtest.pickle', 'wb') as f:
    pickle.dump(xtest, f, pickle.HIGHEST_PROTOCOL)
    
with open('data/pickled/ytest.pickle', 'wb') as f:
    pickle.dump(ytest, f, pickle.HIGHEST_PROTOCOL)


/home/nobody/anaconda3/lib/python3.5/site-packages/sklearn/utils/validation.py:420: DataConversionWarning: Data with input dtype uint8 was converted to float64 by MinMaxScaler.
  warnings.warn(msg, DataConversionWarning)