CSCI 599 - In class demo of variational autoencoder in tensorflow


In [1]:
import tensorflow as tf
from tensorflow.contrib.slim import fully_connected as fc
import numpy as np

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
num_sample = mnist.train.num_examples
input_dim = 784
w = h = 28


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

In [3]:
class VariationalAutoencoder(object):
    def __init__(self, learning_rate=1e-4, batch_size=128, n_z=5):
        # build the model
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.n_z = n_z
        
        self.build()
        
        # launch a session
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())
        
        
    def build(self):
        """ build the network """
        # input
        self.x = tf.placeholder(name='x', dtype=tf.float32, shape=[None, input_dim])
        
        # encoder
        f1 = fc(self.x, 512, scope='enc_fc1', activation_fn=tf.nn.elu)
        f2 = fc(f1, 384, scope='enc_fc2', activation_fn=tf.nn.elu)
        f3 = fc(f2, 256, scope='enc_fc3', activation_fn=tf.nn.elu)
        
        self.z_mu = fc(f3, self.n_z, scope='enc_fc4_mu', activation_fn=None)
        # log (sigma^2)
        self.z_log_sigma_z_sq = fc(f3, self.n_z, scope='enc_fc4_sigma', activation_fn=None)
        # N(z_mu, z_sigma)
        eps = tf.random_normal(shape=tf.shape(self.z_log_sigma_z_sq),
                              mean=0, stddev=1, dtype=tf.float32)
        
        self.z = self.z_mu + tf.sqrt(tf.exp(self.z_log_sigma_z_sq)) * eps
        
        # decoder
        g1 = fc(self.z, 256, scope='dec_fc1', activation_fn=tf.nn.elu)
        g2 = fc(g1, 384, scope='dec_fc2', activation_fn=tf.nn.elu)
        g3 = fc(g2, 512, scope='dec_fc3', activation_fn=tf.nn.elu)
        self.x_hat = fc(g3, input_dim, scope='dec_fc4', activation_fn=tf.nn.sigmoid)
     
        # loss
        # reconstruction
        epsilon = 1e-10
        recon_loss = -tf.reduce_sum(
            self.x * tf.log(self.x_hat + epsilon) + (1 - self.x) * tf.log(1 - self.x_hat + epsilon),
            axis=1
        )
        
        # latent loss
        latent_loss = -0.5 * tf.reduce_sum(
            1 + self.z_log_sigma_z_sq - tf.square(self.z_mu) - tf.exp(self.z_log_sigma_z_sq),
            axis=1
        )
        
        # total loss
        self.total_loss = tf.reduce_mean(recon_loss + latent_loss)
        
        # optimizer
        self.train_op = tf.train.AdamOptimizer(
            learning_rate=self.learning_rate).minimize(self.total_loss)
        
    def run_single_stage(self, x):
        """ execute a forward and a backward pass, report loss """
        _, loss = self.sess.run([self.train_op, self.total_loss], feed_dict={self.x: x})
        return loss
        
    def reconstructor(self, x):
        """ reconstructor """
        return self.sess.run(self.x_hat, feed_dict={self.x: x})
        
    def generator(self, z):
        """ generation """
        return self.sess.run(self.x_hat, feed_dict={self.z: z})
        
    def transformer(self, x):
        """ transformation, in order to visualization """
        return self.sess.run(self.z, feed_dict={self.x: x})

In [4]:
def trainer(learning_rate=1e-4, batch_size=100, num_epoch=100, n_z=10):
    # model
    model = VariationalAutoencoder(learning_rate=learning_rate,
                                  batch_size=batch_size, n_z=n_z)
    
    # training loop
    for epoch in range(num_epoch):
        for it in range(num_sample // batch_size):
            batch = mnist.train.next_batch(batch_size)
            
            # training loop, batch[0] is the image, batch[1] is the label
            loss = model.run_single_stage(batch[0])
        print '[Epoch]', epoch, 'loss', loss
    print('Done!')
    return model

In [5]:
model = trainer(learning_rate=1e-4, batch_size=128, num_epoch=10, n_z=5)


[Epoch] 0 loss 181.502
[Epoch] 1 loss 156.259
[Epoch] 2 loss 149.605
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-5-087663072c0b> in <module>()
----> 1 model = trainer(learning_rate=1e-4, batch_size=128, num_epoch=10, n_z=5)

<ipython-input-4-003cc8534389> in trainer(learning_rate, batch_size, num_epoch, n_z)
     10 
     11             # training loop, batch[0] is the image, batch[1] is the label
---> 12             loss = model.run_single_stage(batch[0])
     13         print '[Epoch]', epoch, 'loss', loss
     14     print('Done!')

<ipython-input-3-9d30e583e2c3> in run_single_stage(self, x)
     61     def run_single_stage(self, x):
     62         """ execute a forward and a backward pass, report loss """
---> 63         _, loss = self.sess.run([self.train_op, self.total_loss], feed_dict={self.x: x})
     64         return loss
     65 

/Applications/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
    893     try:
    894       result = self._run(None, fetches, feed_dict, options_ptr,
--> 895                          run_metadata_ptr)
    896       if run_metadata:
    897         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/Applications/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1122     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1123       results = self._do_run(handle, final_targets, final_fetches,
-> 1124                              feed_dict_tensor, options, run_metadata)
   1125     else:
   1126       results = []

/Applications/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1319     if handle is None:
   1320       return self._do_call(_run_fn, self._session, feeds, fetches, targets,
-> 1321                            options, run_metadata)
   1322     else:
   1323       return self._do_call(_prun_fn, self._session, handle, feeds, fetches)

/Applications/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_call(self, fn, *args)
   1325   def _do_call(self, fn, *args):
   1326     try:
-> 1327       return fn(*args)
   1328     except errors.OpError as e:
   1329       message = compat.as_text(e.message)

/Applications/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1304           return tf_session.TF_Run(session, options,
   1305                                    feed_dict, fetch_list, target_list,
-> 1306                                    status, run_metadata)
   1307 
   1308     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

In [ ]: