In [1]:
%matplotlib inline
import os
os.environ['THEANO_FLAGS']='device=gpu0'
import matplotlib
import numpy as np
np.random.seed(123)
import matplotlib.pyplot as plt
import lasagne
import theano
import theano.tensor as T
conv = lasagne.layers.Conv2DLayer
pool = lasagne.layers.MaxPool2DLayer
NUM_EPOCHS = 500
BATCH_SIZE = 256
LEARNING_RATE = 0.001
DIM = 60
NUM_CLASSES = 10
mnist_cluttered = "mnist_cluttered_60x60_6distortions.npz"
We use lasagne to classify cluttered MNIST digits using the spatial transformer network introduced in [1]. The spatial Transformer Network applies a learned affine transformation to its input.
We test the spatial transformer network using cluttered MNIST data.
Download the data (41 mb) with:
In [1]:
!wget -N https://s3.amazonaws.com/lasagne/recipes/datasets/mnist_cluttered_60x60_6distortions.npz
In [2]:
def load_data():
data = np.load(mnist_cluttered)
X_train, y_train = data['x_train'], np.argmax(data['y_train'], axis=-1)
X_valid, y_valid = data['x_valid'], np.argmax(data['y_valid'], axis=-1)
X_test, y_test = data['x_test'], np.argmax(data['y_test'], axis=-1)
# reshape for convolutions
X_train = X_train.reshape((X_train.shape[0], 1, DIM, DIM))
X_valid = X_valid.reshape((X_valid.shape[0], 1, DIM, DIM))
X_test = X_test.reshape((X_test.shape[0], 1, DIM, DIM))
print "Train samples:", X_train.shape
print "Validation samples:", X_valid.shape
print "Test samples:", X_test.shape
return dict(
X_train=lasagne.utils.floatX(X_train),
y_train=y_train.astype('int32'),
X_valid=lasagne.utils.floatX(X_valid),
y_valid=y_valid.astype('int32'),
X_test=lasagne.utils.floatX(X_test),
y_test=y_test.astype('int32'),
num_examples_train=X_train.shape[0],
num_examples_valid=X_valid.shape[0],
num_examples_test=X_test.shape[0],
input_height=X_train.shape[2],
input_width=X_train.shape[3],
output_dim=10,)
data = load_data()
In [11]:
plt.figure(figsize=(7,7))
plt.imshow(data['X_train'][101].reshape(DIM, DIM), cmap='gray', interpolation='none')
plt.title('Cluttered MNIST', fontsize=20)
plt.axis('off')
plt.show()
We use a model where the localization network is a two layer convolution network which operates directly on the image input. The output from the localization network is a 6 dimensional vector specifying the parameters in the affine transformation.
The localization feeds into the transformer layer which applies the transformation to the image input. In our setup the transformer layer downsamples the input by a factor 3.
Finally a 2 layer convolution layer and 2 fully connected layers calculates the output probabilities.
The model
Input -> localization_network -> TransformerLayer -> output_network -> predictions
| |
>--------------------------------^
In [4]:
def build_model(input_width, input_height, output_dim,
batch_size=BATCH_SIZE):
ini = lasagne.init.HeUniform()
l_in = lasagne.layers.InputLayer(shape=(None, 1, input_width, input_height),)
# Localization network
b = np.zeros((2, 3), dtype=theano.config.floatX)
b[0, 0] = 1
b[1, 1] = 1
b = b.flatten()
loc_l1 = pool(l_in, pool_size=(2, 2))
loc_l2 = conv(
loc_l1, num_filters=20, filter_size=(5, 5), W=ini)
loc_l3 = pool(loc_l2, pool_size=(2, 2))
loc_l4 = conv(loc_l3, num_filters=20, filter_size=(5, 5), W=ini)
loc_l5 = lasagne.layers.DenseLayer(
loc_l4, num_units=50, W=lasagne.init.HeUniform('relu'))
loc_out = lasagne.layers.DenseLayer(
loc_l5, num_units=6, b=b, W=lasagne.init.Constant(0.0),
nonlinearity=lasagne.nonlinearities.identity)
# Transformer network
l_trans1 = lasagne.layers.TransformerLayer(l_in, loc_out, downsample_factor=3.0)
print "Transformer network output shape: ", l_trans1.output_shape
# Classification network
class_l1 = conv(
l_trans1,
num_filters=32,
filter_size=(3, 3),
nonlinearity=lasagne.nonlinearities.rectify,
W=ini,
)
class_l2 = pool(class_l1, pool_size=(2, 2))
class_l3 = conv(
class_l2,
num_filters=32,
filter_size=(3, 3),
nonlinearity=lasagne.nonlinearities.rectify,
W=ini,
)
class_l4 = pool(class_l3, pool_size=(2, 2))
class_l5 = lasagne.layers.DenseLayer(
class_l4,
num_units=256,
nonlinearity=lasagne.nonlinearities.rectify,
W=ini,
)
l_out = lasagne.layers.DenseLayer(
class_l5,
num_units=output_dim,
nonlinearity=lasagne.nonlinearities.softmax,
W=ini,
)
return l_out, l_trans1
model, l_transform = build_model(DIM, DIM, NUM_CLASSES)
model_params = lasagne.layers.get_all_params(model, trainable=True)
In [5]:
X = T.tensor4()
y = T.ivector()
# training output
output_train = lasagne.layers.get_output(model, X, deterministic=False)
# evaluation output. Also includes output of transform for plotting
output_eval, transform_eval = lasagne.layers.get_output([model, l_transform], X, deterministic=True)
sh_lr = theano.shared(lasagne.utils.floatX(LEARNING_RATE))
cost = T.mean(T.nnet.categorical_crossentropy(output_train, y))
updates = lasagne.updates.adam(cost, model_params, learning_rate=sh_lr)
train = theano.function([X, y], [cost, output_train], updates=updates)
eval = theano.function([X], [output_eval, transform_eval])
In [6]:
def train_epoch(X, y):
num_samples = X.shape[0]
num_batches = int(np.ceil(num_samples / float(BATCH_SIZE)))
costs = []
correct = 0
for i in range(num_batches):
idx = range(i*BATCH_SIZE, np.minimum((i+1)*BATCH_SIZE, num_samples))
X_batch = X[idx]
y_batch = y[idx]
cost_batch, output_train = train(X_batch, y_batch)
costs += [cost_batch]
preds = np.argmax(output_train, axis=-1)
correct += np.sum(y_batch == preds)
return np.mean(costs), correct / float(num_samples)
def eval_epoch(X, y):
output_eval, transform_eval = eval(X)
preds = np.argmax(output_eval, axis=-1)
acc = np.mean(preds == y)
return acc, transform_eval
In [7]:
valid_accs, train_accs, test_accs = [], [], []
try:
for n in range(NUM_EPOCHS):
train_cost, train_acc = train_epoch(data['X_train'], data['y_train'])
valid_acc, valid_trainsform = eval_epoch(data['X_valid'], data['y_valid'])
test_acc, test_transform = eval_epoch(data['X_test'], data['y_test'])
valid_accs += [valid_acc]
test_accs += [test_acc]
train_accs += [train_acc]
if (n+1) % 20 == 0:
new_lr = sh_lr.get_value() * 0.7
print "New LR:", new_lr
sh_lr.set_value(lasagne.utils.floatX(new_lr))
print "Epoch {0}: Train cost {1}, Train acc {2}, val acc {3}, test acc {4}".format(
n, train_cost, train_acc, valid_acc, test_acc)
except KeyboardInterrupt:
pass
In [19]:
plt.figure(figsize=(9,9))
plt.plot(1-np.array(train_accs), label='Training Error')
plt.plot(1-np.array(valid_accs), label='Validation Error')
plt.legend(fontsize=20)
plt.xlabel('Epoch', fontsize=20)
plt.ylabel('Error', fontsize=20)
plt.show()
In [41]:
plt.figure(figsize=(7,14))
for i in range(3):
plt.subplot(321+i*2)
plt.imshow(data['X_test'][i].reshape(DIM, DIM), cmap='gray', interpolation='none')
if i == 0:
plt.title('Original 60x60', fontsize=20)
plt.axis('off')
plt.subplot(322+i*2)
plt.imshow(test_transform[i].reshape(DIM//3, DIM//3), cmap='gray', interpolation='none')
if i == 0:
plt.title('Transformed 20x20', fontsize=20)
plt.axis('off')
plt.tight_layout()