Theano + Lasagne :: MNIST CNN

This is a quick illustration of a Convolutional Neural Network being trained on the MNIST data.

( Credit for initially creating this workbook : Eben Olson :: https://github.com/ebenolson/pydata2015 )


In [ ]:
import numpy as np
import theano
import theano.tensor as T
import lasagne

import matplotlib.pyplot as plt
%matplotlib inline

import gzip
import pickle

In [ ]:
# Seed for reproduciblity
np.random.seed(42)

Get the MNIST data

Put it into useful subsets, and show some of it as a sanity check


In [ ]:
# Download the MNIST digits dataset (Already downloaded locally)
# !wget -N --directory-prefix=./data/MNIST/ http://deeplearning.net/data/mnist/mnist.pkl.gz

In [ ]:
train, val, test = pickle.load(gzip.open('./data/MNIST/mnist.pkl.gz'), encoding='iso-8859-1')

X_train, y_train = train
X_val, y_val = val

In [ ]:
def batch_gen(X, y, N):
    while True:
        idx = np.random.choice(len(y), N)
        yield X[idx].astype('float32'), y[idx].astype('int32')

Create the Network

This is a Convolutional Neural Network (CNN), where each 'filter' in a given layer is produced by scanning a small (here 3x3) matrix over the whole of the previous layer (a convolution operation). These filters can produce effects like : averaging, edge detection, etc.


In [ ]:
# We need to reshape from a 1D feature vector to a 1 channel 2D image.
# Then we apply 3 convolutional filters with 3x3 kernel size.
l_in = lasagne.layers.InputLayer((None, 784))

l_shape = lasagne.layers.ReshapeLayer(l_in, (-1, 1, 28, 28))

l_conv = lasagne.layers.Conv2DLayer(l_shape, num_filters=3, filter_size=3, pad=1)

l_out = lasagne.layers.DenseLayer(l_conv,
                                  num_units=10,
                                  nonlinearity=lasagne.nonlinearities.softmax)

Compile and train the network.

Accuracy is much better than the single layer network, despite the small number of filters.


In [ ]:
X_sym = T.matrix()
y_sym = T.ivector()

output = lasagne.layers.get_output(l_out, X_sym)
pred = output.argmax(-1)

loss = T.mean(lasagne.objectives.categorical_crossentropy(output, y_sym))

acc = T.mean(T.eq(pred, y_sym))

params = lasagne.layers.get_all_params(l_out)
grad = T.grad(loss, params)
updates = lasagne.updates.adam(grad, params, learning_rate=0.005)

f_train = theano.function([X_sym, y_sym], [loss, acc], updates=updates)
f_val = theano.function([X_sym, y_sym], [loss, acc])
f_predict = theano.function([X_sym], pred)
print("Built network")

In [ ]:
BATCH_SIZE = 64
N_BATCHES = len(X_train) // BATCH_SIZE
N_VAL_BATCHES = len(X_val) // BATCH_SIZE

In [ ]:
train_batches = batch_gen(X_train, y_train, BATCH_SIZE)
val_batches = batch_gen(X_val, y_val, BATCH_SIZE)

for epoch in range(5):
    train_loss = 0
    train_acc = 0
    for _ in range(N_BATCHES):
        X, y = next(train_batches)
        loss, acc = f_train(X, y)
        train_loss += loss
        train_acc += acc
    train_loss /= N_BATCHES
    train_acc /= N_BATCHES

    val_loss = 0
    val_acc = 0
    for _ in range(N_VAL_BATCHES):
        X, y = next(val_batches)
        loss, acc = f_val(X, y)
        val_loss += loss
        val_acc += acc
    val_loss /= N_VAL_BATCHES
    val_acc /= N_VAL_BATCHES
    
    print('Epoch {:2d}, Train loss {:.03f}     (validation loss     : {:.03f}) ratio {:.03f}'.format(
            epoch, train_loss, val_loss, val_loss/train_loss))
    print('          Train accuracy {:.03f} (validation accuracy : {:.03f})'.format(train_acc, val_acc))
print("DONE")

Look at the Output after the Convolutional Layer

Since the convolutional layer only has 3 filters, we can map these to red, green and blue for easier visualisation.


In [ ]:
filtered = lasagne.layers.get_output(l_conv, X_sym)
f_filter = theano.function([X_sym], filtered)

In [ ]:
# Filter the first few training examples
im = f_filter(X_train[:10])
print(im.shape)

In [ ]:
# Rearrange dimension so we can plot the result as RGB images
im = np.rollaxis(np.rollaxis(im, 3, 1), 3, 1)

We can see that each filter detected different features in the images, i.e. horizontal / diagonal / vertical segments


In [ ]:
plt.figure(figsize=(16,8))
for i in range(10):
    plt.subplot(1, 10, i+1)
    plt.imshow(im[i], interpolation='nearest')
    plt.axis('off')

In [ ]: