Anomaly Detection on MNIST

This notebook shows how a Deep Learning Auto-Encoder model can be used to find outliers in a dataset.

Consider the following three-layer neural network with one hidden layer and the same number of input neurons (features) as output neurons. The loss function is the MSE between the input and the output. Hence, the network is forced to learn the identity via a nonlinear, reduced representation of the original data. Such an algorithm is called a deep autoencoder; these models have been used extensively for unsupervised, layer-wise pretraining of supervised deep learning tasks, but here we consider the autoencoder's application for discovering anomalies in data.

We use the well-known MNIST dataset of hand-written digits, where each row contains the 28^2=784 raw gray-scale pixel values from 0 to 255 of the digitized digits (0 to 9).

Load Theano/Lasagne and the MNIST training/testing datasets


In [ ]:
import numpy as np
import theano
import lasagne

import matplotlib.pyplot as plt
%matplotlib inline

import gzip
import pickle

In [ ]:
# Seed for reproducibility
np.random.seed(42)

In [ ]:
# Download the MNIST digits dataset (actually, these are already downloaded locally)
# !wget -N --directory-prefix=./data/MNIST/ http://deeplearning.net/data/mnist/mnist.pkl.gz

# Load training and test splits as numpy arrays
train, val, test = pickle.load(gzip.open('data/MNIST/mnist.pkl.gz'), encoding='iso-8859-1')

X_train, y_train = train
# Omit the validation set...
X_test, y_test = test

#X_train[1000][:]

In [ ]:
# For training, we want to sample examples at random in small batches - don't care about the 'y_target'
def batch_gen(X, N): 
    while True:
        idx = np.random.choice(len(X), N)
        yield X[idx].astype('float32')

Finding outliers - ugly hand-written digits

We train a Deep Learning Auto-Encoder to learn a compressed (low-dimensional) non-linear representation of the dataset, hence learning the intrinsic structure of the training dataset. The auto-encoder model is then used to transform all test set images to their reconstructed images, by passing through the lower-dimensional neural network. We then find outliers in a test dataset by comparing the reconstruction of each scanned digit with its original pixel values. The idea is that a high reconstruction error of a digit indicates that the test set point doesn't conform to the structure of the training data and can hence be called an outlier.

Learn what's normal from the training data

Train unsupervised Deep Learning autoencoder model on the training dataset. For simplicity, we train a model with 1 hidden layer of 50 Tanh neurons to create 50 non-linear features with which to reconstruct the original dataset. For now, please accept that 50 hidden units is a reasonable choice...

For simplicity, we train the auto-encoder for only 5 epoch (fives passes over the entire traing dataset).


In [ ]:
# A very simple network, an autoencoder with a single hidden layer of 50 neurons
l_in = lasagne.layers.InputLayer(shape=(None, 784))
l_hidden = lasagne.layers.DenseLayer(l_in,
                                    num_units=50,
                                    nonlinearity=lasagne.nonlinearities.tanh)
l_out = lasagne.layers.DenseLayer(l_hidden,
                                    num_units=784,
                                    nonlinearity=lasagne.nonlinearities.sigmoid)

In [ ]:
# Symbolic variable for our input features
X_sym = theano.tensor.matrix()

In [ ]:
# Theano expressions for the output distribution and loss vs the original input
output = lasagne.layers.get_output(l_out, X_sym)

# The loss function is the sum-squared-error averaged over a minibatch
sample_loss = theano.tensor.mean(lasagne.objectives.squared_error(output, X_sym), axis=1)
minibatch_loss = theano.tensor.mean(sample_loss)

In [ ]:
# We retrieve all the trainable parameters in our network
params = lasagne.layers.get_all_params(l_out, trainable=True)

# Compute Adam updates for training (scores on right show training speed variation)
updates = lasagne.updates.adam(minibatch_loss, params)      # 0.065 ... 0.032
#updates = lasagne.updates.adagrad(minibatch_loss, params)   # 0.056 ... 0.037
#updates = lasagne.updates.rmsprop(minibatch_loss, params)   # 0.059 ... 0.041
#updates = lasagne.updates.adadelta(minibatch_loss, params)  # 0.101 ... 0.065

print(params)

In [ ]:
# We define a training function that will compute the loss, and take a single optimization step
f_train = theano.function([X_sym], minibatch_loss, updates=updates)

# The prediction function doesn't require targets, and outputs only the autoencoder loss for the sample
f_predict = theano.function([X_sym], [output, sample_loss])

print("Theano functions created")

In [ ]:
# We'll choose a batch size, and calculate the number of batches in an "epoch"
BATCH_SIZE = 64
N_BATCHES = len(X_train) // BATCH_SIZE

# Minibatch generators for the training and validation sets
train_batches = batch_gen(X_train, BATCH_SIZE)

NB: Each epoch should take 10-20 seconds, although the first one may take a little longer...


In [ ]:
for epoch in range(5):
    train_loss = 0
    for _ in range(N_BATCHES):
        X = next(train_batches)
        loss  = f_train(X)
        train_loss += loss
    train_loss /= N_BATCHES
    print('Epoch {:2d}, Train loss {:.03f}'.format( epoch, train_loss, ))
print("DONE")

Find outliers in the test data

The Anomaly app computes the per-row reconstruction error for the test data set. It passes it through the autoencoder model (built on the training data) and computes mean square error (MSE) for each row in the test set.


In [ ]:
test_reconstructed, test_loss = f_predict(X_test)
test_loss.shape

Visualize the good, the bad and the ugly

We will need a helper function for plotting handwritten digits:


In [ ]:
def plot_by_index(X, indices):
    plt.figure(figsize=(12,3))
    for i in range(len(indices)):
        plt.subplot(1, 12, i+1)
        plt.imshow(X[indices[i]].reshape((28, 28)), cmap='gray', interpolation='nearest')
        plt.axis('off')

Let's look at the test set points with low/median/high reconstruction errors. We will now visualize the original test set points and their reconstructions obtained by propagating them through the narrow neural net.


In [ ]:
# Sort the test set into recostruction error order
test_loss_sorted_indices = np.argsort( test_loss )

# Here are the best ones
test_loss[ test_loss_sorted_indices[0:10] ]

The good

Let's plot the 12 digits with lowest reconstruction error. First we plot the reconstruction, then the original scanned images.


In [ ]:
indices = test_loss_sorted_indices[0:12]

plot_by_index(X_test, indices)
plot_by_index(test_reconstructed, indices)

Clearly, a well-written digit 1 appears in both the training and testing set, and is easy to reconstruct by the autoencoder with minimal reconstruction error. Nothing is as easy as a straight line.

The bad

Now let's look at the 12 digits with median reconstruction error.


In [ ]:
mid = len(test_loss_sorted_indices)//2
indices = test_loss_sorted_indices[mid-6:mid+6]

plot_by_index(X_test, indices)
plot_by_index(test_reconstructed, indices)

These test set digits look "normal" - it is plausible that they resemble digits from the training data to a large extent, but they do have some particularities that cause some reconstruction error.

The ugly

And here are the biggest outliers - The 12 digits with highest reconstruction error!


In [ ]:
indices = test_loss_sorted_indices[-12:]

plot_by_index(X_test, indices)
plot_by_index(test_reconstructed, indices)

Now here are some pretty ugly digits that are plausibly not commonly found in the training data - some are even hard to classify by humans.

Voila!

We were able to find outliers with Deep Learning Auto-Encoder models.

We would love to hear your use-case for Anomaly detection...

Exercises

  • What is the test set error before network training?

  • Check whether 20 hidden units or 100 hidden units would have been a better choice

  • See how the learning progresses using adadelta updates

  • Try adding your own example digits to ./images/mnist/ - there's a template .png file there to start - and have a look at the reconstruction errors


In [ ]:
im = plt.imread('./images/mnist/template_28x28.png')
plt.imshow(im, 'gray')

In [ ]:
import os

image_dir = './images/mnist/'

image_files = [ '%s/%s' % (image_dir, f) for f in os.listdir(image_dir) 
                 if (f.lower().endswith('png') or f.lower().endswith('jpg')) ]

v=[]
for i, f in enumerate(image_files):
    im = plt.imread(f)
    #print("Image File:%s" % (f,))
    v.append( im.flatten() )

# v=[ plt.imread(f).flatten() for f in image_files ]
v_reconstructed, v_loss = f_predict(v)

v_all_indices = np.arange(len(v))

plot_by_index(v, v_all_indices)
plot_by_index(v_reconstructed, v_all_indices)

In [ ]: