Demo of the paper "Lateral Connections in Denoising Autoencoders Support Supervised Learning"

What's remarkable about Rasmus et. al. 2015 is that they are able to achieve state-of-the-art performance on permutation invariant MNIST without dropout (though the denoising step could be performing a similar form of regularization). Unlike previous work with autoencoders, they acheive this great performance with the same semi-supervised cost function for the entire training process.

It's also worth noting their model claims by far the best semisupervised performance with about ~.75% error with only 500 labeled MNIST digits.


In [1]:
import time
import numpy as np
import theano
import theano.tensor as T
import peano
import peano.pops as P
from pylearn2.space import CompositeSpace, VectorSpace

dtype = theano.config.floatX


Using gpu device 0: GeForce GTX TITAN (CNMeM is disabled)

Build the model

The lateral connections in this model make it slightly more tedious to build since many layers depend on multiple previous layers.


In [2]:
z1 = P.nnet.Sequential('z1')
z1.add(P.nnet.Linear(784, 1000))
z1.add(P.nnet.BatchNormalization(1000))

z2 = P.nnet.Sequential('z2')
z2.add(P.nnet.Linear(1000, 500))
z2.add(P.nnet.BatchNormalization(500))

z3 = P.nnet.Sequential('z3')
z3.add(P.nnet.Linear(500, 250))
z3.add(P.nnet.BatchNormalization(250))

z4 = P.nnet.Sequential('z4')
z4.add(P.nnet.Linear(250, 250))
z4.add(P.nnet.BatchNormalization(250))

z5 = P.nnet.Sequential('z5')
z5.add(P.nnet.Linear(250, 250))
z5.add(P.nnet.BatchNormalization(250))

ll0 = P.nnet.Lateral(784)
ll1 = P.nnet.Lateral(1000)
ll2 = P.nnet.Lateral(500)
ll3 = P.nnet.Lateral(250)
ll4 = P.nnet.Lateral(250)
ll5 = P.nnet.Lateral(250)
ll6 = P.nnet.Lateral(10)

u6 = P.nnet.Linear(10, 250)
u5 = P.nnet.Linear(250, 250)
u4 = P.nnet.Linear(250, 250)
u3 = P.nnet.Linear(250, 500)
u2 = P.nnet.Linear(500, 1000)
u1 = P.nnet.Linear(1000, 784)

sl = P.nnet.Sequential('sl')
sl.add(P.nnet.Linear(250, 10))
sl.add(T.nnet.softmax)

xt = T.matrix(dtype=dtype)

z1f = z1.apply(xt)
h1 = T.nnet.relu(z1f)

z2f = z2.apply(h1)
h2 = T.nnet.relu(z2f)

z3f = z3.apply(h2)
h3 = T.nnet.relu(z3f)

z4f = z4.apply(h3)
h4 = T.nnet.relu(z4f)

z5f = z5.apply(h4)
h5 = T.nnet.relu(z5f)

y_s = sl.apply(z5f)
zh6 = ll6.apply(y_s, 0.)
u6f = u6.apply(zh6)

zh5 = ll5.apply(z5f, u6f)
u5f = u5.apply(zh5)

zh4 = ll4.apply(z4f, u5f)
u4f = u4.apply(zh4)

zh3 = ll3.apply(z3f, u4f)
u3f = u3.apply(zh3)

zh2 = ll2.apply(z2f, u3f)
u2f = u2.apply(zh2)

zh1 = ll1.apply(z1f, u2f)
u1f = u1.apply(zh1)

xh = ll0.apply(xt, u1f)

Gather the parameters and construct the cost functions


In [3]:
params = []
for l in [z1,z2,z3,z4,z5,ll0,ll1,ll2,ll3,ll4,ll5,ll6,u6,u5,u4,u3,u2,u1,sl]:
    params += l.params

x_true = T.matrix(dtype=dtype)
y_true = T.matrix(dtype=dtype)
lr = T.scalar(dtype=dtype)

r_cost = P.cost.mean_squared_error(x_true, xh)
s_cost = P.cost.cross_entropy(y_true, y_s)

cost = s_cost + 500.*r_cost
misclass_cost = T.neq(T.argmax(y_true, axis=1), T.argmax(y_s, axis=1)).mean()

Take derivatives and compile the appropriate functions


In [4]:
gparams = T.grad(cost, wrt=params)
updates = peano.optimizer.adam_update(params, gparams, alpha=lr)

learn_mlp_fn = theano.function(inputs = [xt, x_true, y_true, lr],
                                outputs = cost,
                                updates = updates)

misclass_mlp_fn = theano.function(inputs = [xt, y_true],
                                    outputs = misclass_cost)

encode_mlp_fn = theano.function(inputs = [xt],
                                    outputs = xh)

decode_mlp_fn = theano.function(inputs = [xt, y_s],
                                    outputs = xh)


WARNING (theano.gof.cmodule): WARNING: your Theano flags `gcc.cxxflags` specify an `-march=X` flags.
         It is better to let Theano/g++ find it automatically, but we don't do it now
WARNING:theano.gof.cmodule:WARNING: your Theano flags `gcc.cxxflags` specify an `-march=X` flags.
         It is better to let Theano/g++ find it automatically, but we don't do it now

In accordance with the paper, we are training on the entire MNIST training set (all 60000 digits). After 100 epochs we evalute on the MNIST test set (10000 digits). Since this is the actual test set, we are not allowed to tweak anything. The test set error is the final error for this model.


