In this tutorial, we will show you how to implement variational auto-encoder (VAE) in NNabla, which can be a powerful model for unsupervised learning by estimating variational lower bound. VAE can also be extended to many applications, such as image generation, as demonstrated by VQ-VAE.

We will use MNIST for our tutorial. Although MNIST is fully labeled, we will assume a setting in which the labels are missing.


In [ ]:
# If you run this notebook on Google Colab, uncomment and run the following to set up dependencies.
# !pip install nnabla-ext-cuda100
# !git clone https://github.com/sony/nnabla-examples.git
# %cd nnabla-examples

As always, let's start by importing dependencies.


In [ ]:
import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.monitor as M
import nnabla.solver as S
from nnabla.logger import logger
from utils.neu.save_nnp import save_nnp

import nnabla.utils.save as save
import numpy as np
import time
import os

Let's also define data iterator for MNIST. You can disregard the details for now.


In [ ]:
import struct
import zlib

from nnabla.logger import logger
from nnabla.utils.data_iterator import data_iterator
from nnabla.utils.data_source import DataSource
from nnabla.utils.data_source_loader import download


def load_mnist(train=True):
    '''
    Load MNIST dataset images and labels from the original page by Yan LeCun or the cache file.
    Args:
        train (bool): The testing dataset will be returned if False. Training data has 60000 images, while testing has 10000 images.
    Returns:
        numpy.ndarray: A shape of (#images, 1, 28, 28). Values in [0.0, 1.0].
        numpy.ndarray: A shape of (#images, 1). Values in {0, 1, ..., 9}.
    '''
    if train:
        image_uri = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
        label_uri = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
    else:
        image_uri = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'
        label_uri = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'
    logger.info('Getting label data from {}.'.format(label_uri))
    r = download(label_uri)
    data = zlib.decompress(r.read(), zlib.MAX_WBITS | 32)
    _, size = struct.unpack('>II', data[0:8])
    labels = np.frombuffer(data[8:], np.uint8).reshape(-1, 1)
    r.close()
    logger.info('Getting label data done.')

    logger.info('Getting image data from {}.'.format(image_uri))
    r = download(image_uri)
    data = zlib.decompress(r.read(), zlib.MAX_WBITS | 32)
    _, size, height, width = struct.unpack('>IIII', data[0:16])
    images = np.frombuffer(data[16:], np.uint8).reshape(
        size, 1, height, width)
    r.close()
    logger.info('Getting image data done.')

    return images, labels

class MnistDataSource(DataSource):
    '''
    Get data directly from MNIST dataset from Internet(yann.lecun.com).
    '''

    def _get_data(self, position):
        image = self._images[self._indexes[position]]
        label = self._labels[self._indexes[position]]
        return (image, label)

    def __init__(self, train=True, shuffle=False, rng=None):
        super(MnistDataSource, self).__init__(shuffle=shuffle)
        self._train = train

        self._images, self._labels = load_mnist(train)

        self._size = self._labels.size
        self._variables = ('x', 'y')
        if rng is None:
            rng = np.random.RandomState(313)
        self.rng = rng
        self.reset()

    def reset(self):
        if self._shuffle:
            self._indexes = self.rng.permutation(self._size)
        else:
            self._indexes = np.arange(self._size)
        super(MnistDataSource, self).reset()

    @property
    def images(self):
        """Get copy of whole data with a shape of (N, 1, H, W)."""
        return self._images.copy()

    @property
    def labels(self):
        """Get copy of whole label with a shape of (N, 1)."""
        return self._labels.copy()

def data_iterator_mnist(batch_size,
                        train=True,
                        rng=None,
                        shuffle=True,
                        with_memory_cache=False,
                        with_file_cache=False):
    '''
    Provide DataIterator with :py:class:`MnistDataSource`
    with_memory_cache and with_file_cache option's default value is all False,
    because :py:class:`MnistDataSource` is able to store all data into memory.
    For example,
    .. code-block:: python
        with data_iterator_mnist(True, batch_size) as di:
            for data in di:
                SOME CODE TO USE data.
    '''
    return data_iterator(MnistDataSource(train=train, shuffle=shuffle, rng=rng),
                         batch_size,
                         rng,
                         with_memory_cache,
                         with_file_cache)

We now define a function for our VAE network, which is the most essential part of this tutorial. This function will calculate Elbo (evidence lower-bound) loss. Note that this sample employs a Bernoulli generator version.

Encoder:

We first define the encoder network. We first normalize the input, and stack two fully-connected (affine) layers, each of which is followed by ELU non-linear activation function (the original paper uses Softplus activation function). We can then calculate mu and sigma from the output of two-stack affine layers, which are used to compute z. Note that during training, random noise is introduced to z, whereas it is set to be equal to mu for test.

Decoder:

We now define our decoder network, where we also stack two fully-connected layers followed by ELU, but this time the input is z computed from the encoder network. By applying another fully-connected layer to the output of two-stack fully-connected layers, we have the Bernoulli probabilities for each pixel.

Elbo Loss:

We first binarize the normalized input. Then we estimate the expectations over 3 likelihoods, namely q(z|x), p(z), and p(x|z). Finally, from the expectation estimates, we can compute VAE loss, by computing the negative evidence lower bound.


In [ ]:
def vae(x, shape_z, test=False):
    """
    Args:
        x(`~nnabla.Variable`): N-D array
        shape_z(tuple of int): size of z
        test : True=train, False=test
    Returns:
        ~nnabla.Variable: Elbo loss
    """

    #############################################
    # Encoder of 2 fully connected layers       #
    #############################################

    # Normalize input
    xa = x / 256.
    batch_size = x.shape[0]

    # 2 fully connected layers, and Elu replaced from original Softplus.
    h = F.elu(PF.affine(xa, (500,), name='fc1'))
    h = F.elu(PF.affine(h, (500,), name='fc2'))

    # The outputs are the parameters of Gauss probability density.
    mu = PF.affine(h, shape_z, name='fc_mu')
    logvar = PF.affine(h, shape_z, name='fc_logvar')
    sigma = F.exp(0.5 * logvar)

    # The prior variable and the reparameterization trick
    if not test:
        # training with reparameterization trick
        epsilon = F.randn(mu=0, sigma=1, shape=(batch_size,) + shape_z)
        z = mu + sigma * epsilon
    else:
        # test without randomness
        z = mu

    #############################################
    # Decoder of 2 fully connected layers       #
    #############################################

    # 2 fully connected layers, and Elu replaced from original Softplus.
    h = F.elu(PF.affine(z, (500,), name='fc3'))
    h = F.elu(PF.affine(h, (500,), name='fc4'))

    # The outputs are the parameters of Bernoulli probabilities for each pixel.
    prob = PF.affine(h, (1, 28, 28), name='fc5')

    #############################################
    # Elbo components and loss objective        #
    #############################################

    # Binarized input
    xb = F.greater_equal_scalar(xa, 0.5)

    # E_q(z|x)[log(q(z|x))]
    # without some constant terms that will be canceled after summation of loss
    logqz = 0.5 * F.sum(1.0 + logvar, axis=1)

    # E_q(z|x)[log(p(z))]
    # without some constant terms that will be canceled after summation of loss
    logpz = 0.5 * F.sum(mu * mu + sigma * sigma, axis=1)

    # E_q(z|x)[log(p(x|z))]
    logpx = F.sum(F.sigmoid_cross_entropy(prob, xb), axis=(1, 2, 3))

    # Vae loss, the negative evidence lowerbound
    loss = F.mean(logpx + logpz - logqz)

    return loss

Before we start training, let's set our context to use GPU and define data iterator variables for training and test.


In [ ]:
from nnabla.ext_utils import get_extension_context
ctx = get_extension_context('cudnn')
nn.set_default_context(ctx)

batch_size = 100

# Initialize data provider
di_l = data_iterator_mnist(batch_size, True)
di_t = data_iterator_mnist(batch_size, False)

We now define the input variable and loss variables for our graph. We define two losses for training and test respectively, which are computed from the VAE network we defined above.

Let's also set solver and monitor variables to track the progress of training. We use Adam as our solver (which by the way is from the same author as VAE!).


In [ ]:
# Network
shape_x = (1, 28, 28)
shape_z = (50,)
x = nn.Variable((batch_size,) + shape_x)
loss_l = vae(x, shape_z, test=False)
loss_t = vae(x, shape_z, test=True)

# Create solver
learning_rate = 3e-4
solver = S.Adam(learning_rate)
solver.set_parameters(nn.get_parameters())

# Monitors for training and validation
model_save_path = 'tmp.monitor.vae'
monitor = M.Monitor(model_save_path)
monitor_training_loss = M.MonitorSeries(
    "Training loss", monitor, interval=600)
monitor_test_loss = M.MonitorSeries("Test loss", monitor, interval=600)
monitor_time = M.MonitorTimeElapsed("Elapsed time", monitor, interval=600)

We now perform training, where we iterateively retrieve the data, compute loss from VAE network, and update the solver. Note that the labels provided by the data iterators are disregarded by setting garbage variables _.

By the end of the training, you will have test loss of approximately 85.


In [ ]:
# Training Loop.
max_iter = 60000
weight_decay=0
for i in range(max_iter):

    # Initialize gradients
    solver.zero_grad()

    # Forward, backward and update
    x.d, _ = di_l.next()
    loss_l.forward(clear_no_need_grad=True)
    loss_l.backward(clear_buffer=True)
    solver.weight_decay(weight_decay)
    solver.update()

    # Forward for test
    x.d, _ = di_t.next()
    loss_t.forward(clear_no_need_grad=True)

    # Monitor for logging
    monitor_training_loss.add(i, loss_l.d.copy())
    monitor_test_loss.add(i, loss_t.d.copy())
    monitor_time.add(i)