In [1]:
%matplotlib inline
import os
os.environ['THEANO_FLAGS']='device=gpu0'

import matplotlib
import numpy as np
np.random.seed(123)
import matplotlib.pyplot as plt
import lasagne
import theano
import theano.tensor as T
conv = lasagne.layers.Conv2DLayer
pool = lasagne.layers.MaxPool2DLayer
NUM_EPOCHS = 500
BATCH_SIZE = 256
LEARNING_RATE = 0.001
DIM = 60
NUM_CLASSES = 10
mnist_cluttered = "mnist_cluttered_60x60_6distortions.npz"



Using gpu device 0: Graphics Device

Spatial Transformer Network

We use lasagne to classify cluttered MNIST digits using the spatial transformer network introduced in [1]. The spatial Transformer Network applies a learned affine transformation to its input.

Load data

We test the spatial transformer network using cluttered MNIST data.

Download the data (41 mb) with:


In [1]:
!wget -N https://s3.amazonaws.com/lasagne/recipes/datasets/mnist_cluttered_60x60_6distortions.npz


--2015-08-19 14:52:08--  https://s3.amazonaws.com/lasagne/recipes/datasets/mnist_cluttered_60x60_6distortions.npz
Resolving s3.amazonaws.com... 54.231.48.99
Connecting to s3.amazonaws.com|54.231.48.99|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 43046126 (41M) [application/octet-stream]
Server file no newer than local file 'mnist_cluttered_60x60_6distortions.npz' -- not retrieving.


In [2]:
def load_data():
    data = np.load(mnist_cluttered)
    X_train, y_train = data['x_train'], np.argmax(data['y_train'], axis=-1)
    X_valid, y_valid = data['x_valid'], np.argmax(data['y_valid'], axis=-1)
    X_test, y_test = data['x_test'], np.argmax(data['y_test'], axis=-1)

    # reshape for convolutions
    X_train = X_train.reshape((X_train.shape[0], 1, DIM, DIM))
    X_valid = X_valid.reshape((X_valid.shape[0], 1, DIM, DIM))
    X_test = X_test.reshape((X_test.shape[0], 1, DIM, DIM))
    
    print "Train samples:", X_train.shape
    print "Validation samples:", X_valid.shape
    print "Test samples:", X_test.shape

    return dict(
        X_train=lasagne.utils.floatX(X_train),
        y_train=y_train.astype('int32'),
        X_valid=lasagne.utils.floatX(X_valid),
        y_valid=y_valid.astype('int32'),
        X_test=lasagne.utils.floatX(X_test),
        y_test=y_test.astype('int32'),
        num_examples_train=X_train.shape[0],
        num_examples_valid=X_valid.shape[0],
        num_examples_test=X_test.shape[0],
        input_height=X_train.shape[2],
        input_width=X_train.shape[3],
        output_dim=10,)
data = load_data()


Train samples: (50000, 1, 60, 60)
Validation samples: (10000, 1, 60, 60)
Test samples: (10000, 1, 60, 60)

In [11]:
plt.figure(figsize=(7,7))
plt.imshow(data['X_train'][101].reshape(DIM, DIM), cmap='gray', interpolation='none')
plt.title('Cluttered MNIST', fontsize=20)
plt.axis('off')
plt.show()


Building the model

We use a model where the localization network is a two layer convolution network which operates directly on the image input. The output from the localization network is a 6 dimensional vector specifying the parameters in the affine transformation.

The localization feeds into the transformer layer which applies the transformation to the image input. In our setup the transformer layer downsamples the input by a factor 3.

Finally a 2 layer convolution layer and 2 fully connected layers calculates the output probabilities.

The model

Input -> localization_network -> TransformerLayer -> output_network -> predictions
   |                                |
   >--------------------------------^

In [4]:
def build_model(input_width, input_height, output_dim,
                batch_size=BATCH_SIZE):
    ini = lasagne.init.HeUniform()
    l_in = lasagne.layers.InputLayer(shape=(None, 1, input_width, input_height),)

    # Localization network
    b = np.zeros((2, 3), dtype=theano.config.floatX)
    b[0, 0] = 1
    b[1, 1] = 1
    b = b.flatten()
    loc_l1 = pool(l_in, pool_size=(2, 2))
    loc_l2 = conv(
        loc_l1, num_filters=20, filter_size=(5, 5), W=ini)
    loc_l3 = pool(loc_l2, pool_size=(2, 2))
    loc_l4 = conv(loc_l3, num_filters=20, filter_size=(5, 5), W=ini)
    loc_l5 = lasagne.layers.DenseLayer(
        loc_l4, num_units=50, W=lasagne.init.HeUniform('relu'))
    loc_out = lasagne.layers.DenseLayer(
        loc_l5, num_units=6, b=b, W=lasagne.init.Constant(0.0), 
        nonlinearity=lasagne.nonlinearities.identity)
    
    # Transformer network
    l_trans1 = lasagne.layers.TransformerLayer(l_in, loc_out, downsample_factor=3.0)
    print "Transformer network output shape: ", l_trans1.output_shape
    
    # Classification network
    class_l1 = conv(
        l_trans1,
        num_filters=32,
        filter_size=(3, 3),
        nonlinearity=lasagne.nonlinearities.rectify,
        W=ini,
    )
    class_l2 = pool(class_l1, pool_size=(2, 2))
    class_l3 = conv(
        class_l2,
        num_filters=32,
        filter_size=(3, 3),
        nonlinearity=lasagne.nonlinearities.rectify,
        W=ini,
    )
    class_l4 = pool(class_l3, pool_size=(2, 2))
    class_l5 = lasagne.layers.DenseLayer(
        class_l4,
        num_units=256,
        nonlinearity=lasagne.nonlinearities.rectify,
        W=ini,
    )

    l_out = lasagne.layers.DenseLayer(
        class_l5,
        num_units=output_dim,
        nonlinearity=lasagne.nonlinearities.softmax,
        W=ini,
    )

    return l_out, l_trans1