In [5]:
from pylearn2.datasets import mnist
ds = mnist.MNIST(which_set = 'train', start=0, stop=60000)
val = mnist.MNIST(which_set = 'test', start=0, stop=10000)
val_X, val_y = val.get_data()
val_y = np.squeeze(np.eye(10)[val_y]).astype(dtype)

data_space = VectorSpace(dim=784)
label_space = VectorSpace(dim= 10)

lrd = np.linspace(.002,0.,50).astype(dtype)
for i in range(100):
    cost = 0.
    misclass = 0.
    ds_iter = ds.iterator(mode='sequential', batch_size=100, data_specs=(CompositeSpace((data_space, label_space)), ('features', 'targets')))
    t0 = time.time()
    for X,y in ds_iter:
        if i < 50:
            learn_mlp_fn(X+0.3*np.random.randn(*X.shape).astype(dtype) , X, y, 0.002)
        else:
            learn_mlp_fn(X+0.3*np.random.randn(*X.shape).astype(dtype) , X, y, lrd[i-50])
    print 'epoch', i, time.time()-t0, 'seconds'
print 'Test set error:', misclass_mlp_fn(val_X, val_y)


epoch 0 14.3241860867 seconds
epoch 1 14.3138051033 seconds
epoch 2 14.302243948 seconds
epoch 3 14.2785608768 seconds
epoch 4 14.2825241089 seconds
epoch 5 14.2982139587 seconds
epoch 6 14.3066959381 seconds
epoch 7 14.3182621002 seconds
epoch 8 14.2831978798 seconds
epoch 9 14.3160161972 seconds
epoch 10 14.3053920269 seconds
epoch 11 14.277520895 seconds
epoch 12 14.2811539173 seconds
epoch 13 14.2720370293 seconds
epoch 14 14.30133605 seconds
epoch 15 14.285476923 seconds
epoch 16 14.2630209923 seconds
epoch 17 14.2744040489 seconds
epoch 18 14.2751760483 seconds
epoch 19 14.2680740356 seconds
epoch 20 14.2793338299 seconds
epoch 21 14.2843091488 seconds
epoch 22 14.2869808674 seconds
epoch 23 14.276250124 seconds
epoch 24 14.2990670204 seconds
epoch 25 14.2792639732 seconds
epoch 26 14.2720110416 seconds
epoch 27 14.2625980377 seconds
epoch 28 14.259770155 seconds
epoch 29 14.2725348473 seconds
epoch 30 14.3095350266 seconds
epoch 31 14.3106970787 seconds
epoch 32 14.3134691715 seconds
epoch 33 14.3098089695 seconds
epoch 34 14.2686629295 seconds
epoch 35 14.267441988 seconds
epoch 36 14.2607939243 seconds
epoch 37 14.2618298531 seconds
epoch 38 14.2534959316 seconds
epoch 39 14.2622108459 seconds
epoch 40 14.2575678825 seconds
epoch 41 14.2629668713 seconds
epoch 42 14.2600951195 seconds
epoch 43 14.2706940174 seconds
epoch 44 14.2730967999 seconds
epoch 45 14.2634661198 seconds
epoch 46 14.2567288876 seconds
epoch 47 14.2662849426 seconds
epoch 48 14.2648868561 seconds
epoch 49 14.2646889687 seconds
epoch 50 14.2747149467 seconds
epoch 51 14.284815073 seconds
epoch 52 14.3004801273 seconds
epoch 53 14.2752230167 seconds
epoch 54 14.3085579872 seconds
epoch 55 14.2957808971 seconds
epoch 56 14.3062949181 seconds
epoch 57 14.2829530239 seconds
epoch 58 14.2828259468 seconds
epoch 59 14.3139169216 seconds
epoch 60 14.2948460579 seconds
epoch 61 14.2714400291 seconds
epoch 62 14.2658820152 seconds
epoch 63 14.258245945 seconds
epoch 64 14.295976162 seconds
epoch 65 14.2694180012 seconds
epoch 66 14.2725749016 seconds
epoch 67 14.3092057705 seconds
epoch 68 14.2872700691 seconds
epoch 69 14.2699389458 seconds
epoch 70 14.2696049213 seconds
epoch 71 14.2747981548 seconds
epoch 72 14.2779450417 seconds
epoch 73 14.2599079609 seconds
epoch 74 14.2562818527 seconds
epoch 75 14.2988798618 seconds
epoch 76 14.2895288467 seconds
epoch 77 14.258767128 seconds
epoch 78 14.2590417862 seconds
epoch 79 14.265556097 seconds
epoch 80 14.289244175 seconds
epoch 81 14.2696130276 seconds
epoch 82 14.2715451717 seconds
epoch 83 14.268184185 seconds
epoch 84 14.2826249599 seconds
epoch 85 14.2723329067 seconds
epoch 86 14.2629330158 seconds
epoch 87 14.2630178928 seconds
epoch 88 14.2938277721 seconds
epoch 89 14.3064641953 seconds
epoch 90 14.2896800041 seconds
epoch 91 14.2934010029 seconds
epoch 92 14.3221580982 seconds
epoch 93 14.3061950207 seconds
epoch 94 14.311524868 seconds
epoch 95 14.3170070648 seconds
epoch 96 14.2952461243 seconds
epoch 97 14.2871549129 seconds
epoch 98 14.3108949661 seconds
epoch 99 14.3015282154 seconds
Test set error: 0.0092

So with everything said and done, the performance we acheive is ~.92% error on the test set which is very good but off from the .68% claimed in the paper. There could be finicky parameters like weight initilization that might account for the difference.