Variational autoencoders

In this notebook, we define a VAE using TFP's "probabilistic layers". We fit the model to Fashion MNIST. Our code is based on https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Probabilistic_Layers_VAE.ipynb


In [0]:
from __future__ import absolute_import, division, print_function, unicode_literals

In [5]:
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass

import tensorflow as tf
tf.__version__

import tensorflow_probability as tfp


---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-5-79e10747b01a> in <module>()
      8 tf.__version__
      9 
---> 10 import tensorflow_probability as tfp

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/__init__.py in <module>()
     73 
     74 # from tensorflow_probability.google import staging  # DisableOnExport
---> 75 from tensorflow_probability.python import *  # pylint: disable=wildcard-import
     76 from tensorflow_probability.python.version import __version__
     77 # pylint: enable=g-import-not-at-top

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/__init__.py in <module>()
     19 from __future__ import print_function
     20 
---> 21 from tensorflow_probability.python import bijectors
     22 from tensorflow_probability.python import debugging
     23 from tensorflow_probability.python import distributions

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/bijectors/__init__.py in <module>()
     31 from tensorflow_probability.python.bijectors.cholesky_outer_product import CholeskyOuterProduct
     32 from tensorflow_probability.python.bijectors.cholesky_to_inv_cholesky import CholeskyToInvCholesky
---> 33 from tensorflow_probability.python.bijectors.correlation_cholesky import CorrelationCholesky
     34 from tensorflow_probability.python.bijectors.cumsum import Cumsum
     35 from tensorflow_probability.python.bijectors.discrete_cosine_transform import DiscreteCosineTransform

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/bijectors/correlation_cholesky.py in <module>()
     24 
     25 from tensorflow_probability.python.bijectors import bijector
---> 26 from tensorflow_probability.python.bijectors import fill_triangular
     27 from tensorflow_probability.python.internal import prefer_static
     28 

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/bijectors/fill_triangular.py in <module>()
     24 import tensorflow.compat.v2 as tf
     25 
---> 26 from tensorflow_probability.python import math as tfp_math
     27 from tensorflow_probability.python.bijectors import bijector
     28 from tensorflow_probability.python.internal import assert_util

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/math/__init__.py in <module>()
     33 from tensorflow_probability.python.math.interpolation import batch_interp_regular_nd_grid
     34 from tensorflow_probability.python.math.interpolation import interp_regular_1d_grid
---> 35 from tensorflow_probability.python.math.linalg import cholesky_concat
     36 from tensorflow_probability.python.math.linalg import fill_triangular
     37 from tensorflow_probability.python.math.linalg import fill_triangular_inverse

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/math/linalg.py in <module>()
     55     'tfp.math.matrix_rank is deprecated. Use tf.linalg.matrix_rank instead',
     56     warn_once=True)(
---> 57         tf.linalg.matrix_rank)
     58 
     59 

/tensorflow-2.0.0-rc0/python3.6/tensorflow_core/python/util/module_wrapper.py in __getattr__(self, name)
    167   def __getattr__(self, name):
    168     try:
--> 169       attr = getattr(self._tfmw_wrapped_module, name)
    170     except AttributeError as e:
    171       if not self._tfmw_public_apis:

AttributeError: module 'tensorflow._api.v2.compat.v2.linalg' has no attribute 'matrix_rank'

In [0]:
!pip install -q --upgrade tf-nightly-gpu-2.0-preview
!pip install -q tfp-nightly

import tensorflow as tf
from tensorflow.python import tf2
if not tf2.enabled():
  import tensorflow.compat.v2 as tf
  tf.enable_v2_behavior()
  assert tf2.enabled()

In [0]:
import matplotlib.pyplot as plt
import numpy as np

from tensorflow.keras import layers
import tensorflow_datasets as tfds
import tensorflow_probability as tfp


tfk = tf.keras
tfkl = tf.keras.layers
tfpl = tfp.layers
tfd = tfp.distributions

Load Dataset


In [12]:
#dataname = 'mnist' 
dataname = 'fashion_mnist'
datasets, datasets_info = tfds.load(name=dataname, with_info=True, as_supervised=False)

if dataname == 'mnist':
    def _preprocess(sample):
      image = tf.cast(sample['image'], tf.float32) / 255.  # Scale to unit interval.
      image = image < tf.random.uniform(tf.shape(image))   # Randomly binarize.
      return image, image
else:
  def _preprocess(sample):
      image = tf.cast(sample['image'], tf.float32) / 255.  # Scale to unit interval.
      return image, image # for validation we compute p(input|input)
  

train_dataset = (datasets['train']
                 .map(_preprocess)
                 .batch(256)
                 .prefetch(tf.data.experimental.AUTOTUNE)
                 .shuffle(int(10e3)))
eval_dataset = (datasets['test']
                .map(_preprocess)
                .batch(256)
                .prefetch(tf.data.experimental.AUTOTUNE))


