In [1]:
import pickle, gzip
import matplotlib.pyplot as plt
import numpy as np
import sys
%matplotlib inline
np.random.seed(0)
In [2]:
from vae import VAE
In [3]:
with gzip.open('../resources/data/mnist.pkl.gz', 'rb') as f:
train, test, val = pickle.load(f, encoding='latin1')
mnist_train = train[0]
mnist_test = test[0]
The difference between a Variational Autoencoder and other autoencoder networks is that it explicitly learns a latent variable representation of the data. Our modeling assumption is that the data (in this case, handwritten digits) comes from a latent variable $z$.
$$z \sim N(0, I)$$We want to find the posterior $P(z \mid x)$, but the computation is intractable. Other methods such as MCMC are computational expensive. The method we will use is variational inference. We will learn an approximation of the posterior $Q(z \mid x)$. We achieve by adding a regularization term that minimizes the KL divergence between the encoder output $x \sim N(\mu, \sigma^2 I)$ and the latent variable representation.
$$D_{kl}(q || p) = -\frac{1}{2} \sum (1 + log(\sigma^2) - \mu^2 - \sigma^2) $$The result is a generative models that can produce new images that are similar to ones it has seen.
In [4]:
params = {
'alpha' : 0.02,
'iter' : 200,
'activation': 'sigmoid',
'loss': 'squared_error',
'batch_size': 150
}
example = VAE([[784, 200], [200, 784]], 2, params)
Here we pass in a set of training digits. From there the network will reconstruct a latent variable representation.
In [5]:
example.learn(mnist_train)
Now let's inspect the generated images compared to the originals.
In [6]:
fig, ax = plt.subplots(2,3, figsize = (10, 8))
for i in range(3):
in_digit = mnist_test[i][None,:]
out_digit = example.encode_decode(in_digit)
ax[0,i].matshow(in_digit.reshape((28,28)), cmap='gray', clim=(0,1))
ax[1,i].matshow(out_digit.reshape((28,28)), cmap='gray', clim=(0,1))
pass
We can now generate new images as well.
In [7]:
fig, ax = plt.subplots(2,2, figsize = (6, 6))
a = np.array([1, 3])
b = np.array([1, 3])
for i, z1 in enumerate(a):
for j, z2 in enumerate(b):
ax[i,j].matshow(example.generate(np.array([z1,z2])).reshape((28,28)), cmap='gray', clim=(0,1))
pass