In [1]:
#!/usr/bin/env python
"""Variational auto-encoder for MNIST data.
References
----------
http://edwardlib.org/tutorials/decoder
http://edwardlib.org/tutorials/inference-networks
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

In [4]:
import edward as ed
import numpy as np
import os
import tensorflow as tf

from edward.models import Bernoulli, Normal
from edward.util import Progbar
from keras.layers import Dense
from observations import mnist
from scipy.misc import imsave

In [5]:
def generator(array, batch_size):
  """Generate batch with respect to array's first axis."""
  start = 0  # pointer to where we are in iteration
  while True:
    stop = start + batch_size
    diff = stop - array.shape[0]
    if diff <= 0:
      batch = array[start:stop]
      start += batch_size
    else:
      batch = np.concatenate((array[start:], array[:diff]))
      start = diff
    batch = batch.astype(np.float32) / 255.0  # normalize pixel intensities
    batch = np.random.binomial(1, batch)  # binarize images
    yield batch

In [6]:
ed.set_seed(42)

data_dir = "/tmp/data"
out_dir = "/tmp/out"
if not os.path.exists(out_dir):
  os.makedirs(out_dir)
M = 100  # batch size during training
d = 2  # latent dimension

In [7]:
# DATA. MNIST batches are fed at training time.
(x_train, _), (x_test, _) = mnist(data_dir)
x_train_generator = generator(x_train, M)


>> Downloading /tmp/data/train-images-idx3-ubyte.gz.part 
>> [9.5 MB/9.5 MB] 105% @173.8 KB/s,[0s remaining, 58s elapsed]          
URL https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz downloaded to /tmp/data/train-images-idx3-ubyte.gz 
>> Downloading /tmp/data/train-labels-idx1-ubyte.gz.part 
>> [28.2 KB/28.2 KB] 3630% @1.0 MB/s,[0s remaining, 0s elapsed]        
URL https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz downloaded to /tmp/data/train-labels-idx1-ubyte.gz 
>> Downloading /tmp/data/t10k-images-idx3-ubyte.gz.part 
>> [1.6 MB/1.6 MB] 127% @76.1 KB/s,[0s remaining, 26s elapsed]        
URL https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz downloaded to /tmp/data/t10k-images-idx3-ubyte.gz 
>> Downloading /tmp/data/t10k-labels-idx1-ubyte.gz.part 
>> [4.4 KB/4.4 KB] 23086% @2.0 MB/s,[0s remaining, 0s elapsed]        
URL https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz downloaded to /tmp/data/t10k-labels-idx1-ubyte.gz 

In [8]:
# MODEL
# Define a subgraph of the full model, corresponding to a minibatch of
# size M.
z = Normal(loc=tf.zeros([M, d]), scale=tf.ones([M, d]))
hidden = Dense(256, activation='relu')(z.value())
x = Bernoulli(logits=Dense(28 * 28)(hidden))

In [9]:
# INFERENCE
# Define a subgraph of the variational model, corresponding to a
# minibatch of size M.
x_ph = tf.placeholder(tf.int32, [M, 28 * 28])
hidden = Dense(256, activation='relu')(tf.cast(x_ph, tf.float32))
qz = Normal(loc=Dense(d)(hidden),
            scale=Dense(d, activation='softplus')(hidden))

In [10]:
# Bind p(x, z) and q(z | x) to the same TensorFlow placeholder for x.
inference = ed.KLqp({z: qz}, data={x: x_ph})
optimizer = tf.train.RMSPropOptimizer(0.01, epsilon=1.0)
inference.initialize(optimizer=optimizer)

In [11]:
tf.global_variables_initializer().run()

n_epoch = 100
n_iter_per_epoch = x_train.shape[0] // M
for epoch in range(1, n_epoch + 1):
  print("Epoch: {0}".format(epoch))
  avg_loss = 0.0

  pbar = Progbar(n_iter_per_epoch)
  for t in range(1, n_iter_per_epoch + 1):
    pbar.update(t)
    x_batch = next(x_train_generator)
    info_dict = inference.update(feed_dict={x_ph: x_batch})
    avg_loss += info_dict['loss']

  # Print a lower bound to the average marginal likelihood for an
  # image.
  avg_loss = avg_loss / n_iter_per_epoch
  avg_loss = avg_loss / M
  print("-log p(x) <= {:0.3f}".format(avg_loss))

  # Prior predictive check.
  images = x.eval()
  for m in range(M):
    imsave(os.path.join(out_dir, '%d.png') % m, images[m].reshape(28, 28))


Epoch: 1
600/600 [100%] ██████████████████████████████ Elapsed: 21s
-log p(x) <= 207.645
Epoch: 2
600/600 [100%] ██████████████████████████████ Elapsed: 39s
-log p(x) <= 210.973
Epoch: 3
600/600 [100%] ██████████████████████████████ Elapsed: 32s
-log p(x) <= 206.735
Epoch: 4
600/600 [100%] ██████████████████████████████ Elapsed: 31s
-log p(x) <= 207.010
Epoch: 5
600/600 [100%] ██████████████████████████████ Elapsed: 32s
-log p(x) <= 207.079
Epoch: 6
600/600 [100%] ██████████████████████████████ Elapsed: 34s
-log p(x) <= 209.015
Epoch: 7
600/600 [100%] ██████████████████████████████ Elapsed: 33s
-log p(x) <= 210.181
Epoch: 8
600/600 [100%] ██████████████████████████████ Elapsed: 31s
-log p(x) <= 209.804
Epoch: 9
600/600 [100%] ██████████████████████████████ Elapsed: 31s
-log p(x) <= 209.638
Epoch: 10
600/600 [100%] ██████████████████████████████ Elapsed: 31s
-log p(x) <= 212.237
Epoch: 11
600/600 [100%] ██████████████████████████████ Elapsed: 31s
-log p(x) <= 211.976
Epoch: 12
600/600 [100%] ██████████████████████████████ Elapsed: 32s
-log p(x) <= 211.576
Epoch: 13
600/600 [100%] ██████████████████████████████ Elapsed: 31s
-log p(x) <= 212.581
Epoch: 14
600/600 [100%] ██████████████████████████████ Elapsed: 31s
-log p(x) <= 213.042
Epoch: 15
600/600 [100%] ██████████████████████████████ Elapsed: 32s
-log p(x) <= 209.468
Epoch: 16
600/600 [100%] ██████████████████████████████ Elapsed: 32s
-log p(x) <= 213.523
Epoch: 17
600/600 [100%] ██████████████████████████████ Elapsed: 32s
-log p(x) <= 211.872
Epoch: 18
600/600 [100%] ██████████████████████████████ Elapsed: 33s
-log p(x) <= 213.069
Epoch: 19
600/600 [100%] ██████████████████████████████ Elapsed: 33s
-log p(x) <= 211.202
Epoch: 20
600/600 [100%] ██████████████████████████████ Elapsed: 32s
-log p(x) <= 208.602
Epoch: 21
600/600 [100%] ██████████████████████████████ Elapsed: 37s
-log p(x) <= 212.813
Epoch: 22
600/600 [100%] ██████████████████████████████ Elapsed: 38s
-log p(x) <= 214.470
Epoch: 23
600/600 [100%] ██████████████████████████████ Elapsed: 38s
-log p(x) <= 213.760
Epoch: 24
600/600 [100%] ██████████████████████████████ Elapsed: 38s
-log p(x) <= 214.503
Epoch: 25
600/600 [100%] ██████████████████████████████ Elapsed: 38s
-log p(x) <= 214.402
Epoch: 26
600/600 [100%] ██████████████████████████████ Elapsed: 39s
-log p(x) <= 215.500
Epoch: 27
600/600 [100%] ██████████████████████████████ Elapsed: 38s
-log p(x) <= 213.140
Epoch: 28
600/600 [100%] ██████████████████████████████ Elapsed: 39s
-log p(x) <= 214.080
Epoch: 29
600/600 [100%] ██████████████████████████████ Elapsed: 39s
-log p(x) <= 212.804
Epoch: 30
600/600 [100%] ██████████████████████████████ Elapsed: 38s
-log p(x) <= 213.027
Epoch: 31
600/600 [100%] ██████████████████████████████ Elapsed: 39s
-log p(x) <= 215.465
Epoch: 32
600/600 [100%] ██████████████████████████████ Elapsed: 38s
-log p(x) <= 214.424
Epoch: 33
600/600 [100%] ██████████████████████████████ ETA: 0 Elapsed: 39s
-log p(x) <= 213.801
Epoch: 34
600/600 [100%] ██████████████████████████████ Elapsed: 38s
-log p(x) <= 212.198
Epoch: 35
600/600 [100%] ██████████████████████████████ Elapsed: 38s
-log p(x) <= 213.701
Epoch: 36
600/600 [100%] ██████████████████████████████ Elapsed: 38s
-log p(x) <= 212.150
Epoch: 37
600/600 [100%] ██████████████████████████████ Elapsed: 38s
-log p(x) <= 212.205
Epoch: 38
600/600 [100%] ██████████████████████████████ Elapsed: 37s
-log p(x) <= 214.147
Epoch: 39
600/600 [100%] ██████████████████████████████ Elapsed: 38s
-log p(x) <= 214.269
Epoch: 40
600/600 [100%] ██████████████████████████████ Elapsed: 38s
-log p(x) <= 213.426
Epoch: 41
600/600 [100%] ██████████████████████████████ Elapsed: 37s
-log p(x) <= 215.381
Epoch: 42
600/600 [100%] ██████████████████████████████ Elapsed: 38s
-log p(x) <= 214.730
Epoch: 43
600/600 [100%] ██████████████████████████████ Elapsed: 38s
-log p(x) <= 214.604
Epoch: 44
600/600 [100%] ██████████████████████████████ Elapsed: 37s
-log p(x) <= 213.559
Epoch: 45
600/600 [100%] ██████████████████████████████ Elapsed: 37s
-log p(x) <= 214.851
Epoch: 46
600/600 [100%] ██████████████████████████████ Elapsed: 37s
-log p(x) <= 211.634
Epoch: 47
600/600 [100%] ██████████████████████████████ Elapsed: 37s
-log p(x) <= 213.170
Epoch: 48
600/600 [100%] ██████████████████████████████ Elapsed: 37s
-log p(x) <= 213.676
Epoch: 49
600/600 [100%] ██████████████████████████████ Elapsed: 37s
-log p(x) <= 213.487
Epoch: 50
600/600 [100%] ██████████████████████████████ Elapsed: 37s
-log p(x) <= 213.753
Epoch: 51
600/600 [100%] ██████████████████████████████ Elapsed: 37s
-log p(x) <= 213.010
Epoch: 52
600/600 [100%] ██████████████████████████████ Elapsed: 37s
-log p(x) <= 213.151
Epoch: 53
600/600 [100%] ██████████████████████████████ Elapsed: 37s
-log p(x) <= 212.584
Epoch: 54
600/600 [100%] ██████████████████████████████ Elapsed: 36s
-log p(x) <= 212.190
Epoch: 55
600/600 [100%] ██████████████████████████████ Elapsed: 36s
-log p(x) <= 213.579
Epoch: 56
600/600 [100%] ██████████████████████████████ Elapsed: 36s
-log p(x) <= 213.301
Epoch: 57
600/600 [100%] ██████████████████████████████ Elapsed: 36s
-log p(x) <= 212.727
Epoch: 58
600/600 [100%] ██████████████████████████████ Elapsed: 36s
-log p(x) <= 212.676
Epoch: 59
600/600 [100%] ██████████████████████████████ Elapsed: 37s
-log p(x) <= 213.336
Epoch: 60
600/600 [100%] ██████████████████████████████ Elapsed: 36s
-log p(x) <= 212.968
Epoch: 61
600/600 [100%] ██████████████████████████████ Elapsed: 36s
-log p(x) <= 213.369
Epoch: 62
600/600 [100%] ██████████████████████████████ Elapsed: 35s
-log p(x) <= 213.455
Epoch: 63
600/600 [100%] ██████████████████████████████ Elapsed: 36s
-log p(x) <= 213.088
Epoch: 64
600/600 [100%] ██████████████████████████████ Elapsed: 36s
-log p(x) <= 213.512
Epoch: 65
600/600 [100%] ██████████████████████████████ Elapsed: 35s
-log p(x) <= 213.587
Epoch: 66
600/600 [100%] ██████████████████████████████ Elapsed: 35s
-log p(x) <= 213.292
Epoch: 67
600/600 [100%] ██████████████████████████████ Elapsed: 35s
-log p(x) <= 214.088
Epoch: 68
600/600 [100%] ██████████████████████████████ Elapsed: 35s
-log p(x) <= 213.306
Epoch: 69
600/600 [100%] ██████████████████████████████ Elapsed: 35s
-log p(x) <= 212.897
Epoch: 70
600/600 [100%] ██████████████████████████████ Elapsed: 35s
-log p(x) <= 212.703
Epoch: 71
600/600 [100%] ██████████████████████████████ Elapsed: 35s
-log p(x) <= 215.801
Epoch: 72
600/600 [100%] ██████████████████████████████ Elapsed: 34s
-log p(x) <= 213.383
Epoch: 73
600/600 [100%] ██████████████████████████████ Elapsed: 35s
-log p(x) <= 214.901
Epoch: 74
600/600 [100%] ██████████████████████████████ Elapsed: 35s
-log p(x) <= 211.804
Epoch: 75
600/600 [100%] ██████████████████████████████ Elapsed: 35s
-log p(x) <= 212.938
Epoch: 76
600/600 [100%] ██████████████████████████████ Elapsed: 34s
-log p(x) <= 211.891
Epoch: 77
600/600 [100%] ██████████████████████████████ Elapsed: 35s
-log p(x) <= 213.905
Epoch: 78
600/600 [100%] ██████████████████████████████ Elapsed: 34s
-log p(x) <= 214.125
Epoch: 79
600/600 [100%] ██████████████████████████████ Elapsed: 34s
-log p(x) <= 212.074
Epoch: 80
600/600 [100%] ██████████████████████████████ Elapsed: 34s
-log p(x) <= 213.308
Epoch: 81
600/600 [100%] ██████████████████████████████ Elapsed: 34s
-log p(x) <= 212.881
Epoch: 82
600/600 [100%] ██████████████████████████████ Elapsed: 34s
-log p(x) <= 213.260
Epoch: 83
600/600 [100%] ██████████████████████████████ Elapsed: 34s
-log p(x) <= 212.237
Epoch: 84
600/600 [100%] ██████████████████████████████ Elapsed: 34s
-log p(x) <= 212.117
Epoch: 85
600/600 [100%] ██████████████████████████████ Elapsed: 33s
-log p(x) <= 213.889
Epoch: 86
600/600 [100%] ██████████████████████████████ Elapsed: 33s
-log p(x) <= 212.049
Epoch: 87
600/600 [100%] ██████████████████████████████ Elapsed: 33s
-log p(x) <= 211.101
Epoch: 88
600/600 [100%] ██████████████████████████████ Elapsed: 33s
-log p(x) <= 212.536
Epoch: 89
600/600 [100%] ██████████████████████████████ Elapsed: 33s
-log p(x) <= 211.041
Epoch: 90
600/600 [100%] ██████████████████████████████ Elapsed: 33s
-log p(x) <= 212.026
Epoch: 91
600/600 [100%] ██████████████████████████████ Elapsed: 33s
-log p(x) <= 211.746
Epoch: 92
600/600 [100%] ██████████████████████████████ Elapsed: 29s
-log p(x) <= 201.634
Epoch: 93
600/600 [100%] ██████████████████████████████ Elapsed: 29s
-log p(x) <= 197.529
Epoch: 94
600/600 [100%] ██████████████████████████████ Elapsed: 29s
-log p(x) <= 199.385
Epoch: 95
600/600 [100%] ██████████████████████████████ Elapsed: 29s
-log p(x) <= 201.137
Epoch: 96
600/600 [100%] ██████████████████████████████ Elapsed: 29s
-log p(x) <= 203.464
Epoch: 97
600/600 [100%] ██████████████████████████████ Elapsed: 29s
-log p(x) <= 204.189
Epoch: 98
600/600 [100%] ██████████████████████████████ Elapsed: 29s
-log p(x) <= 203.858
Epoch: 99
600/600 [100%] ██████████████████████████████ Elapsed: 29s
-log p(x) <= 203.330
Epoch: 100
600/600 [100%] ██████████████████████████████ Elapsed: 29s
-log p(x) <= 205.945

In [ ]: