Variational Autoencoder

This scripts contains module for implementing variational autoencoder, the module contains:

  1. mnist_loader: loads mnist data, which will be used for this project
  2. xavier_init: initialize weights for vae
  3. vae_init: build variational autoencoder, return a tensorflow session
  4. vae_train: train variational autoencoder on MNIST dataset

import libraries


In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf

import matplotlib.pyplot as plt
%matplotlib inline

mnist_loader


In [2]:
def mnist_loader():
    """
    Load MNIST data in tensorflow readable format
    The script comes from:
    https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/mnist/input_data.py
    """
    import gzip
    import os
    import tempfile
    import numpy
    from six.moves import urllib
    from six.moves import xrange
    from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
    
    mnist = read_data_sets('MNIST_data', one_hot=True)
    n_samples = mnist.train.num_examples
    
    return (mnist, n_samples)

In [3]:
(mnist, n_samples) = mnist_loader()


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

Test mnist data


In [4]:
print('Number of available data: %d' % n_samples)


Number of available data: 55000

We are generating synthetic data in this project, so all the 55000 samples can be used for training


In [5]:
x_sample = mnist.test.next_batch(100)[0]

plt.figure(figsize=(8, 4))
for i in range(6):

    plt.subplot(2, 3, i + 1)
    plt.imshow(x_sample[i].reshape(28, 28), vmin=0, vmax=1, cmap="gray")
    plt.title("MNIST Data")
    plt.colorbar()
plt.tight_layout()


xavier_init

This function helps us to set initial weights properly to prevent the updates in deep layers to be too small or too large. In this implementation, we'll sample the weights from a uniform distribution:
$\mathbf{W} \ \sim \ uniform(-\sqrt{\frac{6}{\#Neuron_{in}+\#Neuron_{out}}},\sqrt{\frac{6}{\#Neuron_{in}+\#Neuron_{out}}})$

More detailed explanations of why we use xavier initialization can be found here


In [6]:
def xavier_init(neuron_in, neuron_out, constant=1):
    low = -constant*np.sqrt(6/(neuron_in + neuron_out))
    high = constant*np.sqrt(6/(neuron_in + neuron_out))
    return tf.random_uniform((neuron_in, neuron_out), minval=low, maxval=high, dtype=tf.float32)

Test xavier_init

For a 3*3 neural network, the weights should be sampled from uniform(-1,1)


In [7]:
sess = tf.Session()
weights = []

for i in range(1000):
    weights.append(sess.run(xavier_init(3,3)))

weights = np.array(weights).reshape((-1,1))

In [8]:
n, bins, patches = plt.hist(weights, bins=20)

plt.xlabel('weight value')
plt.ylabel('counts')
plt.title('Histogram of Weights Initialized by Xavier')
plt.show()


vae_init

This function initialize a variational encoder, returns tensorflow session, optimizer, cost function and input data will be returned (for further training) The architecture can be defined by setting the parameters of the function, the default setup is:

  • input nodes: 784
  • 1st layer of encoder: 500
  • 2nd layer of encoder: 500
  • 1st layer of decoder: 500
  • 2nd layer of decoder: 500
  • z: 20

In [9]:
def init_weights(config):
    """
    Initialize weights with specified configuration using Xavier algorithm
    """
    encoder_weights = dict()
    decoder_weights = dict()
    
    # two layers encoder
    encoder_weights['h1'] = tf.Variable(xavier_init(config['x_in'], config['encoder_1']))
    encoder_weights['h2'] = tf.Variable(xavier_init(config['encoder_1'], config['encoder_2']))
    encoder_weights['mu'] = tf.Variable(xavier_init(config['encoder_2'], config['z']))
    encoder_weights['sigma'] = tf.Variable(xavier_init(config['encoder_2'], config['z']))
    encoder_weights['b1'] = tf.Variable(tf.zeros([config['encoder_1']], dtype=tf.float32))
    encoder_weights['b2'] = tf.Variable(tf.zeros([config['encoder_2']], dtype=tf.float32))
    encoder_weights['bias_mu'] = tf.Variable(tf.zeros([config['z']], dtype=tf.float32))
    encoder_weights['bias_sigma'] = tf.Variable(tf.zeros([config['z']], dtype=tf.float32))
    
    # two layers decoder
    decoder_weights['h1'] = tf.Variable(xavier_init(config['z'], config['decoder_1']))
    decoder_weights['h2'] = tf.Variable(xavier_init(config['decoder_1'], config['decoder_2']))
    decoder_weights['mu'] = tf.Variable(xavier_init(config['decoder_2'], config['x_in']))
    decoder_weights['sigma'] = tf.Variable(xavier_init(config['decoder_2'], config['x_in']))
    decoder_weights['b1'] = tf.Variable(tf.zeros([config['decoder_1']], dtype=tf.float32))
    decoder_weights['b2'] = tf.Variable(tf.zeros([config['decoder_2']], dtype=tf.float32))
    decoder_weights['bias_mu'] = tf.Variable(tf.zeros([config['x_in']], dtype=tf.float32))
    decoder_weights['bias_sigma'] = tf.Variable(tf.zeros([config['x_in']], dtype=tf.float32))
    
    return (encoder_weights, decoder_weights)


def forward_z(x, encoder_weights):
    """
    Compute mean and sigma of z
    """
    layer_1 = tf.nn.softplus(tf.add(tf.matmul(x, encoder_weights['h1']), encoder_weights['b1']))
    layer_2 = tf.nn.softplus(tf.add(tf.matmul(layer_1, encoder_weights['h2']), encoder_weights['b2']))
    z_mean = tf.add(tf.matmul(layer_2, encoder_weights['mu']), encoder_weights['bias_mu'])
    z_sigma = tf.add(tf.matmul(layer_2, encoder_weights['sigma']), encoder_weights['bias_sigma'])
    
    return(z_mean, z_sigma)


def reconstruct_x(z, decoder_weights):
    """
    Use z to reconstruct x
    """
    layer_1 = tf.nn.softplus(tf.add(tf.matmul(z, decoder_weights['h1']), decoder_weights['b1']))
    layer_2 = tf.nn.softplus(tf.add(tf.matmul(layer_1, decoder_weights['h2']), decoder_weights['b2']))
    x_prime = tf.nn.sigmoid(tf.add(tf.matmul(layer_2, decoder_weights['mu']), decoder_weights['bias_mu']))
    
    return x_prime


def optimize_func(z, z_mean, z_sigma, x, x_prime, learn_rate):
    """
    Define cost and optimize function
    """
    # define loss function
    # reconstruction lost
    recons_loss = -tf.reduce_sum(x * tf.log(1e-10 + x_prime) + (1-x) * tf.log(1e-10 + 1 - x_prime), 1)
    # KL distance
    latent_loss = -0.5 * tf.reduce_sum(1 + z_sigma - tf.square(z_mean) - tf.exp(z), 1)
    # summing two loss terms together
    cost = tf.reduce_mean(recons_loss + latent_loss)
    
    # use ADAM to optimize
    optimizer = tf.train.AdamOptimizer(learning_rate=learn_rate).minimize(cost)
    
    return (cost, optimizer)

def vae_init(batch_size=100, learn_rate=0.001, config={}):
    """
    This function build a varational autoencoder based on https://jmetzen.github.io/2015-11-27/vae.html
    In consideration of simplicity and future work on optimization, we removed the class structure
    A tensorflow session, optimizer and cost function as well as input data will be returned
    """
    # default configuration of network
    # x_in = 784
    # encoder_1 = 500
    # encoder_2 = 500
    # decoder_1 = 500
    # decoder_2 = 500
    # z = 20
    
    # use default setting if no configuration is specified
    if not config:
        config['x_in'] = 784
        config['encoder_1'] = 500
        config['encoder_2'] = 500
        config['decoder_1'] = 500
        config['decoder_2'] = 500
        config['z'] = 20
    
    # input
    x = tf.placeholder(tf.float32, [None, config['x_in']])
    
    # initialize weights
    (encoder_weights, decoder_weights) = init_weights(config)
    
    # compute mean and sigma of z
    (z_mean, z_sigma) = forward_z(x, encoder_weights)
    
    # compute z by drawing sample from normal distribution
    eps = tf.random_normal((batch_size, config['z']), 0, 1, dtype=tf.float32)
    z_val = tf.add(z_mean, tf.multiply(tf.sqrt(tf.exp(z_sigma)), eps))
    
    # use z to reconstruct the network
    x_prime = reconstruct_x(z_val, decoder_weights)
    
    # define loss function
    (cost, optimizer) = optimize_func(z_val, z_mean, z_sigma, x, x_prime, learn_rate)
    
    # initialize all variables
    init = tf.global_variables_initializer()
    
    # define and return the session
    sess = tf.InteractiveSession()
    sess.run(init)
    
    return (sess, optimizer, cost, x, x_prime)

vae_train

This function loads the previously initialized VAE and do the training.
If verbose if set as 1, then every verb_step the program will print out cost information


In [10]:
def vae_train(sess, optimizer, cost, x, n_samples, batch_size=100, learn_rate=0.001, train_epoch=10, verb=1, verb_step=5):
    
    for epoch in range(train_epoch):
        avg_cost = 0
        total_batch = int(n_samples / batch_size)
        for i in range(total_batch):
            batch_x, _ = mnist.train.next_batch(batch_size)
            
            _, c = sess.run((optimizer, cost), feed_dict={x: batch_x})
            avg_cost += c / n_samples * batch_size
        
        if verb:
            if epoch % verb_step == 0:
                # print('Epoch:%04d\tCost=%.2f' % (epoch+1, avg_cost))
                print('Epoch:%04d' % (epoch+1), 'cost=', '{:.9f}'.format(avg_cost))

In [11]:
(sess, optimizer, cost, x, x_prime) = vae_init()

In [12]:
vae_train(sess, optimizer, cost, x, n_samples, train_epoch=75)


Epoch:0001 cost= 215.538792725
Epoch:0006 cost= 132.551954249
Epoch:0011 cost= 124.163822632
Epoch:0016 cost= 120.474978013
Epoch:0021 cost= 116.404384960
Epoch:0026 cost= 114.047195462
Epoch:0031 cost= 112.488438388
Epoch:0036 cost= 111.200778087
Epoch:0041 cost= 110.112403856
Epoch:0046 cost= 109.075612072
Epoch:0051 cost= 108.250699394
Epoch:0056 cost= 107.574343012
Epoch:0061 cost= 106.552557276
Epoch:0066 cost= 105.814334675
Epoch:0071 cost= 105.250467918

In [22]:
examples_to_show = 100
x_sample = mnist.test.images[:examples_to_show]
x_reconstruct = sess.run(x_prime,  feed_dict={x: x_sample})

plt.figure(figsize=(8, 2))
for i in range(8):

    ax = plt.subplot(2, 8, i + 1)
    plt.imshow(x_sample[i].reshape(28, 28), vmin=0, vmax=1, cmap="gray")
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    
    ax = plt.subplot(2, 8, 8 + i + 1)
    plt.imshow(x_reconstruct[i].reshape(28, 28), vmin=0, vmax=1, cmap="gray")
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    
plt.tight_layout()
plt.savefig('../]data/{}.png'.format('vae_pic'), bbox_inches='tight')