In [1]:
import os
os.environ['KERAS_BACKEND'] = 'theano'

In [2]:
'''This script demonstrates how to build a variational autoencoder
with Keras and deconvolution layers.

Reference: "Auto-Encoding Variational Bayes" https://arxiv.org/abs/1312.6114
'''
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

from keras.layers import Input, Dense, Lambda, Flatten, Reshape, Layer
from keras.layers import Conv2D, Conv2DTranspose
from keras.models import Model
from keras import backend as K
from keras import metrics
from keras.datasets import mnist


Using Theano backend.

In [3]:
import keras
print(keras.__version__)


2.0.3

In [4]:
# input image dimensions
img_rows, img_cols, img_chns = 28, 28, 1
# number of convolutional filters to use
filters = 64
# convolution kernel size
num_conv = 3

batch_size = 100
if K.image_data_format() == 'channels_first':
    original_img_size = (img_chns, img_rows, img_cols)
else:
    original_img_size = (img_rows, img_cols, img_chns)
latent_dim = 2
intermediate_dim = 128
epsilon_std = 1.0
epochs = 5

x = Input(batch_shape=(batch_size,) + original_img_size)
conv_1 = Conv2D(img_chns,
                kernel_size=(2, 2),
                padding='same', activation='relu')(x)
conv_2 = Conv2D(filters,
                kernel_size=(2, 2),
                padding='same', activation='relu',
                strides=(2, 2))(conv_1)
conv_3 = Conv2D(filters,
                kernel_size=num_conv,
                padding='same', activation='relu',
                strides=1)(conv_2)
conv_4 = Conv2D(filters,
                kernel_size=num_conv,
                padding='same', activation='relu',
                strides=1)(conv_3)
flat = Flatten()(conv_4)
hidden = Dense(intermediate_dim, activation='relu')(flat)

z_mean = Dense(latent_dim)(hidden)
z_log_var = Dense(latent_dim)(hidden)


def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(batch_size, latent_dim),
                              mean=0., stddev=epsilon_std)
    return z_mean + K.exp(z_log_var) * epsilon

# note that "output_shape" isn't necessary with the TensorFlow backend
# so you could write `Lambda(sampling)([z_mean, z_log_var])`
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])

# we instantiate these layers separately so as to reuse them later
decoder_hid = Dense(intermediate_dim, activation='relu')
decoder_upsample = Dense(filters * 14 * 14, activation='relu')

if K.image_data_format() == 'channels_first':
    output_shape = (batch_size, filters, 14, 14)
else:
    output_shape = (batch_size, 14, 14, filters)

decoder_reshape = Reshape(output_shape[1:])
decoder_deconv_1 = Conv2DTranspose(filters,
                                   kernel_size=num_conv,
                                   padding='same',
                                   strides=1,
                                   activation='relu')
decoder_deconv_2 = Conv2DTranspose(filters, num_conv,
                                   padding='same',
                                   strides=1,
                                   activation='relu')
if K.image_data_format() == 'channels_first':
    output_shape = (batch_size, filters, 29, 29)
else:
    output_shape = (batch_size, 29, 29, filters)
decoder_deconv_3_upsamp = Conv2DTranspose(filters,
                                          kernel_size=(3, 3),
                                          strides=(2, 2),
                                          padding='valid',
                                          activation='relu')
decoder_mean_squash = Conv2D(img_chns,
                             kernel_size=2,
                             padding='valid',
                             activation='sigmoid')

hid_decoded = decoder_hid(z)
up_decoded = decoder_upsample(hid_decoded)
reshape_decoded = decoder_reshape(up_decoded)
deconv_1_decoded = decoder_deconv_1(reshape_decoded)
deconv_2_decoded = decoder_deconv_2(deconv_1_decoded)
x_decoded_relu = decoder_deconv_3_upsamp(deconv_2_decoded)
x_decoded_mean_squash = decoder_mean_squash(x_decoded_relu)


# Custom loss layer
class CustomVariationalLayer(Layer):
    def __init__(self, **kwargs):
        self.is_placeholder = True
        super(CustomVariationalLayer, self).__init__(**kwargs)

    def vae_loss(self, x, x_decoded_mean_squash):
        x = K.flatten(x)
        x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
        xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
        kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
        return K.mean(xent_loss + kl_loss)

    def call(self, inputs):
        x = inputs[0]
        x_decoded_mean_squash = inputs[1]
        loss = self.vae_loss(x, x_decoded_mean_squash)
        self.add_loss(loss, inputs=inputs)
        # We don't use this output.
        return x


y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)
vae.summary()

# train the VAE on MNIST digits
(x_train, _), (x_test, y_test) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape((x_train.shape[0],) + original_img_size)
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape((x_test.shape[0],) + original_img_size)

print('x_train.shape:', x_train.shape)


/home/mike/ve/keras/lib/python3.5/site-packages/keras/engine/topology.py:1519: UserWarning: Model inputs must come from a Keras Input layer, they cannot be the output of a previous non-Input layer. Here, a tensor specified as input to "model_1" was not an Input tensor, it was generated by layer custom_variational_layer_1.
Note that input tensors are instantiated via `tensor = Input(shape)`.
The tensor that caused the issue was: /input_1
  str(x.name))