model, l_transform = build_model(DIM, DIM, NUM_CLASSES)
model_params = lasagne.layers.get_all_params(model, trainable=True)


Transformer network output shape:  (None, 1, 20, 20)

In [5]:
X = T.tensor4()
y = T.ivector()

# training output
output_train = lasagne.layers.get_output(model, X, deterministic=False)

# evaluation output. Also includes output of transform for plotting
output_eval, transform_eval = lasagne.layers.get_output([model, l_transform], X, deterministic=True)

sh_lr = theano.shared(lasagne.utils.floatX(LEARNING_RATE))
cost = T.mean(T.nnet.categorical_crossentropy(output_train, y))
updates = lasagne.updates.adam(cost, model_params, learning_rate=sh_lr)

train = theano.function([X, y], [cost, output_train], updates=updates)
eval = theano.function([X], [output_eval, transform_eval])

In [6]:
def train_epoch(X, y):
    num_samples = X.shape[0]
    num_batches = int(np.ceil(num_samples / float(BATCH_SIZE)))
    costs = []
    correct = 0
    for i in range(num_batches):
        idx = range(i*BATCH_SIZE, np.minimum((i+1)*BATCH_SIZE, num_samples))
        X_batch = X[idx]
        y_batch = y[idx]
        cost_batch, output_train = train(X_batch, y_batch)
        costs += [cost_batch]
        preds = np.argmax(output_train, axis=-1)
        correct += np.sum(y_batch == preds)

    return np.mean(costs), correct / float(num_samples)


def eval_epoch(X, y):
    output_eval, transform_eval = eval(X)
    preds = np.argmax(output_eval, axis=-1)
    acc = np.mean(preds == y)
    return acc, transform_eval

Training


In [7]:
valid_accs, train_accs, test_accs = [], [], []
try:
    for n in range(NUM_EPOCHS):
        train_cost, train_acc = train_epoch(data['X_train'], data['y_train'])
        valid_acc, valid_trainsform = eval_epoch(data['X_valid'], data['y_valid'])
        test_acc, test_transform = eval_epoch(data['X_test'], data['y_test'])
        valid_accs += [valid_acc]
        test_accs += [test_acc]
        train_accs += [train_acc]

        if (n+1) % 20 == 0:
            new_lr = sh_lr.get_value() * 0.7
            print "New LR:", new_lr
            sh_lr.set_value(lasagne.utils.floatX(new_lr))

        print "Epoch {0}: Train cost {1}, Train acc {2}, val acc {3}, test acc {4}".format(
                n, train_cost, train_acc, valid_acc, test_acc)
except KeyboardInterrupt:
    pass


