In [1]:
import mxnet as mx
import numpy as np
import os
import logging
import matplotlib.pyplot as plt
import matplotlib.cm as cm

Building a Variational Autoencoder in MXNet

Xiaoyu Lu, July 5th, 2017

This tutorial guides you through the process of building a variational encoder in MXNet. in this notebook we'll focus on an example unsing the MNIST handwritten digit recognition dataset. Refer to Auto-Encoding Variational Bayes for more details on the model description.

1. Loading the Data

We first load the MNIST dataset, which contains 60000 trainings and 10000 test examples. The following code import required modules and load the data. These images are stored in a 4-D matrix with shape (batch_size, num_channels, width, height). For the MNIST dataset, there is only one color channel, and both width and height are 28, so we reshape each image as a 28x28 array. See below for a visualization.


In [7]:
mnist = mx.test_utils.get_mnist()
image = np.reshape(mnist['train_data'],(60000,28*28))
label = image
image_test = np.reshape(mnist['test_data'],(10000,28*28))
label_test = image_test
[N,features] = np.shape(image)          #number of examples and features

In [3]:
f, (ax1, ax2, ax3, ax4) = plt.subplots(1,4,  sharex='col', sharey='row',figsize=(12,3))
ax1.imshow(np.reshape(image[0,:],(28,28)), interpolation='nearest', cmap=cm.Greys)
ax2.imshow(np.reshape(image[1,:],(28,28)), interpolation='nearest', cmap=cm.Greys)
ax3.imshow(np.reshape(image[2,:],(28,28)), interpolation='nearest', cmap=cm.Greys)
ax4.imshow(np.reshape(image[3,:],(28,28)), interpolation='nearest', cmap=cm.Greys)
plt.show()


We can optionally save the parameters in the directory variable 'model_prefix'. We first create data iterators for MXNet, with each batch of data containing 100 images.


In [4]:
model_prefix = None

batch_size = 100
nd_iter = mx.io.NDArrayIter(data={'data':image},label={'loss_label':label},
                            batch_size = batch_size)
nd_iter_test = mx.io.NDArrayIter(data={'data':image_test},label={'loss_label':label_test},
                            batch_size = batch_size)

2. Building the Network Architecture

2.1 Gaussian MLP as encoder

Next we constuct the neural network, as in the paper, we use Multilayer Perceptron (MLP) for both the encoder and decoder. For encoder, a Gaussian MLP is used:

\begin{align} \log q_{\phi}(z|x) &= \log \mathcal{N}(z:\mu,\sigma^2I) \\ \textit{ where } \mu &= W_2h+b_2, \log \sigma^2 = W_3h+b_3\\ h &= \tanh(W_1x+b_1) \end{align}

where $\{W_1,W_2,W_3,b_1,b_2,b_3\}$ are the weights and biases of the MLP. Note below that encoder_mu and encoder_logvar are symbols, can use get_internals() to get the values of them, after which we can sample the latent variable $z$.


In [5]:
## define data and loss labels as symbols 
data = mx.sym.var('data')
loss_label = mx.sym.var('loss_label')

## define fully connected and activation layers for the encoder, where we used tanh activation function.
encoder_h  = mx.sym.FullyConnected(data=data, name="encoder_h",num_hidden=400)
act_h = mx.sym.Activation(data=encoder_h, act_type="tanh",name="activation_h")

## define mu and log variance which are the fully connected layers of the previous activation layer
mu  = mx.sym.FullyConnected(data=act_h, name="mu",num_hidden = 5)
logvar  = mx.sym.FullyConnected(data=act_h, name="logvar",num_hidden = 5)

## sample the latent variables z according to Normal(mu,var)
z = mu + np.multiply(mx.symbol.exp(0.5*logvar),mx.symbol.random_normal(loc=0, scale=1,shape=np.shape(logvar.get_internals()["logvar_output"])))

2.2 Bernoulli MLP as decoder

In this case let $p_\theta(x|z)$ be a multivariate Bernoulli whose probabilities are computed from $z$ with a feed forward neural network with a single hidden layer:

\begin{align} \log p(x|z) &= \sum_{i=1}^D x_i\log y_i + (1-x_i)\log (1-y_i) \\ \textit{ where } y &= f_\sigma(W_5\tanh (W_4z+b_4)+b_5) \end{align}

where $f_\sigma(\dot)$ is the elementwise sigmoid activation function, $\{W_4,W_5,b_4,b_5\}$ are the weights and biases of the decoder MLP. A Bernouilli likelihood is suitable for this type of data but you can easily extend it to other likelihood types by parsing into the argument likelihood in the VAE class, see section 4 for details.


In [6]:
# define fully connected and tanh activation layers for the decoder
decoder_z = mx.sym.FullyConnected(data=z, name="decoder_z",num_hidden=400)
act_z = mx.sym.Activation(data=decoder_z, act_type="tanh",name="activation_z")

# define the output layer with sigmoid activation function, where the dimension is equal to the input dimension
decoder_x = mx.sym.FullyConnected(data=act_z, name="decoder_x",num_hidden=features)
y = mx.sym.Activation(data=decoder_x, act_type="sigmoid",name='activation_x')

2.3 Joint Loss Function for the Encoder and the Decoder

The variational lower bound can be estimated as:

\begin{align} \mathcal{L}(\theta,\phi;x_{(i)}) \approx \frac{1}{2}\left(1+\log ((\sigma_j^{(i)})^2)-(\mu_j^{(i)})^2-(\sigma_j^{(i)})^2\right) + \log p_\theta(x^{(i)}|z^{(i)}) \end{align}

where the first term is the KL divergence of the approximate posterior from the prior, and the second term is an expected negative reconstruction error. We would like to maximize this lower bound, so we can define the loss to be $-\mathcal{L}$ for MXNet to minimize.


In [7]:
# define the objective loss function that needs to be minimized
KL = 0.5*mx.symbol.sum(1+logvar-pow( mu,2)-mx.symbol.exp(logvar),axis=1)
loss = -mx.symbol.sum(mx.symbol.broadcast_mul(loss_label,mx.symbol.log(y)) + mx.symbol.broadcast_mul(1-loss_label,mx.symbol.log(1-y)),axis=1)-KL
output = mx.symbol.MakeLoss(sum(loss),name='loss')

3. Training the model

Now we can define the model and train it, we initilize the weights and the biases to be Gaussian(0,0.01), and use stochastic gradient descent for optimization. To warm start the training, one may initilize with pre-trainined parameters arg_params using init=mx.initializer.Load(arg_params). To save intermediate results, we can optionally use epoch_end_callback = mx.callback.do_checkpoint(model_prefix, 1) which saves the parameters to the path given by model_prefix, and with period every $1$ epoch. To assess the performance, we output minus ELBO $-\mathcal{L}$ after each epoch, with the command eval_metric = 'Loss'. We can plot the training loss for mini batches by accessing the log and save it to a list, then parse it to the argument batch_end_callback.


In [8]:
# set up the log
nd_iter.reset()
logging.getLogger().setLevel(logging.DEBUG)  

#define function to trave back training loss
def log_to_list(period, lst):
    def _callback(param):
        """The checkpoint function."""
        if param.nbatch % period == 0:
            name, value = param.eval_metric.get()
            lst.append(value)
    return _callback

# define the model
model = mx.mod.Module(
    symbol = output ,
    data_names=['data'],
    label_names = ['loss_label'])

In [9]:
# training the model, save training loss as a list.
training_loss=list()

# initilize the parameters for training using Normal.
init = mx.init.Normal(0.01)
model.fit(nd_iter,  # train data
              initializer=init,
              #if eval_data is supplied, test loss will also be reported
              #eval_data = nd_iter_test,
              optimizer='sgd',  # use SGD to train
              optimizer_params={'learning_rate':1e-3,'wd':1e-2},  
              epoch_end_callback  = None if model_prefix==None else mx.callback.do_checkpoint(model_prefix, 1),   #save parameters for each epoch if model_prefix is supplied
              batch_end_callback = log_to_list(N/batch_size,training_loss), 
              num_epoch=100,
              eval_metric = 'Loss')


INFO:root:Epoch[0] Train-loss=375.023381
INFO:root:Epoch[0] Time cost=6.127
INFO:root:Epoch[1] Train-loss=212.780315
INFO:root:Epoch[1] Time cost=6.409
INFO:root:Epoch[2] Train-loss=208.209400
INFO:root:Epoch[2] Time cost=6.619
INFO:root:Epoch[3] Train-loss=206.146854
INFO:root:Epoch[3] Time cost=6.648
INFO:root:Epoch[4] Train-loss=204.530598
INFO:root:Epoch[4] Time cost=7.000
INFO:root:Epoch[5] Train-loss=202.799992
INFO:root:Epoch[5] Time cost=6.778
INFO:root:Epoch[6] Train-loss=200.333474
INFO:root:Epoch[6] Time cost=7.187
INFO:root:Epoch[7] Train-loss=197.506393
INFO:root:Epoch[7] Time cost=6.712
INFO:root:Epoch[8] Train-loss=195.969775
INFO:root:Epoch[8] Time cost=6.896
INFO:root:Epoch[9] Train-loss=195.418288
INFO:root:Epoch[9] Time cost=6.887
INFO:root:Epoch[10] Train-loss=194.739763
INFO:root:Epoch[10] Time cost=6.745
INFO:root:Epoch[11] Train-loss=194.380536
INFO:root:Epoch[11] Time cost=6.706
INFO:root:Epoch[12] Train-loss=193.955462
INFO:root:Epoch[12] Time cost=6.592
INFO:root:Epoch[13] Train-loss=193.493671
INFO:root:Epoch[13] Time cost=6.775
INFO:root:Epoch[14] Train-loss=192.958739
INFO:root:Epoch[14] Time cost=6.600
INFO:root:Epoch[15] Train-loss=191.928542
INFO:root:Epoch[15] Time cost=6.586
INFO:root:Epoch[16] Train-loss=189.797939
INFO:root:Epoch[16] Time cost=6.700
INFO:root:Epoch[17] Train-loss=186.672446
INFO:root:Epoch[17] Time cost=6.869
INFO:root:Epoch[18] Train-loss=184.616599
INFO:root:Epoch[18] Time cost=7.144
INFO:root:Epoch[19] Train-loss=183.305978
INFO:root:Epoch[19] Time cost=6.997
INFO:root:Epoch[20] Train-loss=181.944634
INFO:root:Epoch[20] Time cost=6.481
INFO:root:Epoch[21] Train-loss=181.005329
INFO:root:Epoch[21] Time cost=6.754
INFO:root:Epoch[22] Train-loss=178.363118
INFO:root:Epoch[22] Time cost=7.000
INFO:root:Epoch[23] Train-loss=176.363421
INFO:root:Epoch[23] Time cost=6.923
INFO:root:Epoch[24] Train-loss=174.573954
INFO:root:Epoch[24] Time cost=6.510
INFO:root:Epoch[25] Train-loss=173.245940
INFO:root:Epoch[25] Time cost=6.926
INFO:root:Epoch[26] Train-loss=172.082522
INFO:root:Epoch[26] Time cost=6.733
INFO:root:Epoch[27] Train-loss=171.123084
INFO:root:Epoch[27] Time cost=6.616
INFO:root:Epoch[28] Train-loss=170.239300
INFO:root:Epoch[28] Time cost=7.004
INFO:root:Epoch[29] Train-loss=169.538416
INFO:root:Epoch[29] Time cost=6.341
INFO:root:Epoch[30] Train-loss=168.952901
INFO:root:Epoch[30] Time cost=6.736
INFO:root:Epoch[31] Train-loss=168.169076
INFO:root:Epoch[31] Time cost=6.616
INFO:root:Epoch[32] Train-loss=167.208973
INFO:root:Epoch[32] Time cost=6.446
INFO:root:Epoch[33] Train-loss=165.732213
INFO:root:Epoch[33] Time cost=6.405
INFO:root:Epoch[34] Train-loss=163.606801
INFO:root:Epoch[34] Time cost=6.139
INFO:root:Epoch[35] Train-loss=161.985880
INFO:root:Epoch[35] Time cost=6.678
INFO:root:Epoch[36] Train-loss=160.763072
INFO:root:Epoch[36] Time cost=8.749
INFO:root:Epoch[37] Train-loss=160.025193
INFO:root:Epoch[37] Time cost=6.519
INFO:root:Epoch[38] Train-loss=159.319723
INFO:root:Epoch[38] Time cost=7.584
INFO:root:Epoch[39] Train-loss=158.670701
INFO:root:Epoch[39] Time cost=6.874
INFO:root:Epoch[40] Train-loss=158.225733
INFO:root:Epoch[40] Time cost=6.402
INFO:root:Epoch[41] Train-loss=157.741337
INFO:root:Epoch[41] Time cost=8.617
INFO:root:Epoch[42] Train-loss=157.301411
INFO:root:Epoch[42] Time cost=6.515
INFO:root:Epoch[43] Train-loss=156.765170
INFO:root:Epoch[43] Time cost=6.447
INFO:root:Epoch[44] Train-loss=156.389668
INFO:root:Epoch[44] Time cost=6.130
INFO:root:Epoch[45] Train-loss=155.815434
INFO:root:Epoch[45] Time cost=6.155
INFO:root:Epoch[46] Train-loss=155.432254
INFO:root:Epoch[46] Time cost=6.158
INFO:root:Epoch[47] Train-loss=155.114027
INFO:root:Epoch[47] Time cost=6.749
INFO:root:Epoch[48] Train-loss=154.612441
INFO:root:Epoch[48] Time cost=6.255
INFO:root:Epoch[49] Train-loss=154.137659
INFO:root:Epoch[49] Time cost=7.813
INFO:root:Epoch[50] Train-loss=153.634072
INFO:root:Epoch[50] Time cost=7.408
INFO:root:Epoch[51] Train-loss=153.417397
INFO:root:Epoch[51] Time cost=7.747
INFO:root:Epoch[52] Train-loss=152.851887
INFO:root:Epoch[52] Time cost=8.587
INFO:root:Epoch[53] Train-loss=152.575068
INFO:root:Epoch[53] Time cost=7.554
INFO:root:Epoch[54] Train-loss=152.084419
INFO:root:Epoch[54] Time cost=6.628
INFO:root:Epoch[55] Train-loss=151.724836
INFO:root:Epoch[55] Time cost=6.535
INFO:root:Epoch[56] Train-loss=151.302525
INFO:root:Epoch[56] Time cost=7.148
INFO:root:Epoch[57] Train-loss=150.960916
INFO:root:Epoch[57] Time cost=7.195
INFO:root:Epoch[58] Train-loss=150.603895
INFO:root:Epoch[58] Time cost=6.649
INFO:root:Epoch[59] Train-loss=150.237795
INFO:root:Epoch[59] Time cost=6.222
INFO:root:Epoch[60] Train-loss=149.936080
INFO:root:Epoch[60] Time cost=8.450
INFO:root:Epoch[61] Train-loss=149.514617
INFO:root:Epoch[61] Time cost=6.113
INFO:root:Epoch[62] Train-loss=149.229345
INFO:root:Epoch[62] Time cost=6.088
INFO:root:Epoch[63] Train-loss=148.893769
INFO:root:Epoch[63] Time cost=6.558
INFO:root:Epoch[64] Train-loss=148.526837
INFO:root:Epoch[64] Time cost=7.590
INFO:root:Epoch[65] Train-loss=148.249951
INFO:root:Epoch[65] Time cost=6.180
INFO:root:Epoch[66] Train-loss=147.940414
INFO:root:Epoch[66] Time cost=6.242
INFO:root:Epoch[67] Train-loss=147.621304
INFO:root:Epoch[67] Time cost=8.501
INFO:root:Epoch[68] Train-loss=147.294314
INFO:root:Epoch[68] Time cost=7.645
INFO:root:Epoch[69] Train-loss=147.074479
INFO:root:Epoch[69] Time cost=7.092
INFO:root:Epoch[70] Train-loss=146.796387
INFO:root:Epoch[70] Time cost=6.914
INFO:root:Epoch[71] Train-loss=146.508842
INFO:root:Epoch[71] Time cost=6.606
INFO:root:Epoch[72] Train-loss=146.230444
INFO:root:Epoch[72] Time cost=7.755
INFO:root:Epoch[73] Train-loss=145.970296
INFO:root:Epoch[73] Time cost=6.409
INFO:root:Epoch[74] Train-loss=145.711610
INFO:root:Epoch[74] Time cost=6.334
INFO:root:Epoch[75] Train-loss=145.460053
INFO:root:Epoch[75] Time cost=7.269
INFO:root:Epoch[76] Train-loss=145.156451
INFO:root:Epoch[76] Time cost=6.744
INFO:root:Epoch[77] Train-loss=144.957674
INFO:root:Epoch[77] Time cost=7.100
INFO:root:Epoch[78] Train-loss=144.729749
INFO:root:Epoch[78] Time cost=6.242
INFO:root:Epoch[79] Train-loss=144.481728
INFO:root:Epoch[79] Time cost=6.865
INFO:root:Epoch[80] Train-loss=144.236061
INFO:root:Epoch[80] Time cost=6.632
INFO:root:Epoch[81] Train-loss=144.030473
INFO:root:Epoch[81] Time cost=6.764
INFO:root:Epoch[82] Train-loss=143.776374
INFO:root:Epoch[82] Time cost=6.564
INFO:root:Epoch[83] Train-loss=143.538847
INFO:root:Epoch[83] Time cost=6.181
INFO:root:Epoch[84] Train-loss=143.326444
INFO:root:Epoch[84] Time cost=6.220
INFO:root:Epoch[85] Train-loss=143.078987
INFO:root:Epoch[85] Time cost=6.823
INFO:root:Epoch[86] Train-loss=142.877117
INFO:root:Epoch[86] Time cost=7.755
INFO:root:Epoch[87] Train-loss=142.667316
INFO:root:Epoch[87] Time cost=6.068
INFO:root:Epoch[88] Train-loss=142.461755
INFO:root:Epoch[88] Time cost=6.111
INFO:root:Epoch[89] Train-loss=142.270438
INFO:root:Epoch[89] Time cost=6.221
INFO:root:Epoch[90] Train-loss=142.047086
INFO:root:Epoch[90] Time cost=8.061
INFO:root:Epoch[91] Train-loss=141.855774
INFO:root:Epoch[91] Time cost=6.433
INFO:root:Epoch[92] Train-loss=141.688955
INFO:root:Epoch[92] Time cost=7.153
INFO:root:Epoch[93] Train-loss=141.442910
INFO:root:Epoch[93] Time cost=7.113
INFO:root:Epoch[94] Train-loss=141.279274
INFO:root:Epoch[94] Time cost=7.152
INFO:root:Epoch[95] Train-loss=141.086522
INFO:root:Epoch[95] Time cost=6.472
INFO:root:Epoch[96] Train-loss=140.901925
INFO:root:Epoch[96] Time cost=6.767
INFO:root:Epoch[97] Train-loss=140.722496
INFO:root:Epoch[97] Time cost=7.044
INFO:root:Epoch[98] Train-loss=140.579295
INFO:root:Epoch[98] Time cost=7.040
INFO:root:Epoch[99] Train-loss=140.386067
INFO:root:Epoch[99] Time cost=6.669

In [23]:
ELBO = [-training_loss[i] for i in range(len(training_loss))]
plt.plot(ELBO)
plt.ylabel('ELBO');plt.xlabel('epoch');plt.title("training curve for mini batches")
plt.show()


As expected, the ELBO is monotonically increasing over epoch, and we reproduced the resutls given in the paper Auto-Encoding Variational Bayes. Now we can extract/load the parameters and then feed the network forward to calculate $y$ which is the reconstructed image, and we can also calculate the ELBO for the test set.


In [80]:
arg_params = model.get_params()[0]

#if saved the parameters, can load them at e.g. 100th epoch
#sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 100)
#assert sym.tojson() == output.tojson()

e = y.bind(mx.cpu(), {'data': nd_iter_test.data[0][1],
                     'encoder_h_weight': arg_params['encoder_h_weight'],
                     'encoder_h_bias': arg_params['encoder_h_bias'],
                     'mu_weight': arg_params['mu_weight'],
                     'mu_bias': arg_params['mu_bias'],
                     'logvar_weight':arg_params['logvar_weight'],
                     'logvar_bias':arg_params['logvar_bias'],
                     'decoder_z_weight':arg_params['decoder_z_weight'],
                     'decoder_z_bias':arg_params['decoder_z_bias'],
                     'decoder_x_weight':arg_params['decoder_x_weight'],
                     'decoder_x_bias':arg_params['decoder_x_bias'],                
                     'loss_label':label})

x_fit = e.forward()
x_construction = x_fit[0].asnumpy()

In [78]:
# learning images on the test set
f, ((ax1, ax2, ax3, ax4)) = plt.subplots(1,4,  sharex='col', sharey='row',figsize=(12,3))
ax1.imshow(np.reshape(image_test[0,:],(28,28)), interpolation='nearest', cmap=cm.Greys)
ax1.set_title('True image')
ax2.imshow(np.reshape(x_construction[0,:],(28,28)), interpolation='nearest', cmap=cm.Greys)
ax2.set_title('Learned image')
ax3.imshow(np.reshape(x_construction[999,:],(28,28)), interpolation='nearest', cmap=cm.Greys)
ax3.set_title('Learned image')
ax4.imshow(np.reshape(x_construction[9999,:],(28,28)), interpolation='nearest', cmap=cm.Greys)
ax4.set_title('Learned image')
plt.show()



In [37]:
#calculate the ELBO which is minus the loss for test set
metric = mx.metric.Loss()
model.score(nd_iter_test, metric)


Out[37]:
[('loss', 139.73684648437501)]

4. All together: MXNet-based class VAE


In [2]:
from VAE import VAE

One can directly call the class VAE to do the training. The outputs are the learned model and training loss. VAE(n_latent=5,num_hidden_ecoder=400,num_hidden_decoder=400,x_train=None,x_valid=None,batch_size=100,learning_rate=0.001,weight_decay=0.01,num_epoch=100,optimizer='sgd',model_prefix=None, initializer = mx.init.Normal(0.01),likelihood=Bernoulli)


In [8]:
# can initilize weights and biases with the learned parameters 
#init = mx.initializer.Load(params)

# call the VAE , output model contains the learned model and training loss
out = VAE(n_latent=2,x_train=image,x_valid=None,num_epoch=200)


INFO:root:Epoch[0] Train-loss=377.146422
INFO:root:Epoch[0] Time cost=5.989
INFO:root:Epoch[1] Train-loss=211.998043
INFO:root:Epoch[1] Time cost=6.303
INFO:root:Epoch[2] Train-loss=207.103096
INFO:root:Epoch[2] Time cost=7.368
INFO:root:Epoch[3] Train-loss=204.958183
INFO:root:Epoch[3] Time cost=7.530
INFO:root:Epoch[4] Train-loss=203.342700
INFO:root:Epoch[4] Time cost=8.887
INFO:root:Epoch[5] Train-loss=201.649251
INFO:root:Epoch[5] Time cost=9.147
INFO:root:Epoch[6] Train-loss=199.782661
INFO:root:Epoch[6] Time cost=8.924
INFO:root:Epoch[7] Train-loss=198.044015
INFO:root:Epoch[7] Time cost=8.920
INFO:root:Epoch[8] Train-loss=195.732077
INFO:root:Epoch[8] Time cost=8.857
INFO:root:Epoch[9] Train-loss=194.070547
INFO:root:Epoch[9] Time cost=9.216
INFO:root:Epoch[10] Train-loss=193.186871
INFO:root:Epoch[10] Time cost=8.966
INFO:root:Epoch[11] Train-loss=192.700208
INFO:root:Epoch[11] Time cost=8.843
INFO:root:Epoch[12] Train-loss=192.191504
INFO:root:Epoch[12] Time cost=8.152
INFO:root:Epoch[13] Train-loss=191.842837
INFO:root:Epoch[13] Time cost=6.180
INFO:root:Epoch[14] Train-loss=191.310450
INFO:root:Epoch[14] Time cost=6.067
INFO:root:Epoch[15] Train-loss=190.520681
INFO:root:Epoch[15] Time cost=6.058
INFO:root:Epoch[16] Train-loss=189.784146
INFO:root:Epoch[16] Time cost=6.046
INFO:root:Epoch[17] Train-loss=188.515020
INFO:root:Epoch[17] Time cost=6.062
INFO:root:Epoch[18] Train-loss=187.530712
INFO:root:Epoch[18] Time cost=6.088
INFO:root:Epoch[19] Train-loss=186.194826
INFO:root:Epoch[19] Time cost=6.491
INFO:root:Epoch[20] Train-loss=185.492288
INFO:root:Epoch[20] Time cost=6.182
INFO:root:Epoch[21] Train-loss=184.922654
INFO:root:Epoch[21] Time cost=6.058
INFO:root:Epoch[22] Train-loss=184.677911
INFO:root:Epoch[22] Time cost=6.042
INFO:root:Epoch[23] Train-loss=183.921396
INFO:root:Epoch[23] Time cost=5.994
INFO:root:Epoch[24] Train-loss=183.600690
INFO:root:Epoch[24] Time cost=6.038
INFO:root:Epoch[25] Train-loss=183.388476
INFO:root:Epoch[25] Time cost=6.025
INFO:root:Epoch[26] Train-loss=182.972208
INFO:root:Epoch[26] Time cost=6.014
INFO:root:Epoch[27] Train-loss=182.561678
INFO:root:Epoch[27] Time cost=6.064
INFO:root:Epoch[28] Train-loss=182.475261
INFO:root:Epoch[28] Time cost=5.983
INFO:root:Epoch[29] Train-loss=182.308808
INFO:root:Epoch[29] Time cost=6.371
INFO:root:Epoch[30] Train-loss=182.135900
INFO:root:Epoch[30] Time cost=6.038
INFO:root:Epoch[31] Train-loss=181.978367
INFO:root:Epoch[31] Time cost=6.924
INFO:root:Epoch[32] Train-loss=181.677153
INFO:root:Epoch[32] Time cost=8.205
INFO:root:Epoch[33] Train-loss=181.677775
INFO:root:Epoch[33] Time cost=6.017
INFO:root:Epoch[34] Train-loss=181.257998
INFO:root:Epoch[34] Time cost=6.056
INFO:root:Epoch[35] Train-loss=181.125288
INFO:root:Epoch[35] Time cost=6.020
INFO:root:Epoch[36] Train-loss=181.018858
INFO:root:Epoch[36] Time cost=6.035
INFO:root:Epoch[37] Train-loss=180.785110
INFO:root:Epoch[37] Time cost=6.049
INFO:root:Epoch[38] Train-loss=180.452598
INFO:root:Epoch[38] Time cost=6.083
INFO:root:Epoch[39] Train-loss=180.362733
INFO:root:Epoch[39] Time cost=6.198
INFO:root:Epoch[40] Train-loss=180.060788
INFO:root:Epoch[40] Time cost=6.049
INFO:root:Epoch[41] Train-loss=180.022728
INFO:root:Epoch[41] Time cost=6.135
INFO:root:Epoch[42] Train-loss=179.648499
INFO:root:Epoch[42] Time cost=6.055
INFO:root:Epoch[43] Train-loss=179.507952
INFO:root:Epoch[43] Time cost=6.108
INFO:root:Epoch[44] Train-loss=179.303132
INFO:root:Epoch[44] Time cost=6.020
INFO:root:Epoch[45] Train-loss=178.945211
INFO:root:Epoch[45] Time cost=6.004
INFO:root:Epoch[46] Train-loss=178.808598
INFO:root:Epoch[46] Time cost=6.016
INFO:root:Epoch[47] Train-loss=178.550906
INFO:root:Epoch[47] Time cost=6.050
INFO:root:Epoch[48] Train-loss=178.403674
INFO:root:Epoch[48] Time cost=6.115
INFO:root:Epoch[49] Train-loss=178.237544
INFO:root:Epoch[49] Time cost=6.004
INFO:root:Epoch[50] Train-loss=178.033747
INFO:root:Epoch[50] Time cost=6.051
INFO:root:Epoch[51] Train-loss=177.802884
INFO:root:Epoch[51] Time cost=6.028
INFO:root:Epoch[52] Train-loss=177.533980
INFO:root:Epoch[52] Time cost=6.052
INFO:root:Epoch[53] Train-loss=177.490143
INFO:root:Epoch[53] Time cost=6.019
INFO:root:Epoch[54] Train-loss=177.136637
INFO:root:Epoch[54] Time cost=6.014
INFO:root:Epoch[55] Train-loss=177.062524
INFO:root:Epoch[55] Time cost=6.024
INFO:root:Epoch[56] Train-loss=176.869033
INFO:root:Epoch[56] Time cost=6.065
INFO:root:Epoch[57] Train-loss=176.704606
INFO:root:Epoch[57] Time cost=6.037
INFO:root:Epoch[58] Train-loss=176.470091
INFO:root:Epoch[58] Time cost=6.012
INFO:root:Epoch[59] Train-loss=176.261440
INFO:root:Epoch[59] Time cost=6.215
INFO:root:Epoch[60] Train-loss=176.133904
INFO:root:Epoch[60] Time cost=6.042
INFO:root:Epoch[61] Train-loss=175.941920
INFO:root:Epoch[61] Time cost=6.000
INFO:root:Epoch[62] Train-loss=175.731296
INFO:root:Epoch[62] Time cost=6.025
INFO:root:Epoch[63] Train-loss=175.613303
INFO:root:Epoch[63] Time cost=6.002
INFO:root:Epoch[64] Train-loss=175.438844
INFO:root:Epoch[64] Time cost=5.982
INFO:root:Epoch[65] Train-loss=175.254716
INFO:root:Epoch[65] Time cost=6.016
INFO:root:Epoch[66] Train-loss=175.090210
INFO:root:Epoch[66] Time cost=6.008
INFO:root:Epoch[67] Train-loss=174.895443
INFO:root:Epoch[67] Time cost=6.008
INFO:root:Epoch[68] Train-loss=174.701321
INFO:root:Epoch[68] Time cost=6.418
INFO:root:Epoch[69] Train-loss=174.553292
INFO:root:Epoch[69] Time cost=6.072
INFO:root:Epoch[70] Train-loss=174.349379
INFO:root:Epoch[70] Time cost=6.048
INFO:root:Epoch[71] Train-loss=174.174641
INFO:root:Epoch[71] Time cost=6.036
INFO:root:Epoch[72] Train-loss=173.966333
INFO:root:Epoch[72] Time cost=6.017
INFO:root:Epoch[73] Train-loss=173.798454
INFO:root:Epoch[73] Time cost=6.018
INFO:root:Epoch[74] Train-loss=173.635657
INFO:root:Epoch[74] Time cost=5.985
INFO:root:Epoch[75] Train-loss=173.423795
INFO:root:Epoch[75] Time cost=6.016
INFO:root:Epoch[76] Train-loss=173.273981
INFO:root:Epoch[76] Time cost=6.018
INFO:root:Epoch[77] Train-loss=173.073401
INFO:root:Epoch[77] Time cost=5.996
INFO:root:Epoch[78] Train-loss=172.888044
INFO:root:Epoch[78] Time cost=6.035
INFO:root:Epoch[79] Train-loss=172.694943
INFO:root:Epoch[79] Time cost=8.492
INFO:root:Epoch[80] Train-loss=172.504260
INFO:root:Epoch[80] Time cost=7.380
INFO:root:Epoch[81] Train-loss=172.323245
INFO:root:Epoch[81] Time cost=6.063
INFO:root:Epoch[82] Train-loss=172.131274
INFO:root:Epoch[82] Time cost=6.209
INFO:root:Epoch[83] Train-loss=171.932986
INFO:root:Epoch[83] Time cost=6.060
INFO:root:Epoch[84] Train-loss=171.755262
INFO:root:Epoch[84] Time cost=6.068
INFO:root:Epoch[85] Train-loss=171.556803
INFO:root:Epoch[85] Time cost=6.004
INFO:root:Epoch[86] Train-loss=171.384773
INFO:root:Epoch[86] Time cost=6.059
INFO:root:Epoch[87] Train-loss=171.185034
INFO:root:Epoch[87] Time cost=6.001
INFO:root:Epoch[88] Train-loss=170.995980
INFO:root:Epoch[88] Time cost=6.143
INFO:root:Epoch[89] Train-loss=170.818701
INFO:root:Epoch[89] Time cost=6.690
INFO:root:Epoch[90] Train-loss=170.629929
INFO:root:Epoch[90] Time cost=6.869
INFO:root:Epoch[91] Train-loss=170.450824
INFO:root:Epoch[91] Time cost=7.156
INFO:root:Epoch[92] Train-loss=170.261806
INFO:root:Epoch[92] Time cost=6.972
INFO:root:Epoch[93] Train-loss=170.070318
INFO:root:Epoch[93] Time cost=6.595
INFO:root:Epoch[94] Train-loss=169.906993
INFO:root:Epoch[94] Time cost=6.561
INFO:root:Epoch[95] Train-loss=169.734455
INFO:root:Epoch[95] Time cost=6.744
INFO:root:Epoch[96] Train-loss=169.564318
INFO:root:Epoch[96] Time cost=6.601
INFO:root:Epoch[97] Train-loss=169.373926
INFO:root:Epoch[97] Time cost=6.725
INFO:root:Epoch[98] Train-loss=169.215408
INFO:root:Epoch[98] Time cost=6.391
INFO:root:Epoch[99] Train-loss=169.039854
INFO:root:Epoch[99] Time cost=6.677
INFO:root:Epoch[100] Train-loss=168.869222
INFO:root:Epoch[100] Time cost=6.370
INFO:root:Epoch[101] Train-loss=168.703175
INFO:root:Epoch[101] Time cost=6.607
INFO:root:Epoch[102] Train-loss=168.523054
INFO:root:Epoch[102] Time cost=6.368
INFO:root:Epoch[103] Train-loss=168.365964
INFO:root:Epoch[103] Time cost=10.267
INFO:root:Epoch[104] Train-loss=168.181174
INFO:root:Epoch[104] Time cost=11.132
INFO:root:Epoch[105] Train-loss=168.021498
INFO:root:Epoch[105] Time cost=10.187
INFO:root:Epoch[106] Train-loss=167.858251
INFO:root:Epoch[106] Time cost=10.676
INFO:root:Epoch[107] Train-loss=167.690670
INFO:root:Epoch[107] Time cost=10.973
INFO:root:Epoch[108] Train-loss=167.535069
INFO:root:Epoch[108] Time cost=10.108
INFO:root:Epoch[109] Train-loss=167.373971
INFO:root:Epoch[109] Time cost=11.013
INFO:root:Epoch[110] Train-loss=167.207507
INFO:root:Epoch[110] Time cost=11.427
INFO:root:Epoch[111] Train-loss=167.043077
INFO:root:Epoch[111] Time cost=10.349
INFO:root:Epoch[112] Train-loss=166.884060
INFO:root:Epoch[112] Time cost=13.129
INFO:root:Epoch[113] Train-loss=166.746976
INFO:root:Epoch[113] Time cost=11.255
INFO:root:Epoch[114] Train-loss=166.572499
INFO:root:Epoch[114] Time cost=10.037
INFO:root:Epoch[115] Train-loss=166.445170
INFO:root:Epoch[115] Time cost=10.406
INFO:root:Epoch[116] Train-loss=166.284912
INFO:root:Epoch[116] Time cost=10.170
INFO:root:Epoch[117] Train-loss=166.171475
INFO:root:Epoch[117] Time cost=10.034
INFO:root:Epoch[118] Train-loss=166.015457
INFO:root:Epoch[118] Time cost=10.047
INFO:root:Epoch[119] Train-loss=165.882208
INFO:root:Epoch[119] Time cost=10.008
INFO:root:Epoch[120] Train-loss=165.753836
INFO:root:Epoch[120] Time cost=10.056
INFO:root:Epoch[121] Train-loss=165.626045
INFO:root:Epoch[121] Time cost=10.704
INFO:root:Epoch[122] Train-loss=165.492859
INFO:root:Epoch[122] Time cost=10.609
INFO:root:Epoch[123] Train-loss=165.361132
INFO:root:Epoch[123] Time cost=10.027
INFO:root:Epoch[124] Train-loss=165.256487
INFO:root:Epoch[124] Time cost=11.225
INFO:root:Epoch[125] Train-loss=165.119995
INFO:root:Epoch[125] Time cost=11.266
INFO:root:Epoch[126] Train-loss=165.012773
INFO:root:Epoch[126] Time cost=10.547
INFO:root:Epoch[127] Train-loss=164.898748
INFO:root:Epoch[127] Time cost=10.339
INFO:root:Epoch[128] Train-loss=164.775702
INFO:root:Epoch[128] Time cost=10.875
INFO:root:Epoch[129] Train-loss=164.692449
INFO:root:Epoch[129] Time cost=8.412
INFO:root:Epoch[130] Train-loss=164.564323
INFO:root:Epoch[130] Time cost=7.239
INFO:root:Epoch[131] Train-loss=164.468273
INFO:root:Epoch[131] Time cost=10.096
INFO:root:Epoch[132] Train-loss=164.328320
INFO:root:Epoch[132] Time cost=9.680
INFO:root:Epoch[133] Train-loss=164.256156
INFO:root:Epoch[133] Time cost=10.707
INFO:root:Epoch[134] Train-loss=164.151625
INFO:root:Epoch[134] Time cost=13.835
INFO:root:Epoch[135] Train-loss=164.046402
INFO:root:Epoch[135] Time cost=10.049
INFO:root:Epoch[136] Train-loss=163.960676
INFO:root:Epoch[136] Time cost=9.625
INFO:root:Epoch[137] Train-loss=163.873193
INFO:root:Epoch[137] Time cost=9.845
INFO:root:Epoch[138] Train-loss=163.783837
INFO:root:Epoch[138] Time cost=9.618
INFO:root:Epoch[139] Train-loss=163.658903
INFO:root:Epoch[139] Time cost=10.411
INFO:root:Epoch[140] Train-loss=163.588920
INFO:root:Epoch[140] Time cost=9.633
INFO:root:Epoch[141] Train-loss=163.493254
INFO:root:Epoch[141] Time cost=10.668
INFO:root:Epoch[142] Train-loss=163.401188
INFO:root:Epoch[142] Time cost=10.644
INFO:root:Epoch[143] Train-loss=163.334470
INFO:root:Epoch[143] Time cost=9.665
INFO:root:Epoch[144] Train-loss=163.235133
INFO:root:Epoch[144] Time cost=9.612
INFO:root:Epoch[145] Train-loss=163.168029
INFO:root:Epoch[145] Time cost=9.578
INFO:root:Epoch[146] Train-loss=163.092392
INFO:root:Epoch[146] Time cost=10.215
INFO:root:Epoch[147] Train-loss=163.014362
INFO:root:Epoch[147] Time cost=12.296
INFO:root:Epoch[148] Train-loss=162.891574
INFO:root:Epoch[148] Time cost=9.578
INFO:root:Epoch[149] Train-loss=162.831664
INFO:root:Epoch[149] Time cost=9.536
INFO:root:Epoch[150] Train-loss=162.768784
INFO:root:Epoch[150] Time cost=9.607
INFO:root:Epoch[151] Train-loss=162.695416
INFO:root:Epoch[151] Time cost=9.681
INFO:root:Epoch[152] Train-loss=162.620814
INFO:root:Epoch[152] Time cost=9.464
INFO:root:Epoch[153] Train-loss=162.527031
INFO:root:Epoch[153] Time cost=9.518
INFO:root:Epoch[154] Train-loss=162.466575
INFO:root:Epoch[154] Time cost=9.562
INFO:root:Epoch[155] Train-loss=162.409388
INFO:root:Epoch[155] Time cost=9.483
INFO:root:Epoch[156] Train-loss=162.308957
INFO:root:Epoch[156] Time cost=9.545
INFO:root:Epoch[157] Train-loss=162.211725
INFO:root:Epoch[157] Time cost=9.542
INFO:root:Epoch[158] Train-loss=162.141098
INFO:root:Epoch[158] Time cost=9.768
INFO:root:Epoch[159] Train-loss=162.124311
INFO:root:Epoch[159] Time cost=7.155
INFO:root:Epoch[160] Train-loss=162.013039
INFO:root:Epoch[160] Time cost=6.147
INFO:root:Epoch[161] Train-loss=161.954485
INFO:root:Epoch[161] Time cost=9.121
INFO:root:Epoch[162] Train-loss=161.913859
INFO:root:Epoch[162] Time cost=9.936
INFO:root:Epoch[163] Train-loss=161.830799
INFO:root:Epoch[163] Time cost=8.612
INFO:root:Epoch[164] Train-loss=161.768672
INFO:root:Epoch[164] Time cost=9.722
INFO:root:Epoch[165] Train-loss=161.689120
INFO:root:Epoch[165] Time cost=9.478
INFO:root:Epoch[166] Train-loss=161.598279
INFO:root:Epoch[166] Time cost=9.466
INFO:root:Epoch[167] Train-loss=161.551172
INFO:root:Epoch[167] Time cost=9.419
INFO:root:Epoch[168] Train-loss=161.488880
INFO:root:Epoch[168] Time cost=9.457
INFO:root:Epoch[169] Train-loss=161.410458
INFO:root:Epoch[169] Time cost=9.504
INFO:root:Epoch[170] Train-loss=161.340681
INFO:root:Epoch[170] Time cost=9.866
INFO:root:Epoch[171] Train-loss=161.281700
INFO:root:Epoch[171] Time cost=9.526
INFO:root:Epoch[172] Train-loss=161.215523
INFO:root:Epoch[172] Time cost=9.511
INFO:root:Epoch[173] Train-loss=161.152452
INFO:root:Epoch[173] Time cost=9.498
INFO:root:Epoch[174] Train-loss=161.058544
INFO:root:Epoch[174] Time cost=9.561
INFO:root:Epoch[175] Train-loss=161.036475
INFO:root:Epoch[175] Time cost=9.463
INFO:root:Epoch[176] Train-loss=161.009996
INFO:root:Epoch[176] Time cost=9.629
INFO:root:Epoch[177] Train-loss=160.853546
INFO:root:Epoch[177] Time cost=9.518
INFO:root:Epoch[178] Train-loss=160.860520
INFO:root:Epoch[178] Time cost=9.395
INFO:root:Epoch[179] Train-loss=160.810621
INFO:root:Epoch[179] Time cost=9.452
INFO:root:Epoch[180] Train-loss=160.683071
INFO:root:Epoch[180] Time cost=9.411
INFO:root:Epoch[181] Train-loss=160.674101
INFO:root:Epoch[181] Time cost=8.784
INFO:root:Epoch[182] Train-loss=160.554823
INFO:root:Epoch[182] Time cost=7.265
INFO:root:Epoch[183] Train-loss=160.536528
INFO:root:Epoch[183] Time cost=6.108
INFO:root:Epoch[184] Train-loss=160.525913
INFO:root:Epoch[184] Time cost=6.349
INFO:root:Epoch[185] Train-loss=160.399412
INFO:root:Epoch[185] Time cost=7.364
INFO:root:Epoch[186] Train-loss=160.380027
INFO:root:Epoch[186] Time cost=7.651
INFO:root:Epoch[187] Train-loss=160.272921
INFO:root:Epoch[187] Time cost=7.309
INFO:root:Epoch[188] Train-loss=160.243907
INFO:root:Epoch[188] Time cost=7.162
INFO:root:Epoch[189] Train-loss=160.194351
INFO:root:Epoch[189] Time cost=8.941
INFO:root:Epoch[190] Train-loss=160.130400
INFO:root:Epoch[190] Time cost=10.242
INFO:root:Epoch[191] Train-loss=160.073841
INFO:root:Epoch[191] Time cost=10.528
INFO:root:Epoch[192] Train-loss=160.021623
INFO:root:Epoch[192] Time cost=9.482
INFO:root:Epoch[193] Train-loss=159.938673
INFO:root:Epoch[193] Time cost=9.465
INFO:root:Epoch[194] Train-loss=159.885823
INFO:root:Epoch[194] Time cost=9.523
INFO:root:Epoch[195] Train-loss=159.886516
INFO:root:Epoch[195] Time cost=9.599
INFO:root:Epoch[196] Train-loss=159.797400
INFO:root:Epoch[196] Time cost=8.675
INFO:root:Epoch[197] Train-loss=159.705562
INFO:root:Epoch[197] Time cost=9.551
INFO:root:Epoch[198] Train-loss=159.738354
INFO:root:Epoch[198] Time cost=9.919
INFO:root:Epoch[199] Train-loss=159.619932
INFO:root:Epoch[199] Time cost=10.121

In [12]:
# encode test images to obtain mu and logvar which are used for sampling
[mu,logvar] = VAE.encoder(out,image_test)
#sample in the latent space
z = VAE.sampler(mu,logvar)
# decode from the latent space to obtain reconstructed images
x_construction = VAE.decoder(out,z)

In [13]:
f, ((ax1, ax2, ax3, ax4)) = plt.subplots(1,4,  sharex='col', sharey='row',figsize=(12,3))
ax1.imshow(np.reshape(image_test[0,:],(28,28)), interpolation='nearest', cmap=cm.Greys)
ax1.set_title('True image')
ax2.imshow(np.reshape(x_construction[0,:],(28,28)), interpolation='nearest', cmap=cm.Greys)
ax2.set_title('Learned image')
ax3.imshow(np.reshape(x_construction[999,:],(28,28)), interpolation='nearest', cmap=cm.Greys)
ax3.set_title('Learned image')
ax4.imshow(np.reshape(x_construction[9999,:],(28,28)), interpolation='nearest', cmap=cm.Greys)
ax4.set_title('Learned image')
plt.show()



In [78]:
z1 = z[:,0]
z2 = z[:,1]

fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(z1,z2,'ko')
plt.title("latent space")

#np.where((z1>3) & (z2<2) & (z2>0))
#select the points from the latent space
a_vec = [2,5,7,789,25,9993]
for i in range(len(a_vec)):
    ax.plot(z1[a_vec[i]],z2[a_vec[i]],'ro')  
    ax.annotate('z%d' %i, xy=(z1[a_vec[i]],z2[a_vec[i]]), 
                xytext=(z1[a_vec[i]],z2[a_vec[i]]),color = 'r',fontsize=15)


f, ((ax0, ax1, ax2, ax3, ax4,ax5)) = plt.subplots(1,6,  sharex='col', sharey='row',figsize=(12,2.5))
for i in range(len(a_vec)):
    eval('ax%d' %(i)).imshow(np.reshape(x_construction[a_vec[i],:],(28,28)), interpolation='nearest', cmap=cm.Greys)
    eval('ax%d' %(i)).set_title('z%d'%i)

plt.show()


Above is a plot of points in the 2D latent space and their corresponding decoded images, it can be seen that points that are close in the latent space get mapped to the same digit from the decoder, and we can see how it evolves from left to right.