/home/mike/ve/keras/lib/python3.5/site-packages/ipykernel/__main__.py:117: UserWarning: Output "custom_variational_layer_1" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "custom_variational_layer_1" during training.
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_1 (InputLayer)             (100, 28, 28, 1)      0                                            
____________________________________________________________________________________________________
conv2d_1 (Conv2D)                (100, 28, 28, 1)      5                                            
____________________________________________________________________________________________________
conv2d_2 (Conv2D)                (100, 14, 14, 64)     320                                          
____________________________________________________________________________________________________
conv2d_3 (Conv2D)                (100, 14, 14, 64)     36928                                        
____________________________________________________________________________________________________
conv2d_4 (Conv2D)                (100, 14, 14, 64)     36928                                        
____________________________________________________________________________________________________
flatten_1 (Flatten)              (100, 12544)          0                                            
____________________________________________________________________________________________________
dense_1 (Dense)                  (100, 128)            1605760                                      
____________________________________________________________________________________________________
dense_2 (Dense)                  (100, 2)              258                                          
____________________________________________________________________________________________________
dense_3 (Dense)                  (100, 2)              258                                          
____________________________________________________________________________________________________
lambda_1 (Lambda)                (100, 2)              0                                            
____________________________________________________________________________________________________
dense_4 (Dense)                  (100, 128)            384                                          
____________________________________________________________________________________________________
dense_5 (Dense)                  (100, 12544)          1618176                                      
____________________________________________________________________________________________________
reshape_1 (Reshape)              (100, 14, 14, 64)     0                                            
____________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTransp (100, 14, 14, 64)     36928                                        
____________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTransp (100, 14, 14, 64)     36928                                        
____________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTransp (100, 29, 29, 64)     36928                                        
____________________________________________________________________________________________________
conv2d_5 (Conv2D)                (100, 28, 28, 1)      257                                          
____________________________________________________________________________________________________
custom_variational_layer_1 (Cust [(100, 28, 28, 1), (1 0                                            
====================================================================================================
Total params: 3,410,058
Trainable params: 3,410,058
Non-trainable params: 0
____________________________________________________________________________________________________
x_train.shape: (60000, 28, 28, 1)

In [5]:
vae.fit(x_train,
        shuffle=True,
        epochs=epochs,
        batch_size=batch_size,
        validation_data=(x_test, x_test))


---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-5-4c2dedca8577> in <module>()
      3         epochs=epochs,
      4         batch_size=batch_size,
----> 5         validation_data=(x_test, x_test))

/home/mike/ve/keras/lib/python3.5/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, **kwargs)
   1424                 sample_weight=val_sample_weight,
   1425                 check_batch_axis=False,
-> 1426                 batch_size=batch_size)
   1427             self._make_test_function()
   1428             val_f = self.test_function

/home/mike/ve/keras/lib/python3.5/site-packages/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_batch_axis, batch_size)
   1298                                     output_shapes,
   1299                                     check_batch_axis=False,
-> 1300                                     exception_prefix='model target')
   1301         sample_weights = _standardize_sample_weights(sample_weight,
   1302                                                      self._feed_output_names)

/home/mike/ve/keras/lib/python3.5/site-packages/keras/engine/training.py in _standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
     98             raise ValueError('The model expects ' + str(len(names)) +
     99                              ' input arrays, but only received one array. '
--> 100                              'Found: array with shape ' + str(data.shape))
    101         arrays = [data]
    102 

ValueError: The model expects 0 input arrays, but only received one array. Found: array with shape (10000, 28, 28, 1)

In [ ]:
# build a model to project inputs on the latent space
encoder = Model(x, z_mean)

# display a 2D plot of the digit classes in the latent space
x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
plt.figure(figsize=(6, 6))
plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test)
plt.colorbar()
plt.show()

# build a digit generator that can sample from the learned distribution
decoder_input = Input(shape=(latent_dim,))
_hid_decoded = decoder_hid(decoder_input)
_up_decoded = decoder_upsample(_hid_decoded)
_reshape_decoded = decoder_reshape(_up_decoded)
_deconv_1_decoded = decoder_deconv_1(_reshape_decoded)
_deconv_2_decoded = decoder_deconv_2(_deconv_1_decoded)
_x_decoded_relu = decoder_deconv_3_upsamp(_deconv_2_decoded)
_x_decoded_mean_squash = decoder_mean_squash(_x_decoded_relu)
generator = Model(decoder_input, _x_decoded_mean_squash)

In [ ]:
# display a 2D manifold of the digits
n = 15  # figure with 15x15 digits
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
# linearly spaced coordinates on the unit square were transformed through the inverse CDF (ppf) of the Gaussian
# to produce values of the latent variables z, since the prior of the latent space is Gaussian
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi, yi]])
        z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
        x_decoded = generator.predict(z_sample, batch_size=batch_size)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i * digit_size: (i + 1) * digit_size,
               j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()