Epoch 0: Train cost 1.72300577164, Train acc 0.38824, val acc 0.6114, test acc 0.6087
Epoch 1: Train cost 0.867130100727, Train acc 0.71758, val acc 0.7745, test acc 0.7759
Epoch 2: Train cost 0.618825733662, Train acc 0.79848, val acc 0.8199, test acc 0.827
Epoch 3: Train cost 0.475057393312, Train acc 0.8489, val acc 0.8602, test acc 0.8613
Epoch 4: Train cost 0.369837403297, Train acc 0.88208, val acc 0.8697, test acc 0.8723
Epoch 5: Train cost 0.336995840073, Train acc 0.89126, val acc 0.8957, test acc 0.8974
Epoch 6: Train cost 0.288021206856, Train acc 0.90742, val acc 0.9005, test acc 0.8993
Epoch 7: Train cost 0.260697960854, Train acc 0.915, val acc 0.9081, test acc 0.9091
Epoch 8: Train cost 0.235620766878, Train acc 0.92484, val acc 0.917, test acc 0.9214
Epoch 9: Train cost 0.232491567731, Train acc 0.9245, val acc 0.9205, test acc 0.921
Epoch 10: Train cost 0.214803680778, Train acc 0.92916, val acc 0.9249, test acc 0.926
Epoch 11: Train cost 0.191879570484, Train acc 0.93728, val acc 0.9306, test acc 0.9317
Epoch 12: Train cost 0.187945634127, Train acc 0.93854, val acc 0.9365, test acc 0.937
Epoch 13: Train cost 0.177504748106, Train acc 0.94238, val acc 0.9329, test acc 0.933
Epoch 14: Train cost 0.161393344402, Train acc 0.9479, val acc 0.9246, test acc 0.9269
Epoch 15: Train cost 0.158181488514, Train acc 0.9482, val acc 0.9353, test acc 0.9382
Epoch 16: Train cost 0.162177875638, Train acc 0.94768, val acc 0.9399, test acc 0.9385
Epoch 17: Train cost 0.150974154472, Train acc 0.95074, val acc 0.9417, test acc 0.944
Epoch 18: Train cost 0.13878442347, Train acc 0.9546, val acc 0.9514, test acc 0.9481
New LR: 0.000700000033248
Epoch 19: Train cost 0.139381811023, Train acc 0.95302, val acc 0.9465, test acc 0.9477
Epoch 20: Train cost 0.115818083286, Train acc 0.96186, val acc 0.9498, test acc 0.9515
Epoch 21: Train cost 0.10844618082, Train acc 0.96364, val acc 0.9537, test acc 0.9544
Epoch 22: Train cost 0.104168988764, Train acc 0.9651, val acc 0.95, test acc 0.9522
Epoch 23: Train cost 0.100386917591, Train acc 0.96664, val acc 0.9523, test acc 0.9533
Epoch 24: Train cost 0.101429723203, Train acc 0.9666, val acc 0.9516, test acc 0.9557
Epoch 25: Train cost 0.0968987718225, Train acc 0.96804, val acc 0.9523, test acc 0.9556
Epoch 26: Train cost 0.0905688554049, Train acc 0.97016, val acc 0.955, test acc 0.9533
Epoch 27: Train cost 0.0892679914832, Train acc 0.97024, val acc 0.9574, test acc 0.9537
Epoch 28: Train cost 0.0790596753359, Train acc 0.9733, val acc 0.956, test acc 0.9577
Epoch 29: Train cost 0.0846520811319, Train acc 0.97228, val acc 0.9586, test acc 0.9575
Epoch 30: Train cost 0.0861563980579, Train acc 0.9711, val acc 0.9553, test acc 0.9579
Epoch 31: Train cost 0.084160938859, Train acc 0.9713, val acc 0.9574, test acc 0.9565
Epoch 32: Train cost 0.0740946382284, Train acc 0.97538, val acc 0.9583, test acc 0.9568
Epoch 33: Train cost 0.0750161111355, Train acc 0.97476, val acc 0.9522, test acc 0.9558
Epoch 34: Train cost 0.0719307512045, Train acc 0.97592, val acc 0.9534, test acc 0.9601
Epoch 35: Train cost 0.0688360854983, Train acc 0.97742, val acc 0.9568, test acc 0.9578
Epoch 36: Train cost 0.0659850463271, Train acc 0.97732, val acc 0.9586, test acc 0.9602
Epoch 37: Train cost 0.0669036284089, Train acc 0.97736, val acc 0.9606, test acc 0.9581
Epoch 38: Train cost 0.0615548193455, Train acc 0.9792, val acc 0.9584, test acc 0.9538
New LR: 0.000490000023274
Epoch 39: Train cost 0.0617390647531, Train acc 0.9795, val acc 0.9585, test acc 0.9574
Epoch 40: Train cost 0.0535897053778, Train acc 0.9818, val acc 0.9563, test acc 0.9582
Epoch 41: Train cost 0.0471548065543, Train acc 0.98434, val acc 0.9622, test acc 0.9613
Epoch 42: Train cost 0.0408403426409, Train acc 0.98648, val acc 0.9635, test acc 0.9624
Epoch 43: Train cost 0.0405819378793, Train acc 0.98642, val acc 0.9636, test acc 0.9619
Epoch 44: Train cost 0.0374028384686, Train acc 0.98754, val acc 0.9606, test acc 0.9614
Epoch 45: Train cost 0.0365789830685, Train acc 0.98828, val acc 0.9591, test acc 0.9574
Epoch 46: Train cost 0.0347327440977, Train acc 0.98848, val acc 0.962, test acc 0.9613

Plot results


In [19]:
plt.figure(figsize=(9,9))
plt.plot(1-np.array(train_accs), label='Training Error')
plt.plot(1-np.array(valid_accs), label='Validation Error')
plt.legend(fontsize=20)
plt.xlabel('Epoch', fontsize=20)
plt.ylabel('Error', fontsize=20)
plt.show()



In [41]:
plt.figure(figsize=(7,14))
for i in range(3):
    plt.subplot(321+i*2)
    plt.imshow(data['X_test'][i].reshape(DIM, DIM), cmap='gray', interpolation='none')
    if i == 0:
        plt.title('Original 60x60', fontsize=20)
    plt.axis('off')
    plt.subplot(322+i*2)
    plt.imshow(test_transform[i].reshape(DIM//3, DIM//3), cmap='gray', interpolation='none')
    if i == 0:
        plt.title('Transformed 20x20', fontsize=20)
    plt.axis('off')
plt.tight_layout()


References

[1] Jaderberg, Max, et al. "Spatial Transformer Networks." arXiv preprint arXiv:1506.02025 (2015).