W0829 04:00:45.417734 140274383931264 dataset_builder.py:439] Warning: Setting shuffle_files=True because split=TRAIN and shuffle_files=None. This behavior will be deprecated on 2019-08-06, at which point shuffle_files=False will be the default for all splits.

Specify model.


In [0]:
input_shape = datasets_info.features['image'].shape
encoded_size = 16
base_depth = 32

In [0]:
prior = tfd.Independent(tfd.Normal(loc=tf.zeros(encoded_size), scale=1),
                        reinterpreted_batch_ndims=1)

In [0]:
encoder = tfk.Sequential([
    tfkl.InputLayer(input_shape=input_shape),
    tfkl.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5),
    tfkl.Conv2D(base_depth, 5, strides=1,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(base_depth, 5, strides=2,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(2 * base_depth, 5, strides=1,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(2 * base_depth, 5, strides=2,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(4 * encoded_size, 7, strides=1,
                padding='valid', activation=tf.nn.leaky_relu),
    tfkl.Flatten(),
    tfkl.Dense(tfpl.MultivariateNormalTriL.params_size(encoded_size),
               activation=None),
    tfpl.MultivariateNormalTriL(
        encoded_size,
        activity_regularizer=tfpl.KLDivergenceRegularizer(prior)),
])

In [0]:
decoder = tfk.Sequential([
    tfkl.InputLayer(input_shape=[encoded_size]),
    tfkl.Reshape([1, 1, encoded_size]),
    tfkl.Conv2DTranspose(2 * base_depth, 7, strides=1,
                         padding='valid', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose(2 * base_depth, 5, strides=1,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose(2 * base_depth, 5, strides=2,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose(base_depth, 5, strides=1,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose(base_depth, 5, strides=2,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose(base_depth, 5, strides=1,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(filters=1, kernel_size=5, strides=1,
                padding='same', activation=None),
    tfkl.Flatten(),
    tfpl.IndependentBernoulli(input_shape, tfd.Bernoulli.logits),
])

In [0]:
vae = tfk.Model(inputs=encoder.inputs,
                outputs=decoder(encoder.outputs[0]))

Train model.


In [18]:
negloglik = lambda x, rv_x: -rv_x.log_prob(x)

vae.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-3),
            loss=negloglik)

vae.fit(train_dataset,
        epochs=5,
        validation_data=eval_dataset)


Epoch 1/5
235/235 [==============================] - 25s 106ms/step - loss: 318.3373 - val_loss: 0.0000e+00
Epoch 2/5
235/235 [==============================] - 21s 91ms/step - loss: 267.4640 - val_loss: 265.9156
Epoch 3/5
235/235 [==============================] - 22s 92ms/step - loss: 261.6352 - val_loss: 261.4726
Epoch 4/5
235/235 [==============================] - 22s 92ms/step - loss: 258.9843 - val_loss: 261.3817
Epoch 5/5
235/235 [==============================] - 22s 92ms/step - loss: 257.2519 - val_loss: 259.1017
Out[18]:
<tensorflow.python.keras.callbacks.History at 0x7f935f953c18>

Look Ma, No HandsTensors!


In [0]:
# We'll just examine ten random digits.
x = next(iter(eval_dataset))[0][:10]
xhat = vae(x)
assert isinstance(xhat, tfd.Distribution)

In [0]:
import matplotlib.pyplot as plt

def display_imgs(x, y=None):
  if not isinstance(x, (np.ndarray, np.generic)):
    x = np.array(x)
  plt.ioff()
  n = x.shape[0]
  fig, axs = plt.subplots(1, n, figsize=(n, 1))
  if y is not None:
    fig.suptitle(np.argmax(y, axis=1))
  for i in range(n): # xrange is python2 only
    axs.flat[i].imshow(x[i].squeeze(), interpolation='none', cmap='gray')
    axs.flat[i].axis('off')
  plt.show()
  plt.close()
  plt.ion()

In [21]:
print('Originals:')
display_imgs(x)

print('Decoded Random Samples:')
display_imgs(xhat.sample())

print('Decoded Modes:')
display_imgs(xhat.mode())

print('Decoded Means:')
display_imgs(xhat.mean())


Originals:
Decoded Random Samples:
Decoded Modes:
Decoded Means:

In [0]:
# Now, let's generate ten never-before-seen digits.
z = prior.sample(10)
xtilde = decoder(z)
assert isinstance(xtilde, tfd.Distribution)

In [23]:
print('Randomly Generated Samples:')
display_imgs(xtilde.sample())

print('Randomly Generated Modes:')
display_imgs(xtilde.mode())

print('Randomly Generated Means:')
display_imgs(xtilde.mean())


Randomly Generated Samples:
Randomly Generated Modes:
Randomly Generated Means: