In [1]:
import mxnet as mx
import numpy as np
import os
import logging
import matplotlib.pyplot as plt
import matplotlib.cm as cm

Building a Variational Autoencoder in MXNet

Xiaoyu Lu, July 5th, 2017

This tutorial guides you through the process of building a variational encoder in MXNet. In this notebook we'll focus on an example using the MNIST handwritten digit recognition dataset. Refer to Auto-Encoding Variational Bayes for more details on the model description.

Prerequisites

To complete this tutorial, we need following python packages:

  • numpy, matplotlib

1. Loading the Data

We first load the MNIST dataset, which contains 60000 training and 10000 test examples. The following code imports required modules and loads the data. These images are stored in a 4-D matrix with shape (batch_size, num_channels, width, height). For the MNIST dataset, there is only one color channel, and both width and height are 28, so we reshape each image as a 28x28 array. See below for a visualization:


In [2]:
mnist = mx.test_utils.get_mnist()
image = np.reshape(mnist['train_data'],(60000,28*28))
label = image
image_test = np.reshape(mnist['test_data'],(10000,28*28))
label_test = image_test
[N,features] = np.shape(image)          #number of examples and features
print(N,features)


60000 784

In [3]:
nsamples = 5
idx = np.random.choice(len(mnist['train_data']), nsamples)
_, axarr = plt.subplots(1, nsamples, sharex='col', sharey='row',figsize=(12,3))

for i,j in enumerate(idx):
    axarr[i].imshow(np.reshape(image[j,:],(28,28)), interpolation='nearest', cmap=cm.Greys)

plt.show()


We can optionally save the parameters in the directory variable 'model_prefix'. We first create data iterators for MXNet, with each batch of data containing 100 images.


In [4]:
model_prefix = None

batch_size = 100
latent_dim = 5
nd_iter = mx.io.NDArrayIter(data={'data':image},label={'loss_label':label},
                            batch_size = batch_size)
nd_iter_test = mx.io.NDArrayIter(data={'data':image_test},label={'loss_label':label_test},
                            batch_size = batch_size)

2. Building the Network Architecture

2.1 Gaussian MLP as encoder

Next we constuct the neural network, as in the paper, we use Multilayer Perceptron (MLP) for both the encoder and decoder. For encoder, a Gaussian MLP is used as follows:

\begin{align} \log q_{\phi}(z|x) &= \log \mathcal{N}(z:\mu,\sigma^2I) \\ \textit{ where } \mu &= W_2h+b_2, \log \sigma^2 = W_3h+b_3\\ h &= \tanh(W_1x+b_1) \end{align}

where $\{W_1,W_2,W_3,b_1,b_2,b_3\}$ are the weights and biases of the MLP. Note below that encoder_mu(mu) and encoder_logvar(logvar) are symbols. So, we can use get_internals() to get the values of them, after which we can sample the latent variable $z$.


In [5]:
## define data and loss labels as symbols 
data = mx.sym.var('data')
loss_label = mx.sym.var('loss_label')

## define fully connected and activation layers for the encoder, where we used tanh activation function.
encoder_h  = mx.sym.FullyConnected(data=data, name="encoder_h",num_hidden=400)
act_h = mx.sym.Activation(data=encoder_h, act_type="tanh",name="activation_h")

## define mu and log variance which are the fully connected layers of the previous activation layer
mu  = mx.sym.FullyConnected(data=act_h, name="mu",num_hidden = latent_dim)
logvar  = mx.sym.FullyConnected(data=act_h, name="logvar",num_hidden = latent_dim)

## sample the latent variables z according to Normal(mu,var)
z = mu + mx.symbol.broadcast_mul(mx.symbol.exp(0.5 * logvar), 
                                 mx.symbol.random_normal(loc=0, scale=1, shape=(batch_size, latent_dim)),
                                 name="z")

2.2 Bernoulli MLP as decoder

In this case let $p_\theta(x|z)$ be a multivariate Bernoulli whose probabilities are computed from $z$ with a feed forward neural network with a single hidden layer:

\begin{align} \log p(x|z) &= \sum_{i=1}^D x_i\log y_i + (1-x_i)\log (1-y_i) \\ \textit{ where } y &= f_\sigma(W_5\tanh (W_4z+b_4)+b_5) \end{align}

where $f_\sigma(\dot)$ is the elementwise sigmoid activation function, $\{W_4,W_5,b_4,b_5\}$ are the weights and biases of the decoder MLP. A Bernouilli likelihood is suitable for this type of data but you can easily extend it to other likelihood types by parsing into the argument likelihood in the VAE class, see section 4 for details.


In [6]:
# define fully connected and tanh activation layers for the decoder
decoder_z = mx.sym.FullyConnected(data=z, name="decoder_z",num_hidden=400)
act_z = mx.sym.Activation(data=decoder_z, act_type="tanh",name="activation_z")

# define the output layer with sigmoid activation function, where the dimension is equal to the input dimension
decoder_x = mx.sym.FullyConnected(data=act_z, name="decoder_x",num_hidden=features)
y = mx.sym.Activation(data=decoder_x, act_type="sigmoid",name='activation_x')

2.3 Joint Loss Function for the Encoder and the Decoder

The variational lower bound also called evidence lower bound (ELBO) can be estimated as:

\begin{align} \mathcal{L}(\theta,\phi;x_{(i)}) \approx \frac{1}{2}\left(1+\log ((\sigma_j^{(i)})^2)-(\mu_j^{(i)})^2-(\sigma_j^{(i)})^2\right) + \log p_\theta(x^{(i)}|z^{(i)}) \end{align}

where the first term is the KL divergence of the approximate posterior from the prior, and the second term is an expected negative reconstruction error. We would like to maximize this lower bound, so we can define the loss to be $-\mathcal{L}$(minus ELBO) for MXNet to minimize.


In [7]:
# define the objective loss function that needs to be minimized
KL = 0.5*mx.symbol.sum(1+logvar-pow( mu,2)-mx.symbol.exp(logvar),axis=1)
loss = -mx.symbol.sum(mx.symbol.broadcast_mul(loss_label,mx.symbol.log(y)) 
                      + mx.symbol.broadcast_mul(1-loss_label,mx.symbol.log(1-y)),axis=1)-KL
output = mx.symbol.MakeLoss(sum(loss),name='loss')

3. Training the model

Now, we can define the model and train it. First we will initilize the weights and the biases to be Gaussian(0,0.01), and then use stochastic gradient descent for optimization. To warm start the training, one may also initilize with pre-trainined parameters arg_params using init=mx.initializer.Load(arg_params).

To save intermediate results, we can optionally use epoch_end_callback = mx.callback.do_checkpoint(model_prefix, 1) which saves the parameters to the path given by model_prefix, and with period every $1$ epoch. To assess the performance, we output $-\mathcal{L}$(minus ELBO) after each epoch, with the command eval_metric = 'Loss' which is defined above. We will also plot the training loss for mini batches by accessing the log and saving it to a list, and then parsing it to the argument batch_end_callback.


In [8]:
# set up the log
nd_iter.reset()
logging.getLogger().setLevel(logging.DEBUG)  

# define function to trave back training loss
def log_to_list(period, lst):
    def _callback(param):
        """The checkpoint function."""
        if param.nbatch % period == 0:
            name, value = param.eval_metric.get()
            lst.append(value)
    return _callback

# define the model
model = mx.mod.Module(
    symbol = output ,
    data_names=['data'],
    label_names = ['loss_label'])

In [9]:
# training the model, save training loss as a list.
training_loss=list()

# initilize the parameters for training using Normal.
init = mx.init.Normal(0.01)
model.fit(nd_iter,  # train data
          initializer=init,
          # if eval_data is supplied, test loss will also be reported
          # eval_data = nd_iter_test,
          optimizer='sgd',  # use SGD to train
          optimizer_params={'learning_rate':1e-3,'wd':1e-2},  
          # save parameters for each epoch if model_prefix is supplied
          epoch_end_callback = None if model_prefix==None else mx.callback.do_checkpoint(model_prefix, 1),
          batch_end_callback = log_to_list(N/batch_size,training_loss), 
          num_epoch=100,
          eval_metric = 'Loss')


INFO:root:Epoch[0] Train-loss=373.547317
INFO:root:Epoch[0] Time cost=5.020
INFO:root:Epoch[1] Train-loss=212.232684
INFO:root:Epoch[1] Time cost=4.651
INFO:root:Epoch[2] Train-loss=207.448528
INFO:root:Epoch[2] Time cost=4.665
INFO:root:Epoch[3] Train-loss=205.369479
INFO:root:Epoch[3] Time cost=4.758
INFO:root:Epoch[4] Train-loss=203.651983
INFO:root:Epoch[4] Time cost=4.672
INFO:root:Epoch[5] Train-loss=202.061007
INFO:root:Epoch[5] Time cost=5.087
INFO:root:Epoch[6] Train-loss=199.348143
INFO:root:Epoch[6] Time cost=5.056
INFO:root:Epoch[7] Train-loss=196.266242
INFO:root:Epoch[7] Time cost=4.813
INFO:root:Epoch[8] Train-loss=194.694945
INFO:root:Epoch[8] Time cost=4.776
INFO:root:Epoch[9] Train-loss=193.699284
INFO:root:Epoch[9] Time cost=4.756
INFO:root:Epoch[10] Train-loss=193.036517
INFO:root:Epoch[10] Time cost=4.757
INFO:root:Epoch[11] Train-loss=192.555736
INFO:root:Epoch[11] Time cost=4.678
INFO:root:Epoch[12] Train-loss=192.020813
INFO:root:Epoch[12] Time cost=4.630
INFO:root:Epoch[13] Train-loss=191.648876
INFO:root:Epoch[13] Time cost=5.158
INFO:root:Epoch[14] Train-loss=191.057798
INFO:root:Epoch[14] Time cost=4.781
INFO:root:Epoch[15] Train-loss=190.315835
INFO:root:Epoch[15] Time cost=5.117
INFO:root:Epoch[16] Train-loss=189.311271
INFO:root:Epoch[16] Time cost=4.707
INFO:root:Epoch[17] Train-loss=187.285967
INFO:root:Epoch[17] Time cost=4.745
INFO:root:Epoch[18] Train-loss=185.271324
INFO:root:Epoch[18] Time cost=4.692
INFO:root:Epoch[19] Train-loss=183.510888
INFO:root:Epoch[19] Time cost=4.762
INFO:root:Epoch[20] Train-loss=181.756008
INFO:root:Epoch[20] Time cost=4.838
INFO:root:Epoch[21] Train-loss=180.546818
INFO:root:Epoch[21] Time cost=4.764
INFO:root:Epoch[22] Train-loss=179.479776
INFO:root:Epoch[22] Time cost=4.791
INFO:root:Epoch[23] Train-loss=178.352077
INFO:root:Epoch[23] Time cost=4.981
INFO:root:Epoch[24] Train-loss=177.385084
INFO:root:Epoch[24] Time cost=5.292
INFO:root:Epoch[25] Train-loss=175.920123
INFO:root:Epoch[25] Time cost=5.097
INFO:root:Epoch[26] Train-loss=174.377171
INFO:root:Epoch[26] Time cost=4.907
INFO:root:Epoch[27] Train-loss=172.590589
INFO:root:Epoch[27] Time cost=4.484
INFO:root:Epoch[28] Train-loss=170.933683
INFO:root:Epoch[28] Time cost=4.348
INFO:root:Epoch[29] Train-loss=169.866807
INFO:root:Epoch[29] Time cost=4.647
INFO:root:Epoch[30] Train-loss=169.182084
INFO:root:Epoch[30] Time cost=5.034
INFO:root:Epoch[31] Train-loss=168.121719
INFO:root:Epoch[31] Time cost=5.615
INFO:root:Epoch[32] Train-loss=167.389992
INFO:root:Epoch[32] Time cost=4.733
INFO:root:Epoch[33] Train-loss=166.189067
INFO:root:Epoch[33] Time cost=5.041
INFO:root:Epoch[34] Train-loss=163.783392
INFO:root:Epoch[34] Time cost=5.168
INFO:root:Epoch[35] Train-loss=162.167959
INFO:root:Epoch[35] Time cost=5.019
INFO:root:Epoch[36] Train-loss=161.192039
INFO:root:Epoch[36] Time cost=5.064
INFO:root:Epoch[37] Train-loss=160.307114
INFO:root:Epoch[37] Time cost=5.180
INFO:root:Epoch[38] Train-loss=159.591957
INFO:root:Epoch[38] Time cost=5.440
INFO:root:Epoch[39] Train-loss=159.109593
INFO:root:Epoch[39] Time cost=5.119
INFO:root:Epoch[40] Train-loss=158.463844
INFO:root:Epoch[40] Time cost=5.299
INFO:root:Epoch[41] Train-loss=158.037287
INFO:root:Epoch[41] Time cost=4.856
INFO:root:Epoch[42] Train-loss=157.598576
INFO:root:Epoch[42] Time cost=5.227
INFO:root:Epoch[43] Train-loss=157.097344
INFO:root:Epoch[43] Time cost=5.237
INFO:root:Epoch[44] Train-loss=156.594472
INFO:root:Epoch[44] Time cost=4.783
INFO:root:Epoch[45] Train-loss=156.177069
INFO:root:Epoch[45] Time cost=4.834
INFO:root:Epoch[46] Train-loss=155.825302
INFO:root:Epoch[46] Time cost=4.902
INFO:root:Epoch[47] Train-loss=155.318117
INFO:root:Epoch[47] Time cost=4.966
INFO:root:Epoch[48] Train-loss=154.890766
INFO:root:Epoch[48] Time cost=5.012
INFO:root:Epoch[49] Train-loss=154.504158
INFO:root:Epoch[49] Time cost=4.844
INFO:root:Epoch[50] Train-loss=154.035214
INFO:root:Epoch[50] Time cost=4.736
INFO:root:Epoch[51] Train-loss=153.692903
INFO:root:Epoch[51] Time cost=5.057
INFO:root:Epoch[52] Train-loss=153.257554
INFO:root:Epoch[52] Time cost=5.044
INFO:root:Epoch[53] Train-loss=152.849715
INFO:root:Epoch[53] Time cost=4.783
INFO:root:Epoch[54] Train-loss=152.483047
INFO:root:Epoch[54] Time cost=4.842
INFO:root:Epoch[55] Train-loss=152.091617
INFO:root:Epoch[55] Time cost=5.044
INFO:root:Epoch[56] Train-loss=151.715490
INFO:root:Epoch[56] Time cost=5.029
INFO:root:Epoch[57] Train-loss=151.362293
INFO:root:Epoch[57] Time cost=4.873
INFO:root:Epoch[58] Train-loss=151.003241
INFO:root:Epoch[58] Time cost=4.729
INFO:root:Epoch[59] Train-loss=150.619678
INFO:root:Epoch[59] Time cost=5.068
INFO:root:Epoch[60] Train-loss=150.296043
INFO:root:Epoch[60] Time cost=4.458
INFO:root:Epoch[61] Train-loss=149.964152
INFO:root:Epoch[61] Time cost=4.828
INFO:root:Epoch[62] Train-loss=149.694102
INFO:root:Epoch[62] Time cost=5.012
INFO:root:Epoch[63] Train-loss=149.290113
INFO:root:Epoch[63] Time cost=5.193
INFO:root:Epoch[64] Train-loss=148.934186
INFO:root:Epoch[64] Time cost=4.999
INFO:root:Epoch[65] Train-loss=148.657502
INFO:root:Epoch[65] Time cost=4.810
INFO:root:Epoch[66] Train-loss=148.331948
INFO:root:Epoch[66] Time cost=5.201
INFO:root:Epoch[67] Train-loss=148.018539
INFO:root:Epoch[67] Time cost=4.833
INFO:root:Epoch[68] Train-loss=147.746825
INFO:root:Epoch[68] Time cost=5.187
INFO:root:Epoch[69] Train-loss=147.406399
INFO:root:Epoch[69] Time cost=5.355
INFO:root:Epoch[70] Train-loss=147.181831
INFO:root:Epoch[70] Time cost=4.989
INFO:root:Epoch[71] Train-loss=146.860770
INFO:root:Epoch[71] Time cost=4.934
INFO:root:Epoch[72] Train-loss=146.604369
INFO:root:Epoch[72] Time cost=5.283
INFO:root:Epoch[73] Train-loss=146.351628
INFO:root:Epoch[73] Time cost=5.062
INFO:root:Epoch[74] Train-loss=146.102506
INFO:root:Epoch[74] Time cost=4.540
INFO:root:Epoch[75] Train-loss=145.828805
INFO:root:Epoch[75] Time cost=4.875
INFO:root:Epoch[76] Train-loss=145.571626
INFO:root:Epoch[76] Time cost=4.856
INFO:root:Epoch[77] Train-loss=145.365383
INFO:root:Epoch[77] Time cost=5.003
INFO:root:Epoch[78] Train-loss=145.101047
INFO:root:Epoch[78] Time cost=4.718
INFO:root:Epoch[79] Train-loss=144.810765
INFO:root:Epoch[79] Time cost=5.127
INFO:root:Epoch[80] Train-loss=144.619876
INFO:root:Epoch[80] Time cost=4.737
INFO:root:Epoch[81] Train-loss=144.399066
INFO:root:Epoch[81] Time cost=4.742
INFO:root:Epoch[82] Train-loss=144.220090
INFO:root:Epoch[82] Time cost=4.810
INFO:root:Epoch[83] Train-loss=143.904279
INFO:root:Epoch[83] Time cost=5.176
INFO:root:Epoch[84] Train-loss=143.734935
INFO:root:Epoch[84] Time cost=4.921
INFO:root:Epoch[85] Train-loss=143.499403
INFO:root:Epoch[85] Time cost=4.692
INFO:root:Epoch[86] Train-loss=143.304287
INFO:root:Epoch[86] Time cost=4.778
INFO:root:Epoch[87] Train-loss=143.096145
INFO:root:Epoch[87] Time cost=4.962
INFO:root:Epoch[88] Train-loss=142.877920
INFO:root:Epoch[88] Time cost=4.815
INFO:root:Epoch[89] Train-loss=142.677429
INFO:root:Epoch[89] Time cost=5.127
INFO:root:Epoch[90] Train-loss=142.499622
INFO:root:Epoch[90] Time cost=5.463
INFO:root:Epoch[91] Train-loss=142.300291
INFO:root:Epoch[91] Time cost=4.639
INFO:root:Epoch[92] Train-loss=142.111362
INFO:root:Epoch[92] Time cost=5.064
INFO:root:Epoch[93] Train-loss=141.912848
INFO:root:Epoch[93] Time cost=4.894
INFO:root:Epoch[94] Train-loss=141.723130
INFO:root:Epoch[94] Time cost=4.635
INFO:root:Epoch[95] Train-loss=141.516580
INFO:root:Epoch[95] Time cost=5.063
INFO:root:Epoch[96] Train-loss=141.362380
INFO:root:Epoch[96] Time cost=4.785
INFO:root:Epoch[97] Train-loss=141.178878
INFO:root:Epoch[97] Time cost=4.699
INFO:root:Epoch[98] Train-loss=141.004168
INFO:root:Epoch[98] Time cost=4.959
INFO:root:Epoch[99] Train-loss=140.865592
INFO:root:Epoch[99] Time cost=5.155

In [10]:
ELBO = [-training_loss[i] for i in range(len(training_loss))]
plt.plot(ELBO)
plt.ylabel('ELBO');plt.xlabel('epoch');plt.title("training curve for mini batches")
plt.show()


DEBUG:matplotlib.font_manager:findfont: Matching :family=sans-serif:style=normal:variant=normal:weight=normal:stretch=normal:size=12.0 to DejaVu Sans ('/usr/local/lib/python3.5/dist-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans.ttf') with score of 0.050000

As expected, the ELBO is monotonically increasing over epoch, and we reproduced the results given in the paper Auto-Encoding Variational Bayes. Now we can extract/load the parameters and then feed the network forward to calculate $y$ which is the reconstructed image, and we can also calculate the ELBO for the test set.


In [11]:
arg_params = model.get_params()[0]
nd_iter_test.reset()
test_batch = nd_iter_test.next()

# if saved the parameters, can load them using `load_checkpoint` method at e.g. 100th epoch
# sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 100)
# assert sym.tojson() == output.tojson()

e = y.bind(mx.cpu(), {'data': test_batch.data[0],
                     'encoder_h_weight': arg_params['encoder_h_weight'],
                     'encoder_h_bias': arg_params['encoder_h_bias'],
                     'mu_weight': arg_params['mu_weight'],
                     'mu_bias': arg_params['mu_bias'],
                     'logvar_weight':arg_params['logvar_weight'],
                     'logvar_bias':arg_params['logvar_bias'],
                     'decoder_z_weight':arg_params['decoder_z_weight'],
                     'decoder_z_bias':arg_params['decoder_z_bias'],
                     'decoder_x_weight':arg_params['decoder_x_weight'],
                     'decoder_x_bias':arg_params['decoder_x_bias'],                
                     'loss_label':label})

x_fit = e.forward()
x_construction = x_fit[0].asnumpy()

In [12]:
# learning images on the test set
f, ((ax1, ax2, ax3, ax4)) = plt.subplots(1,4,  sharex='col', sharey='row',figsize=(12,3))
ax1.imshow(np.reshape(image_test[0,:],(28,28)), interpolation='nearest', cmap=cm.Greys)
ax1.set_title('True image')
ax2.imshow(np.reshape(x_construction[0,:],(28,28)), interpolation='nearest', cmap=cm.Greys)
ax2.set_title('Learned image')
ax3.imshow(np.reshape(image_test[99,:],(28,28)), interpolation='nearest', cmap=cm.Greys)
ax3.set_title('True image')
ax4.imshow(np.reshape(x_construction[99,:],(28,28)), interpolation='nearest', cmap=cm.Greys)
ax4.set_title('Learned image')
plt.show()



In [13]:
# calculate the ELBO which is minus the loss for test set
metric = mx.gluon.metric.Loss()
model.score(nd_iter_test, metric)


Out[13]:
[('loss', 140.17346005859375)]

4. All together: MXNet-based class VAE


In [14]:
from VAE import VAE

One can directly call the class VAE to do the training:

VAE(n_latent=5,num_hidden_ecoder=400,num_hidden_decoder=400,x_train=None,x_valid=None,
batch_size=100,learning_rate=0.001,weight_decay=0.01,num_epoch=100,optimizer='sgd',model_prefix=None,
initializer = mx.init.Normal(0.01),likelihood=Bernoulli)

The outputs are the learned model and training loss.


In [15]:
# can initilize weights and biases with the learned parameters as follows: 
# init = mx.initializer.Load(params)

# call the VAE, output model contains the learned model and training loss
out = VAE(n_latent=2, x_train=image, x_valid=None, num_epoch=200)


INFO:root:Epoch[0] Train-loss=383.478870
INFO:root:Epoch[0] Time cost=5.075
INFO:root:Epoch[1] Train-loss=211.923867
INFO:root:Epoch[1] Time cost=4.741
INFO:root:Epoch[2] Train-loss=206.789445
INFO:root:Epoch[2] Time cost=4.601
INFO:root:Epoch[3] Train-loss=204.428186
INFO:root:Epoch[3] Time cost=4.865
INFO:root:Epoch[4] Train-loss=202.417322
INFO:root:Epoch[4] Time cost=4.606
INFO:root:Epoch[5] Train-loss=200.635136
INFO:root:Epoch[5] Time cost=4.711
INFO:root:Epoch[6] Train-loss=199.009614
INFO:root:Epoch[6] Time cost=5.159
INFO:root:Epoch[7] Train-loss=197.565788
INFO:root:Epoch[7] Time cost=4.588
INFO:root:Epoch[8] Train-loss=196.524507
INFO:root:Epoch[8] Time cost=4.905
INFO:root:Epoch[9] Train-loss=195.725745
INFO:root:Epoch[9] Time cost=4.426
INFO:root:Epoch[10] Train-loss=194.902025
INFO:root:Epoch[10] Time cost=4.685
INFO:root:Epoch[11] Train-loss=194.026873
INFO:root:Epoch[11] Time cost=4.622
INFO:root:Epoch[12] Train-loss=193.350646
INFO:root:Epoch[12] Time cost=4.712
INFO:root:Epoch[13] Train-loss=192.737502
INFO:root:Epoch[13] Time cost=4.618
INFO:root:Epoch[14] Train-loss=192.338165
INFO:root:Epoch[14] Time cost=4.763
INFO:root:Epoch[15] Train-loss=191.888625
INFO:root:Epoch[15] Time cost=5.168
INFO:root:Epoch[16] Train-loss=191.170650
INFO:root:Epoch[16] Time cost=4.809
INFO:root:Epoch[17] Train-loss=190.307264
INFO:root:Epoch[17] Time cost=4.622
INFO:root:Epoch[18] Train-loss=188.988063
INFO:root:Epoch[18] Time cost=4.543
INFO:root:Epoch[19] Train-loss=187.616311
INFO:root:Epoch[19] Time cost=5.154
INFO:root:Epoch[20] Train-loss=186.352783
INFO:root:Epoch[20] Time cost=4.661
INFO:root:Epoch[21] Train-loss=185.428020
INFO:root:Epoch[21] Time cost=5.193
INFO:root:Epoch[22] Train-loss=184.543097
INFO:root:Epoch[22] Time cost=4.519
INFO:root:Epoch[23] Train-loss=184.029907
INFO:root:Epoch[23] Time cost=4.732
INFO:root:Epoch[24] Train-loss=183.643270
INFO:root:Epoch[24] Time cost=5.011
INFO:root:Epoch[25] Train-loss=183.246912
INFO:root:Epoch[25] Time cost=4.706
INFO:root:Epoch[26] Train-loss=183.065233
INFO:root:Epoch[26] Time cost=4.673
INFO:root:Epoch[27] Train-loss=182.680542
INFO:root:Epoch[27] Time cost=4.628
INFO:root:Epoch[28] Train-loss=182.428677
INFO:root:Epoch[28] Time cost=4.772
INFO:root:Epoch[29] Train-loss=182.219946
INFO:root:Epoch[29] Time cost=4.571
INFO:root:Epoch[30] Train-loss=182.070927
INFO:root:Epoch[30] Time cost=4.603
INFO:root:Epoch[31] Train-loss=181.837968
INFO:root:Epoch[31] Time cost=4.559
INFO:root:Epoch[32] Train-loss=181.624303
INFO:root:Epoch[32] Time cost=5.069
INFO:root:Epoch[33] Train-loss=181.534547
INFO:root:Epoch[33] Time cost=4.654
INFO:root:Epoch[34] Train-loss=181.239556
INFO:root:Epoch[34] Time cost=4.776
INFO:root:Epoch[35] Train-loss=181.098188
INFO:root:Epoch[35] Time cost=4.571
INFO:root:Epoch[36] Train-loss=180.820560
INFO:root:Epoch[36] Time cost=4.815
INFO:root:Epoch[37] Train-loss=180.828095
INFO:root:Epoch[37] Time cost=4.455
INFO:root:Epoch[38] Train-loss=180.495569
INFO:root:Epoch[38] Time cost=5.096
INFO:root:Epoch[39] Train-loss=180.389106
INFO:root:Epoch[39] Time cost=4.797
INFO:root:Epoch[40] Train-loss=180.200965
INFO:root:Epoch[40] Time cost=5.054
INFO:root:Epoch[41] Train-loss=179.851014
INFO:root:Epoch[41] Time cost=4.642
INFO:root:Epoch[42] Train-loss=179.719933
INFO:root:Epoch[42] Time cost=4.603
INFO:root:Epoch[43] Train-loss=179.431740
INFO:root:Epoch[43] Time cost=4.341
INFO:root:Epoch[44] Train-loss=179.235384
INFO:root:Epoch[44] Time cost=4.638
INFO:root:Epoch[45] Train-loss=179.108771
INFO:root:Epoch[45] Time cost=4.754
INFO:root:Epoch[46] Train-loss=178.714163
INFO:root:Epoch[46] Time cost=4.457
INFO:root:Epoch[47] Train-loss=178.508338
INFO:root:Epoch[47] Time cost=4.960
INFO:root:Epoch[48] Train-loss=178.288002
INFO:root:Epoch[48] Time cost=4.562
INFO:root:Epoch[49] Train-loss=178.083288
INFO:root:Epoch[49] Time cost=4.619
INFO:root:Epoch[50] Train-loss=177.791330
INFO:root:Epoch[50] Time cost=4.580
INFO:root:Epoch[51] Train-loss=177.570741
INFO:root:Epoch[51] Time cost=4.704
INFO:root:Epoch[52] Train-loss=177.287114
INFO:root:Epoch[52] Time cost=5.172
INFO:root:Epoch[53] Train-loss=177.122645
INFO:root:Epoch[53] Time cost=4.678
INFO:root:Epoch[54] Train-loss=176.816022
INFO:root:Epoch[54] Time cost=4.819
INFO:root:Epoch[55] Train-loss=176.670484
INFO:root:Epoch[55] Time cost=4.568
INFO:root:Epoch[56] Train-loss=176.459671
INFO:root:Epoch[56] Time cost=4.450
INFO:root:Epoch[57] Train-loss=176.174175
INFO:root:Epoch[57] Time cost=4.579
INFO:root:Epoch[58] Train-loss=175.935856
INFO:root:Epoch[58] Time cost=4.552
INFO:root:Epoch[59] Train-loss=175.739928
INFO:root:Epoch[59] Time cost=4.385
INFO:root:Epoch[60] Train-loss=175.579695
INFO:root:Epoch[60] Time cost=4.496
INFO:root:Epoch[61] Train-loss=175.403871
INFO:root:Epoch[61] Time cost=5.088
INFO:root:Epoch[62] Train-loss=175.157114
INFO:root:Epoch[62] Time cost=4.628
INFO:root:Epoch[63] Train-loss=174.953950
INFO:root:Epoch[63] Time cost=4.826
INFO:root:Epoch[64] Train-loss=174.743393
INFO:root:Epoch[64] Time cost=4.832
INFO:root:Epoch[65] Train-loss=174.554056
INFO:root:Epoch[65] Time cost=4.375
INFO:root:Epoch[66] Train-loss=174.366719
INFO:root:Epoch[66] Time cost=4.583
INFO:root:Epoch[67] Train-loss=174.160622
INFO:root:Epoch[67] Time cost=4.586
INFO:root:Epoch[68] Train-loss=173.981699
INFO:root:Epoch[68] Time cost=5.149
INFO:root:Epoch[69] Train-loss=173.751617
INFO:root:Epoch[69] Time cost=4.495
INFO:root:Epoch[70] Train-loss=173.548732
INFO:root:Epoch[70] Time cost=4.588
INFO:root:Epoch[71] Train-loss=173.380950
INFO:root:Epoch[71] Time cost=5.042
INFO:root:Epoch[72] Train-loss=173.158519
INFO:root:Epoch[72] Time cost=4.817
INFO:root:Epoch[73] Train-loss=172.970726
INFO:root:Epoch[73] Time cost=4.791
INFO:root:Epoch[74] Train-loss=172.782357
INFO:root:Epoch[74] Time cost=4.377
INFO:root:Epoch[75] Train-loss=172.581992
INFO:root:Epoch[75] Time cost=4.518
INFO:root:Epoch[76] Train-loss=172.385020
INFO:root:Epoch[76] Time cost=4.863
INFO:root:Epoch[77] Train-loss=172.198309
INFO:root:Epoch[77] Time cost=5.104
INFO:root:Epoch[78] Train-loss=172.022333
INFO:root:Epoch[78] Time cost=4.571
INFO:root:Epoch[79] Train-loss=171.816585
INFO:root:Epoch[79] Time cost=4.557
INFO:root:Epoch[80] Train-loss=171.643714
INFO:root:Epoch[80] Time cost=4.567
INFO:root:Epoch[81] Train-loss=171.460581
INFO:root:Epoch[81] Time cost=4.735
INFO:root:Epoch[82] Train-loss=171.284854
INFO:root:Epoch[82] Time cost=5.012
INFO:root:Epoch[83] Train-loss=171.113129
INFO:root:Epoch[83] Time cost=4.877
INFO:root:Epoch[84] Train-loss=170.947790
INFO:root:Epoch[84] Time cost=4.487
INFO:root:Epoch[85] Train-loss=170.766223
INFO:root:Epoch[85] Time cost=4.723
INFO:root:Epoch[86] Train-loss=170.602559
INFO:root:Epoch[86] Time cost=4.803
INFO:root:Epoch[87] Train-loss=170.448713
INFO:root:Epoch[87] Time cost=4.636
INFO:root:Epoch[88] Train-loss=170.273053
INFO:root:Epoch[88] Time cost=4.562
INFO:root:Epoch[89] Train-loss=170.099485
INFO:root:Epoch[89] Time cost=4.567
INFO:root:Epoch[90] Train-loss=169.934289
INFO:root:Epoch[90] Time cost=4.905
INFO:root:Epoch[91] Train-loss=169.768920
INFO:root:Epoch[91] Time cost=4.636
INFO:root:Epoch[92] Train-loss=169.620803
INFO:root:Epoch[92] Time cost=4.429
INFO:root:Epoch[93] Train-loss=169.448189
INFO:root:Epoch[93] Time cost=4.985
INFO:root:Epoch[94] Train-loss=169.295794
INFO:root:Epoch[94] Time cost=4.649
INFO:root:Epoch[95] Train-loss=169.143627
INFO:root:Epoch[95] Time cost=4.602
INFO:root:Epoch[96] Train-loss=168.989410
INFO:root:Epoch[96] Time cost=4.904
INFO:root:Epoch[97] Train-loss=168.841089
INFO:root:Epoch[97] Time cost=4.602
INFO:root:Epoch[98] Train-loss=168.694906
INFO:root:Epoch[98] Time cost=4.589
INFO:root:Epoch[99] Train-loss=168.527604
INFO:root:Epoch[99] Time cost=4.560
INFO:root:Epoch[100] Train-loss=168.385596
INFO:root:Epoch[100] Time cost=4.835
INFO:root:Epoch[101] Train-loss=168.246526
INFO:root:Epoch[101] Time cost=4.558
INFO:root:Epoch[102] Train-loss=168.093663
INFO:root:Epoch[102] Time cost=4.609
INFO:root:Epoch[103] Train-loss=167.938807
INFO:root:Epoch[103] Time cost=4.599
INFO:root:Epoch[104] Train-loss=167.814916
INFO:root:Epoch[104] Time cost=4.394
INFO:root:Epoch[105] Train-loss=167.676473
INFO:root:Epoch[105] Time cost=4.724
INFO:root:Epoch[106] Train-loss=167.560241
INFO:root:Epoch[106] Time cost=4.316
INFO:root:Epoch[107] Train-loss=167.424132
INFO:root:Epoch[107] Time cost=4.646
INFO:root:Epoch[108] Train-loss=167.284482
INFO:root:Epoch[108] Time cost=4.472
INFO:root:Epoch[109] Train-loss=167.184511
INFO:root:Epoch[109] Time cost=4.768
INFO:root:Epoch[110] Train-loss=167.037793
INFO:root:Epoch[110] Time cost=4.717
INFO:root:Epoch[111] Train-loss=166.916652
INFO:root:Epoch[111] Time cost=4.803
INFO:root:Epoch[112] Train-loss=166.796803
INFO:root:Epoch[112] Time cost=4.617
INFO:root:Epoch[113] Train-loss=166.655028
INFO:root:Epoch[113] Time cost=4.420
INFO:root:Epoch[114] Train-loss=166.561129
INFO:root:Epoch[114] Time cost=4.333
INFO:root:Epoch[115] Train-loss=166.434593
INFO:root:Epoch[115] Time cost=4.526
INFO:root:Epoch[116] Train-loss=166.322805
INFO:root:Epoch[116] Time cost=4.310
INFO:root:Epoch[117] Train-loss=166.195452
INFO:root:Epoch[117] Time cost=4.458
INFO:root:Epoch[118] Train-loss=166.073792
INFO:root:Epoch[118] Time cost=4.333
INFO:root:Epoch[119] Train-loss=165.967437
INFO:root:Epoch[119] Time cost=4.459
INFO:root:Epoch[120] Train-loss=165.876094
INFO:root:Epoch[120] Time cost=5.070
INFO:root:Epoch[121] Train-loss=165.748064
INFO:root:Epoch[121] Time cost=4.782
INFO:root:Epoch[122] Train-loss=165.656283
INFO:root:Epoch[122] Time cost=4.640
INFO:root:Epoch[123] Train-loss=165.540462
INFO:root:Epoch[123] Time cost=4.522
INFO:root:Epoch[124] Train-loss=165.448734
INFO:root:Epoch[124] Time cost=4.858
INFO:root:Epoch[125] Train-loss=165.347751
INFO:root:Epoch[125] Time cost=4.842
INFO:root:Epoch[126] Train-loss=165.230048
INFO:root:Epoch[126] Time cost=4.495
INFO:root:Epoch[127] Train-loss=165.147932
INFO:root:Epoch[127] Time cost=4.766
INFO:root:Epoch[128] Train-loss=165.036021
INFO:root:Epoch[128] Time cost=4.526
INFO:root:Epoch[129] Train-loss=164.977613
INFO:root:Epoch[129] Time cost=5.091
INFO:root:Epoch[130] Train-loss=164.881467
INFO:root:Epoch[130] Time cost=5.223
INFO:root:Epoch[131] Train-loss=164.785627
INFO:root:Epoch[131] Time cost=4.165
INFO:root:Epoch[132] Train-loss=164.707629
INFO:root:Epoch[132] Time cost=4.527
INFO:root:Epoch[133] Train-loss=164.598039
INFO:root:Epoch[133] Time cost=4.167
INFO:root:Epoch[134] Train-loss=164.502932
INFO:root:Epoch[134] Time cost=4.354
INFO:root:Epoch[135] Train-loss=164.422286
INFO:root:Epoch[135] Time cost=4.387
INFO:root:Epoch[136] Train-loss=164.344749
INFO:root:Epoch[136] Time cost=4.662
INFO:root:Epoch[137] Train-loss=164.264898
INFO:root:Epoch[137] Time cost=4.671
INFO:root:Epoch[138] Train-loss=164.178707
INFO:root:Epoch[138] Time cost=4.776
INFO:root:Epoch[139] Train-loss=164.109071
INFO:root:Epoch[139] Time cost=4.787
INFO:root:Epoch[140] Train-loss=163.993291
INFO:root:Epoch[140] Time cost=4.726
INFO:root:Epoch[141] Train-loss=163.956234
INFO:root:Epoch[141] Time cost=4.337
INFO:root:Epoch[142] Train-loss=163.845638
INFO:root:Epoch[142] Time cost=4.787
INFO:root:Epoch[143] Train-loss=163.790882
INFO:root:Epoch[143] Time cost=5.563
INFO:root:Epoch[144] Train-loss=163.723495
INFO:root:Epoch[144] Time cost=4.529
INFO:root:Epoch[145] Train-loss=163.634262
INFO:root:Epoch[145] Time cost=5.028
INFO:root:Epoch[146] Train-loss=163.552854
INFO:root:Epoch[146] Time cost=4.933
INFO:root:Epoch[147] Train-loss=163.501429
INFO:root:Epoch[147] Time cost=4.912
INFO:root:Epoch[148] Train-loss=163.444245
INFO:root:Epoch[148] Time cost=5.034
INFO:root:Epoch[149] Train-loss=163.348476
INFO:root:Epoch[149] Time cost=4.600
INFO:root:Epoch[150] Train-loss=163.256955
INFO:root:Epoch[150] Time cost=4.704
INFO:root:Epoch[151] Train-loss=163.216139
INFO:root:Epoch[151] Time cost=4.670
INFO:root:Epoch[152] Train-loss=163.144691
INFO:root:Epoch[152] Time cost=4.678
INFO:root:Epoch[153] Train-loss=163.050236
INFO:root:Epoch[153] Time cost=4.595
INFO:root:Epoch[154] Train-loss=162.991225
INFO:root:Epoch[154] Time cost=5.307
INFO:root:Epoch[155] Train-loss=162.907200
INFO:root:Epoch[155] Time cost=4.684
INFO:root:Epoch[156] Train-loss=162.838075
INFO:root:Epoch[156] Time cost=4.686
INFO:root:Epoch[157] Train-loss=162.759286
INFO:root:Epoch[157] Time cost=4.750
INFO:root:Epoch[158] Train-loss=162.725998
INFO:root:Epoch[158] Time cost=4.637
INFO:root:Epoch[159] Train-loss=162.635852
INFO:root:Epoch[159] Time cost=4.498
INFO:root:Epoch[160] Train-loss=162.563777
INFO:root:Epoch[160] Time cost=5.048
INFO:root:Epoch[161] Train-loss=162.527387
INFO:root:Epoch[161] Time cost=5.040
INFO:root:Epoch[162] Train-loss=162.395881
INFO:root:Epoch[162] Time cost=4.764
INFO:root:Epoch[163] Train-loss=162.353654
INFO:root:Epoch[163] Time cost=4.561
INFO:root:Epoch[164] Train-loss=162.285584
INFO:root:Epoch[164] Time cost=5.051
INFO:root:Epoch[165] Train-loss=162.204332
INFO:root:Epoch[165] Time cost=4.455
INFO:root:Epoch[166] Train-loss=162.147100
INFO:root:Epoch[166] Time cost=5.021
INFO:root:Epoch[167] Train-loss=162.051296
INFO:root:Epoch[167] Time cost=4.551
INFO:root:Epoch[168] Train-loss=161.978708
INFO:root:Epoch[168] Time cost=4.744
INFO:root:Epoch[169] Train-loss=161.927990
INFO:root:Epoch[169] Time cost=4.821
INFO:root:Epoch[170] Train-loss=161.883088
INFO:root:Epoch[170] Time cost=4.365
INFO:root:Epoch[171] Train-loss=161.785367
INFO:root:Epoch[171] Time cost=4.448
INFO:root:Epoch[172] Train-loss=161.716386
INFO:root:Epoch[172] Time cost=4.622
INFO:root:Epoch[173] Train-loss=161.656391
INFO:root:Epoch[173] Time cost=4.500
INFO:root:Epoch[174] Train-loss=161.598127
INFO:root:Epoch[174] Time cost=4.677
INFO:root:Epoch[175] Train-loss=161.518613
INFO:root:Epoch[175] Time cost=4.958
INFO:root:Epoch[176] Train-loss=161.418783
INFO:root:Epoch[176] Time cost=4.607
INFO:root:Epoch[177] Train-loss=161.407767
INFO:root:Epoch[177] Time cost=4.427
INFO:root:Epoch[178] Train-loss=161.319552
INFO:root:Epoch[178] Time cost=4.930
INFO:root:Epoch[179] Train-loss=161.234087
INFO:root:Epoch[179] Time cost=4.240
INFO:root:Epoch[180] Train-loss=161.187404
INFO:root:Epoch[180] Time cost=4.484
INFO:root:Epoch[181] Train-loss=161.123118
INFO:root:Epoch[181] Time cost=4.937
INFO:root:Epoch[182] Train-loss=160.999420
INFO:root:Epoch[182] Time cost=4.489
INFO:root:Epoch[183] Train-loss=160.955369
INFO:root:Epoch[183] Time cost=4.894
INFO:root:Epoch[184] Train-loss=160.908542
INFO:root:Epoch[184] Time cost=4.269
INFO:root:Epoch[185] Train-loss=160.846908
INFO:root:Epoch[185] Time cost=4.998
INFO:root:Epoch[186] Train-loss=160.765964
INFO:root:Epoch[186] Time cost=4.467
INFO:root:Epoch[187] Train-loss=160.687773
INFO:root:Epoch[187] Time cost=4.609
INFO:root:Epoch[188] Train-loss=160.652674
INFO:root:Epoch[188] Time cost=5.327
INFO:root:Epoch[189] Train-loss=160.551175
INFO:root:Epoch[189] Time cost=4.267
INFO:root:Epoch[190] Train-loss=160.477424
INFO:root:Epoch[190] Time cost=4.798
INFO:root:Epoch[191] Train-loss=160.501221
INFO:root:Epoch[191] Time cost=4.695
INFO:root:Epoch[192] Train-loss=160.370335
INFO:root:Epoch[192] Time cost=4.640
INFO:root:Epoch[193] Train-loss=160.279749
INFO:root:Epoch[193] Time cost=4.653
INFO:root:Epoch[194] Train-loss=160.242415
INFO:root:Epoch[194] Time cost=5.044
INFO:root:Epoch[195] Train-loss=160.197063
INFO:root:Epoch[195] Time cost=4.684
INFO:root:Epoch[196] Train-loss=160.132983
INFO:root:Epoch[196] Time cost=4.460
INFO:root:Epoch[197] Train-loss=160.083149
INFO:root:Epoch[197] Time cost=4.713
INFO:root:Epoch[198] Train-loss=160.025012
INFO:root:Epoch[198] Time cost=4.779
INFO:root:Epoch[199] Train-loss=159.945513
INFO:root:Epoch[199] Time cost=4.659

In [16]:
# encode test images to obtain mu and logvar which are used for sampling
[mu,logvar] = VAE.encoder(out,image_test)
# sample in the latent space
z = VAE.sampler(mu,logvar)
# decode from the latent space to obtain reconstructed images
x_construction = VAE.decoder(out,z)

In [17]:
f, ((ax1, ax2, ax3, ax4)) = plt.subplots(1,4,  sharex='col', sharey='row',figsize=(12,3))
ax1.imshow(np.reshape(image_test[0,:],(28,28)), interpolation='nearest', cmap=cm.Greys)
ax1.set_title('True image')
ax2.imshow(np.reshape(x_construction[0,:],(28,28)), interpolation='nearest', cmap=cm.Greys)
ax2.set_title('Learned image')
ax3.imshow(np.reshape(image_test[146,:],(28,28)), interpolation='nearest', cmap=cm.Greys)
ax3.set_title('True image')
ax4.imshow(np.reshape(x_construction[146,:],(28,28)), interpolation='nearest', cmap=cm.Greys)
ax4.set_title('Learned image')
plt.show()



In [18]:
z1 = z[:,0]
z2 = z[:,1]

fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(z1,z2,'ko')
plt.title("latent space")

#np.where((z1>3) & (z2<2) & (z2>0))
#select the points from the latent space
a_vec = [2,5,7,789,25,9993]
for i in range(len(a_vec)):
    ax.plot(z1[a_vec[i]],z2[a_vec[i]],'ro')  
    ax.annotate('z%d' %i, xy=(z1[a_vec[i]],z2[a_vec[i]]), 
                xytext=(z1[a_vec[i]],z2[a_vec[i]]),color = 'r',fontsize=15)


f, ((ax0, ax1, ax2, ax3, ax4,ax5)) = plt.subplots(1,6,  sharex='col', sharey='row',figsize=(12,2.5))
for i in range(len(a_vec)):
    eval('ax%d' %(i)).imshow(np.reshape(x_construction[a_vec[i],:],(28,28)), interpolation='nearest', cmap=cm.Greys)
    eval('ax%d' %(i)).set_title('z%d'%i)

plt.show()


DEBUG:matplotlib.font_manager:findfont: Matching :family=sans-serif:style=normal:variant=normal:weight=normal:stretch=normal:size=15.0 to DejaVu Sans ('/usr/local/lib/python3.5/dist-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans.ttf') with score of 0.050000

Above is a plot of points in the 2D latent space and their corresponding decoded images, it can be seen that points that are close in the latent space get mapped to the same digit from the decoder, and we can see how it evolves from left to right.