Trial for spatial transformer network


In [1]:
%matplotlib inline
import os
os.environ['THEANO_FLAGS']='device=gpu0'

import matplotlib
import numpy as np
np.random.seed(123)
import matplotlib.pyplot as plt
import lasagne
import theano
import theano.tensor as T
conv = lasagne.layers.Conv2DLayer
pool = lasagne.layers.MaxPool2DLayer
NUM_EPOCHS = 500
BATCH_SIZE = 256
LEARNING_RATE = 0.001
DIM = 60
NUM_CLASSES = 10
mnist_cluttered = "mnist_cluttered_60x60_6distortions.npz"


Using gpu device 0: GeForce GT 740M (CNMeM is disabled)

In [4]:
!wget -c -N https://s3.amazonaws.com/lasagne/recipes/datasets/mnist_cluttered_60x60_6distortions.npz


--2016-02-10 01:27:18--  https://s3.amazonaws.com/lasagne/recipes/datasets/mnist_cluttered_60x60_6distortions.npz
Resolving s3.amazonaws.com (s3.amazonaws.com)... 54.231.49.220
Connecting to s3.amazonaws.com (s3.amazonaws.com)|54.231.49.220|:443... connected.
HTTP request sent, awaiting response... 206 Partial Content
Length: 43046126 (41M), 40174182 (38M) remaining [application/octet-stream]
The sizes do not match (local 2871944) -- retrieving.

--2016-02-10 01:27:19--  https://s3.amazonaws.com/lasagne/recipes/datasets/mnist_cluttered_60x60_6distortions.npz
Reusing existing connection to s3.amazonaws.com:443.
HTTP request sent, awaiting response... 206 Partial Content
Length: 43046126 (41M), 40174182 (38M) remaining [application/octet-stream]
Saving to: ‘mnist_cluttered_60x60_6distortions.npz’

100%[++====================================>] 4,30,46,126  144KB/s   in 8m 42s 

2016-02-10 01:36:02 (75.2 KB/s) - ‘mnist_cluttered_60x60_6distortions.npz’ saved [43046126/43046126]


In [5]:
def load_data():
    data = np.load(mnist_cluttered)
    X_train, y_train = data['x_train'], np.argmax(data['y_train'], axis=-1)
    X_valid, y_valid = data['x_valid'], np.argmax(data['y_valid'], axis=-1)
    X_test, y_test = data['x_test'], np.argmax(data['y_test'], axis=-1)

    # reshape for convolutions
    X_train = X_train.reshape((X_train.shape[0], 1, DIM, DIM))
    X_valid = X_valid.reshape((X_valid.shape[0], 1, DIM, DIM))
    X_test = X_test.reshape((X_test.shape[0], 1, DIM, DIM))
    
    print "Train samples:", X_train.shape
    print "Validation samples:", X_valid.shape
    print "Test samples:", X_test.shape

    return dict(
        X_train=lasagne.utils.floatX(X_train),
        y_train=y_train.astype('int32'),
        X_valid=lasagne.utils.floatX(X_valid),
        y_valid=y_valid.astype('int32'),
        X_test=lasagne.utils.floatX(X_test),
        y_test=y_test.astype('int32'),
        num_examples_train=X_train.shape[0],
        num_examples_valid=X_valid.shape[0],
        num_examples_test=X_test.shape[0],
        input_height=X_train.shape[2],
        input_width=X_train.shape[3],
        output_dim=10,)
data = load_data()


Train samples: (50000, 1, 60, 60)
Validation samples: (10000, 1, 60, 60)
Test samples: (10000, 1, 60, 60)

In [6]:
plt.figure(figsize=(7,7))
plt.imshow(data['X_train'][101].reshape(DIM, DIM), cmap='gray', interpolation='none')
plt.title('Cluttered MNIST', fontsize=20)
plt.axis('off')
plt.show()



In [7]:
def build_model(input_width, input_height, output_dim,
                batch_size=BATCH_SIZE):
    ini = lasagne.init.HeUniform()
    l_in = lasagne.layers.InputLayer(shape=(None, 1, input_width, input_height),)

    # Localization network
    b = np.zeros((2, 3), dtype=theano.config.floatX)
    b[0, 0] = 1
    b[1, 1] = 1
    b = b.flatten()
    loc_l1 = pool(l_in, pool_size=(2, 2))
    loc_l2 = conv(
        loc_l1, num_filters=20, filter_size=(5, 5), W=ini)
    loc_l3 = pool(loc_l2, pool_size=(2, 2))
    loc_l4 = conv(loc_l3, num_filters=20, filter_size=(5, 5), W=ini)
    loc_l5 = lasagne.layers.DenseLayer(
        loc_l4, num_units=50, W=lasagne.init.HeUniform('relu'))
    loc_out = lasagne.layers.DenseLayer(
        loc_l5, num_units=6, b=b, W=lasagne.init.Constant(0.0), 
        nonlinearity=lasagne.nonlinearities.identity)
    
    # Transformer network
    l_trans1 = lasagne.layers.TransformerLayer(l_in, loc_out, downsample_factor=3.0)
    print "Transformer network output shape: ", l_trans1.output_shape
    
    # Classification network
    class_l1 = conv(
        l_trans1,
        num_filters=32,
        filter_size=(3, 3),
        nonlinearity=lasagne.nonlinearities.rectify,
        W=ini,
    )
    class_l2 = pool(class_l1, pool_size=(2, 2))
    class_l3 = conv(
        class_l2,
        num_filters=32,
        filter_size=(3, 3),
        nonlinearity=lasagne.nonlinearities.rectify,
        W=ini,
    )
    class_l4 = pool(class_l3, pool_size=(2, 2))
    class_l5 = lasagne.layers.DenseLayer(
        class_l4,
        num_units=256,
        nonlinearity=lasagne.nonlinearities.rectify,
        W=ini,
    )

    l_out = lasagne.layers.DenseLayer(
        class_l5,
        num_units=output_dim,
        nonlinearity=lasagne.nonlinearities.softmax,
        W=ini,
    )

    return l_out, l_trans1

model, l_transform = build_model(DIM, DIM, NUM_CLASSES)
model_params = lasagne.layers.get_all_params(model, trainable=True)


Transformer network output shape:  (None, 1, 20, 20)

In [8]:
X = T.tensor4()
y = T.ivector()

# training output
output_train = lasagne.layers.get_output(model, X, deterministic=False)

# evaluation output. Also includes output of transform for plotting
output_eval, transform_eval = lasagne.layers.get_output([model, l_transform], X, deterministic=True)

sh_lr = theano.shared(lasagne.utils.floatX(LEARNING_RATE))
cost = T.mean(T.nnet.categorical_crossentropy(output_train, y))
updates = lasagne.updates.adam(cost, model_params, learning_rate=sh_lr)

train = theano.function([X, y], [cost, output_train], updates=updates)
eval = theano.function([X], [output_eval, transform_eval])

In [9]:
def train_epoch(X, y):
    num_samples = X.shape[0]
    num_batches = int(np.ceil(num_samples / float(BATCH_SIZE)))
    costs = []
    correct = 0
    for i in range(num_batches):
        idx = range(i*BATCH_SIZE, np.minimum((i+1)*BATCH_SIZE, num_samples))
        X_batch = X[idx]
        y_batch = y[idx]
        cost_batch, output_train = train(X_batch, y_batch)
        costs += [cost_batch]
        preds = np.argmax(output_train, axis=-1)
        correct += np.sum(y_batch == preds)

    return np.mean(costs), correct / float(num_samples)


def eval_epoch(X, y):
    output_eval, transform_eval = eval(X)
    preds = np.argmax(output_eval, axis=-1)
    acc = np.mean(preds == y)
    return acc, transform_eval

In [10]:
valid_accs, train_accs, test_accs = [], [], []
try:
    for n in range(NUM_EPOCHS):
        train_cost, train_acc = train_epoch(data['X_train'], data['y_train'])
        valid_acc, valid_trainsform = eval_epoch(data['X_valid'], data['y_valid'])
        test_acc, test_transform = eval_epoch(data['X_test'], data['y_test'])
        valid_accs += [valid_acc]
        test_accs += [test_acc]
        train_accs += [train_acc]

        if (n+1) % 20 == 0:
            new_lr = sh_lr.get_value() * 0.7
            print "New LR:", new_lr
            sh_lr.set_value(lasagne.utils.floatX(new_lr))

        print "Epoch {0}: Train cost {1}, Train acc {2}, val acc {3}, test acc {4}".format(
                n, train_cost, train_acc, valid_acc, test_acc)
except KeyboardInterrupt:
    pass


Epoch 0: Train cost 1.72685241699, Train acc 0.3871, val acc 0.6362, test acc 0.6351
Epoch 1: Train cost 0.882956385612, Train acc 0.71368, val acc 0.7879, test acc 0.7805
Epoch 2: Train cost 0.62586081028, Train acc 0.79688, val acc 0.8238, test acc 0.824
Epoch 3: Train cost 0.495990216732, Train acc 0.84106, val acc 0.8597, test acc 0.8627
Epoch 4: Train cost 0.399075329304, Train acc 0.87132, val acc 0.871, test acc 0.873
Epoch 5: Train cost 0.345367133617, Train acc 0.88812, val acc 0.888, test acc 0.8892
Epoch 6: Train cost 0.324547797441, Train acc 0.89486, val acc 0.8905, test acc 0.8868
Epoch 7: Train cost 0.276472955942, Train acc 0.9112, val acc 0.9071, test acc 0.9058
Epoch 8: Train cost 0.259725809097, Train acc 0.91544, val acc 0.9094, test acc 0.9096
Epoch 9: Train cost 0.244311392307, Train acc 0.92178, val acc 0.9208, test acc 0.9194
Epoch 10: Train cost 0.219494447112, Train acc 0.92828, val acc 0.9313, test acc 0.9309
Epoch 11: Train cost 0.21216057241, Train acc 0.93154, val acc 0.9349, test acc 0.9313
Epoch 12: Train cost 0.190080523491, Train acc 0.93848, val acc 0.9286, test acc 0.9312
Epoch 13: Train cost 0.192602977157, Train acc 0.93694, val acc 0.9337, test acc 0.9341
Epoch 14: Train cost 0.176975011826, Train acc 0.942, val acc 0.939, test acc 0.9385
Epoch 15: Train cost 0.169229134917, Train acc 0.94522, val acc 0.9276, test acc 0.9268
Epoch 16: Train cost 0.164277136326, Train acc 0.9458, val acc 0.9383, test acc 0.9352
Epoch 17: Train cost 0.155019000173, Train acc 0.9493, val acc 0.9443, test acc 0.9468
Epoch 18: Train cost 0.151129618287, Train acc 0.95038, val acc 0.945, test acc 0.9447
New LR: 0.000700000033248
Epoch 19: Train cost 0.155249327421, Train acc 0.94976, val acc 0.9377, test acc 0.9392
Epoch 20: Train cost 0.130315035582, Train acc 0.95706, val acc 0.9504, test acc 0.9495
Epoch 21: Train cost 0.123107776046, Train acc 0.95894, val acc 0.9438, test acc 0.9433

In [11]:
plt.figure(figsize=(7,14))
for i in range(3):
    plt.subplot(321+i*2)
    plt.imshow(data['X_test'][i].reshape(DIM, DIM), cmap='gray', interpolation='none')
    if i == 0:
        plt.title('Original 60x60', fontsize=20)
    plt.axis('off')
    plt.subplot(322+i*2)
    plt.imshow(test_transform[i].reshape(DIM//3, DIM//3), cmap='gray', interpolation='none')
    if i == 0:
        plt.title('Transformed 20x20', fontsize=20)
    plt.axis('off')
plt.tight_layout()



In [ ]: