Variational Autoencoder in TensorFlow

Variational Autoencoders (VAE) are a popular model that allows for unsupervised (and semi-supervised) learning. In this notebook, we'll implement a simple VAE on the MNIST dataset.

One of the primary goals of the VAE (and auto-encoders in general) is to reconstruct the original input. Why would we want to do that? At first glance, such a model seems silly: a simple identity function achieves the same thing with perfect results. However, with an autoencoder, we can learn a compresesed representation in a smaller latent space, allowing us to learn features and structure of the data. Autoencoders are composed of two arms, the encoder and decoder, which convert values from the data space to the latent space and vice versa, respectively.

Importantly, since we're simply reconstructing the original input, we do not necessarily need labels to do our learning, as we have in previous examples. This is significant, as labels are often far more expensive to acquire than raw data, often prohibitively so. VAEs therefore allow us to leverage abundant unlabeled data. That said, VAEs are also able to take advantage of labels when available as well, either in a completely supervised or semi-supervised setting. Altogether, autoencoders can achieve impressive results on tasks like denoising, segmentation, and even predicting future images.

Imports and Data

First, some package imports and loading of the data. This is similar to what we've done before, with the main difference being that we're going to use TensorFlow Slim, as a follow-up to notebook 02A.


In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

slim = tf.contrib.slim

# Import data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

Encoder

The encoder deterministically transforms the data $x$ from the data space to the latent space of $z$. Since we're dealing with a variational autoencoder, we attempt to model the distribution of the latent space given the input, represented by $q(z|x)$. This isn't immediately obvious in the code implementation, but we assume a standard Gaussian prior on this distribution, and our encoder returns the mean and variance (actually log-variance) of this distribution. We use log-variance because our model returns a real number, while variances must be positive.

MNIST is a very simple dataset, so let's also keep the model simple: an MLP with 2 fully connected layers. We name the output mu_logvar as we will be interpretting the first half of the final 128-dimensional vector as the mean $\mu$ and the second half as the log-variance log($\sigma^2$).


In [2]:
def encoder(x):
    """Network q(z|x)"""
    with slim.arg_scope([slim.fully_connected],
                    activation_fn=tf.nn.relu,
                    weights_initializer=tf.truncated_normal_initializer(0.0, 0.1)):
        mu_logvar = slim.fully_connected(x, 128, scope='fc1')
        mu_logvar = slim.fully_connected(mu_logvar, 128, activation_fn=None, scope='fc2')
        
    return mu_logvar

Note that we use a couple features of TF-Slim here:

  1. We use slim.fully_connected() to specify which layers we want to use, without having to worry about defining weight or bias variables beforehand.

  2. We use slim.arg_scope() to specify default arguments so we can leave them out of the definitions of each of the fully connected layers. We can still override the activation_fn for the last layer though.

For this simple model, TF-Slim doesn't actually benefit us all that much, but for the sake of demonstration, we'll stick with it.

Decoder

The decoder is the generative arm of the auotencoder. Just like our encoder learned parameters of a distribution $p(z|x)$, our decoder will learn parameters of a distribution $p(x|z)$. Beceause $x$ is binary data (black and white pixels), we will use a Bernoulli distribution. Our generative neural network will learn the mean of this Bernoulli distribution for each pixel we want to generate. Another viewpoint: if our neural network outputs $\hat{x}_j$ for pixel $j$, it means we believe that the pixel will be white with that probability.

Again, since MNIST is simple, we'll use a 2 layer MLP for the decoder. Importantly, since we are focusing on reconstruction, we make sure that the final output of the decoder $\hat{x}$ is the same dimensions as our input $x$.


In [3]:
def decoder(mu_logvar):
    """Network p(x|z)"""
    # Interpret z as concatenation of mean and log variance
    mu, logvar = tf.split(mu_logvar, num_or_size_splits=2, axis=1)

    # Standard deviation must be positive
    stddev = tf.sqrt(tf.exp(logvar))

    # Draw a z from the distribution
    epsilon = tf.random_normal(tf.shape(stddev))
    z = mu + tf.multiply(stddev, epsilon)

    # Decoding arm
    with slim.arg_scope([slim.fully_connected],
                        activation_fn=tf.nn.relu,
                        weights_initializer=tf.truncated_normal_initializer(0.0, 0.1)):        
        x_logits = slim.fully_connected(z, 128, scope='fc1')
        x_logits = slim.fully_connected(x_logits, 784, activation_fn=None, scope='fc2')
        
        # x_hat to be generated from a Bernoulli distribution
        x_dist = tf.contrib.distributions.Bernoulli(logits=x_logits, dtype=tf.float32)
        
    return x_logits, x_dist

Loss

Prof. Jun Zhu talked in class about the theoretical motivation for the loss of the VAE model. Like all variational inference techniques, it tries to match the variational posterior distribution (here a neural network) with the true posterior. However, at the end of the derivation, we can think of our model as trading off two goals:

  1. Reconstruction loss: Our generator produces parameters to a Bernoulli distribution that is supposed to represent $p(x | z)$; because we assume that $z$ is the latent representation of an actual data point $x$, we can measure how well we achieve this goal by measuring the likelihood of $x$ according to that Bernoulli distribution. Another way of thinking of this is that we can measure how similar our reconstructed image is to our original image. The measure of similarity we use is cross-entropy: we think of our model as classifying each pixel as black or white, and we measure how good the classifier is using the classic sigmoid cross-entropy loss.

  2. KL Divergence: Because this model is variational, we also include a KL penalty to impose a Gaussian prior on the latent space. The exact derivation of this term can be found in the original Auto-Encoding Variational Bayes paper. Is a standard Gaussian prior a good assumption? What are the potential weaknesses of this approach?

We use the ADAM algorithm that we've used before for optimization.


In [4]:
def optimizer(x_logits, x, mu_logvar):
    """Define loss functions (reconstruction, KL divergence) and optimizer"""
    with tf.variable_scope('optimizer') as scope:            
        # Reconstruction loss
        reconstruction = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=x_logits), reduction_indices=[1])
 
        # KL divergence
        mu, logvar = tf.split(mu_logvar, num_or_size_splits=2, axis=1)
        kl_d = -0.5 * tf.reduce_sum(1.0 + logvar - tf.square(mu) - tf.exp(logvar), reduction_indices=[1])
        
        # Total loss
        loss = tf.reduce_mean(reconstruction + kl_d)
            
        # ADAM optimizer
        train_step = tf.train.AdamOptimizer().minimize(loss)
    
    return train_step

Visualization

It'll be nice to visualize the reconstructions that our model generates to see what it learns. This helper function plots the original inputs in one column and the reconstructions next to them in another column. I also may or may not have stolen it from Alex Lew, who included it in his GAN notebook (03B)...


In [5]:
def visualize_row(image, reconstruction, img_width=28, cmap='gray'):
    """
    Takes in a tensor of images of given width, and displays them in a column
    in a plot, using `cmap` to map from numbers to colors.
    """
    fig, ax = plt.subplots(1, 2)
    image = np.reshape(image, [-1, img_width])
    reconstruction = np.reshape(reconstruction, [-1, img_width])
    plt.figure()
    ax[0].imshow(np.clip(image, 0, 1), cmap=cmap)
    ax[1].imshow(np.clip(reconstruction, 0, 1), cmap=cmap)
    plt.show()

Define the graph and train

All of the functions we've written thus far are just that: functions. We still need to call them to assemble our TensorFlow computation graph. At this point, this should be becoming familiar.

One of the small differences is the inclusion of tf.reset_default_graph(), added to remedy a small, unfortunate side effect of using Jupyter and TensorFlow in conjunction, but you don't have to worry about it too much to understand the model. A more detailed explanation if you're interested below [1].


In [6]:
# Reset the graph
tf.reset_default_graph()

# Define input placeholder
x = tf.placeholder(tf.float32,[None, 784], name='x')

# Define VAE graph
with tf.variable_scope('encoder'):
    mu_logvar = encoder(x)
with tf.variable_scope('decoder'):
    x_logits, x_dist = decoder(mu_logvar)
    x_hat = x_dist.sample()

# Optimization
with tf.variable_scope('unlabeled') as scope:
    train_step_unlabeled = optimizer(x_logits, x, mu_logvar)

*[1] The primary purpose of TensorFlow is to construct a computation graph connecting Tensors and operations. Each of these nodes must be assigned a unique name; if the user does not specify one, a unique name is automatically generated, like 'Placeholder_2', with the number at the end incrementing each time you create a new node of that type. Attempting to create a node with a name already found in the graph raises an error.*

*So how can this be problematic? In the Coding Environments notebook ([00B](https://github.com/kevinjliang/Duke-Tsinghua-MLSS-2017/blob/master/00B_Coding_Environments.ipynb)), it was mentioned that code from previously run cells persists. As such, if we're programming interactively and want to rebuild our graph after some updates, the new updated nodes we want to add collide with the names from our previous run, throwing an error. Why didn't we have to worry about this before? In the past, we haven't been naming our variables, so TensorFlow has been giving the nodes new unique names every time we update the graph and adding them to the collection of nodes from previous runs; the old nodes are never called, so they just sit there. However, TF-Slim does name the variables it generates, thus causing the problem. We can solve this by creating a new graph object before we define our computation graph, so every time we want to make modifications to the graph, we start anew.*

*If you're confused by that explanation, I wouldn't worry about it. It's not necessary for the program to run. It's there so we can re-run the cell defining the computation graph without restarting the entire kernel to clear memory of previous variables. In a traditionally written Python program (i.e. not IPython), you wouldn't need to do this.*

For training, we'll stay simple and train for 20000 iterations, visualizing our results with 5 digits from the validation set after every 1000 minibatches. Notice that this model is completely unsupervised: we never include the digit labels at any point in the process. Within a few thousand iterations, the model should start producing reasonable looking results:


In [7]:
with tf.Session() as sess:    
    # Initialize all variables
    sess.run(tf.global_variables_initializer())
    
    # Train VAE model
    for i in range(20000):        
        # Get a training minibatch
        batch = mnist.train.next_batch(100)
        
        # Binarize the data
        x_binarized = (batch[0] > 0.5).astype(np.float32)
        
        # Train on minibatch
        sess.run(train_step_unlabeled, feed_dict={x: x_binarized}) # No labels
            
        # Visualize reconstructions every 1000 iterations
        if i % 1000 == 0:
            batch = mnist.validation.next_batch(5)
            x_binarized = (batch[0] > 0.5).astype(np.float32)
            reconstructions = sess.run(x_hat, feed_dict={x: x_binarized})
            print("Iteration {0}:".format(i))
            visualize_row(batch[0], reconstructions)


Iteration 0:
<matplotlib.figure.Figure at 0x28fc9e46e80>
Iteration 1000:
<matplotlib.figure.Figure at 0x28fcaf5e8d0>
Iteration 2000:
<matplotlib.figure.Figure at 0x28fcafb0f60>
Iteration 3000:
<matplotlib.figure.Figure at 0x28fcb085908>
Iteration 4000:
<matplotlib.figure.Figure at 0x28fcb19b908>
Iteration 5000:
<matplotlib.figure.Figure at 0x28fcb2b2828>
Iteration 6000:
<matplotlib.figure.Figure at 0x28fcc5e94e0>
Iteration 7000:
<matplotlib.figure.Figure at 0x28fccb4e400>
Iteration 8000:
<matplotlib.figure.Figure at 0x28fccc64320>
Iteration 9000:
<matplotlib.figure.Figure at 0x28fcc6975f8>
Iteration 10000:
<matplotlib.figure.Figure at 0x28fc9e31908>
Iteration 11000:
<matplotlib.figure.Figure at 0x28fcc6dc6a0>
Iteration 12000:
<matplotlib.figure.Figure at 0x28fc9e059b0>
Iteration 13000:
<matplotlib.figure.Figure at 0x28fcb037748>
Iteration 14000:
<matplotlib.figure.Figure at 0x28fc9ed8cc0>
Iteration 15000:
<matplotlib.figure.Figure at 0x28fcafab320>
Iteration 16000:
<matplotlib.figure.Figure at 0x28fcb1a5f60>
Iteration 17000:
<matplotlib.figure.Figure at 0x28fc9ea7908>
Iteration 18000:
<matplotlib.figure.Figure at 0x28fc9ebe6a0>
Iteration 19000:
<matplotlib.figure.Figure at 0x28fc9e13828>