Convolutional Variational Autoencoder - MNIST

Slightly modified from variational_autoencoder_deconv.py in the Keras examples folder:

https://github.com/keras-team/keras/blob/master/examples/variational_autoencoder_deconv.py


In [1]:
KERAS_MODEL_FILEPATH = '../../demos/data/mnist_vae/mnist_vae.h5'

In [2]:
import numpy as np
np.random.seed(1337)  # for reproducibility
import matplotlib.pyplot as plt
%matplotlib inline

from keras.layers import Input, Dense, Lambda, Flatten, Reshape
from keras.layers import Conv2D, Conv2DTranspose
from keras.models import Model
from keras import backend as K
from keras import objectives
from keras.datasets import mnist
from keras.callbacks import EarlyStopping, ModelCheckpoint


Using TensorFlow backend.
/home/leon/miniconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6
  return f(*args, **kwds)

In [3]:
# input image dimensions
img_rows, img_cols, img_chns = 28, 28, 1
# number of convolutional filters to use
filters = 64
# convolution kernel size
num_conv = 3

batch_size = 200
original_img_size = (img_rows, img_cols, img_chns)
latent_dim = 2
intermediate_dim = 128
epsilon_std = 0.01

x = Input(batch_shape=(batch_size,) + original_img_size)
conv_1 = Conv2D(img_chns, kernel_size=(2,2), padding='same', activation='relu')(x)
conv_2 = Conv2D(filters, kernel_size=(2,2), strides=(2,2), padding='same', activation='relu')(conv_1)
conv_3 = Conv2D(filters, kernel_size=num_conv, strides=(1,1), padding='same', activation='relu')(conv_2)
conv_4 = Conv2D(filters, kernel_size=num_conv, strides=(1,1), padding='same', activation='relu')(conv_3)
flat = Flatten()(conv_4)
hidden = Dense(intermediate_dim, activation='relu')(flat)

z_mean = Dense(latent_dim)(hidden)
z_log_var = Dense(latent_dim)(hidden)


def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0., stddev=epsilon_std)
    return z_mean + K.exp(z_log_var) * epsilon

# note that "output_shape" isn't necessary with the TensorFlow backend
# so you could write `Lambda(sampling)([z_mean, z_log_var])`
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])

# we instantiate these layers separately so as to reuse them later
decoder_hid = Dense(intermediate_dim, activation='relu')
decoder_upsample = Dense(filters * 14 * 14, activation='relu')

decoder_reshape = Reshape((14, 14, filters))
decoder_deconv_1 = Conv2DTranspose(filters, kernel_size=num_conv, strides=(1,1),
                                   padding='same', activation='relu')
decoder_deconv_2 = Conv2DTranspose(filters, kernel_size=num_conv, strides=(1,1),
                                   padding='same', activation='relu')
decoder_deconv_3_upsamp = Conv2DTranspose(filters, kernel_size=(2,2), strides=(2,2),
                                          padding='valid', activation='relu')
decoder_mean_squash = Conv2D(img_chns, kernel_size=(2,2), padding='same', activation='sigmoid')

hid_decoded = decoder_hid(z)
up_decoded = decoder_upsample(hid_decoded)
reshape_decoded = decoder_reshape(up_decoded)
deconv_1_decoded = decoder_deconv_1(reshape_decoded)
deconv_2_decoded = decoder_deconv_2(deconv_1_decoded)
x_decoded_relu = decoder_deconv_3_upsamp(deconv_2_decoded)
x_decoded_mean_squash = decoder_mean_squash(x_decoded_relu)

def vae_loss(x, x_decoded_mean):
    # NOTE: binary_crossentropy expects a batch_size by dim
    # for x and x_decoded_mean, so we MUST flatten these!
    x = K.flatten(x)
    x_decoded_mean = K.flatten(x_decoded_mean)
    xent_loss = img_rows * img_cols * objectives.binary_crossentropy(x, x_decoded_mean)
    kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
    return xent_loss + kl_loss

vae = Model(x, x_decoded_mean_squash)
vae.compile(optimizer='adam', loss=vae_loss)

In [4]:
vae.summary()


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (200, 28, 28, 1)     0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (200, 28, 28, 1)     5           input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (200, 14, 14, 64)    320         conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (200, 14, 14, 64)    36928       conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (200, 14, 14, 64)    36928       conv2d_3[0][0]                   
__________________________________________________________________________________________________
flatten_1 (Flatten)             (200, 12544)         0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
dense_1 (Dense)                 (200, 128)           1605760     flatten_1[0][0]                  
__________________________________________________________________________________________________
dense_2 (Dense)                 (200, 2)             258         dense_1[0][0]                    
__________________________________________________________________________________________________
dense_3 (Dense)                 (200, 2)             258         dense_1[0][0]                    
__________________________________________________________________________________________________
lambda_1 (Lambda)               (200, 2)             0           dense_2[0][0]                    
                                                                 dense_3[0][0]                    
__________________________________________________________________________________________________
dense_4 (Dense)                 (200, 128)           384         lambda_1[0][0]                   
__________________________________________________________________________________________________
dense_5 (Dense)                 (200, 12544)         1618176     dense_4[0][0]                    
__________________________________________________________________________________________________
reshape_1 (Reshape)             (200, 14, 14, 64)    0           dense_5[0][0]                    
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (200, 14, 14, 64)    36928       reshape_1[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (200, 14, 14, 64)    36928       conv2d_transpose_1[0][0]         
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (200, 28, 28, 64)    16448       conv2d_transpose_2[0][0]         
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (200, 28, 28, 1)     257         conv2d_transpose_3[0][0]         
==================================================================================================
Total params: 3,389,578
Trainable params: 3,389,578
Non-trainable params: 0
__________________________________________________________________________________________________

In [5]:
epochs = 100

# train the VAE on MNIST digits
(x_train, _), (x_test, y_test) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape((x_train.shape[0],) + original_img_size)
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape((x_test.shape[0],) + original_img_size)

print('x_train.shape:', x_train.shape)

# Early stopping
early_stopping = EarlyStopping(monitor='val_loss', verbose=1, patience=20)

vae.fit(x_train, x_train,
        validation_data=(x_test, x_test),
        shuffle=True, epochs=epochs, batch_size=batch_size, verbose=2,
        callbacks=[early_stopping])


x_train.shape: (60000, 28, 28, 1)
Train on 60000 samples, validate on 10000 samples
Epoch 1/100
 - 8s - loss: 187.0373 - val_loss: 156.7114
Epoch 2/100
 - 7s - loss: 152.0429 - val_loss: 148.6483
Epoch 3/100
 - 7s - loss: 146.5286 - val_loss: 144.5534
Epoch 4/100
 - 7s - loss: 143.1876 - val_loss: 142.3928
Epoch 5/100
 - 7s - loss: 140.9927 - val_loss: 140.8551
Epoch 6/100
 - 7s - loss: 139.4741 - val_loss: 139.1728
Epoch 7/100
 - 7s - loss: 138.1855 - val_loss: 139.5438
Epoch 8/100
 - 7s - loss: 137.1530 - val_loss: 138.0423
Epoch 9/100
 - 7s - loss: 136.2261 - val_loss: 137.1981
Epoch 10/100
 - 7s - loss: 135.5949 - val_loss: 136.8874
Epoch 11/100
 - 7s - loss: 134.7879 - val_loss: 136.3957
Epoch 12/100
 - 7s - loss: 134.0974 - val_loss: 136.2742
Epoch 13/100
 - 8s - loss: 133.5830 - val_loss: 136.2701
Epoch 14/100
 - 7s - loss: 133.1168 - val_loss: 136.1554
Epoch 15/100
 - 7s - loss: 132.5904 - val_loss: 135.3253
Epoch 16/100
 - 7s - loss: 132.1790 - val_loss: 135.5948
Epoch 17/100
 - 7s - loss: 131.7488 - val_loss: 135.2553
Epoch 18/100
 - 7s - loss: 131.2587 - val_loss: 134.4164
Epoch 19/100
 - 7s - loss: 130.9092 - val_loss: 134.6242
Epoch 20/100
 - 7s - loss: 130.5510 - val_loss: 134.7760
Epoch 21/100
 - 7s - loss: 130.2725 - val_loss: 134.7138
Epoch 22/100
 - 7s - loss: 130.0032 - val_loss: 134.2895
Epoch 23/100
 - 7s - loss: 129.6128 - val_loss: 134.8942
Epoch 24/100
 - 7s - loss: 129.4343 - val_loss: 134.4236
Epoch 25/100
 - 7s - loss: 129.0382 - val_loss: 134.1219
Epoch 26/100
 - 7s - loss: 128.7270 - val_loss: 134.5024
Epoch 27/100
 - 7s - loss: 128.5309 - val_loss: 134.4658
Epoch 28/100
 - 7s - loss: 128.3248 - val_loss: 134.3021
Epoch 29/100
 - 7s - loss: 128.0661 - val_loss: 134.8726
Epoch 30/100
 - 7s - loss: 127.8949 - val_loss: 134.2355
Epoch 31/100
 - 7s - loss: 127.6775 - val_loss: 134.7689
Epoch 32/100
 - 7s - loss: 127.3580 - val_loss: 134.5752
Epoch 33/100
 - 7s - loss: 127.2949 - val_loss: 134.3039
Epoch 34/100
 - 7s - loss: 127.1502 - val_loss: 134.4167
Epoch 35/100
 - 7s - loss: 126.8103 - val_loss: 134.4313
Epoch 36/100
 - 7s - loss: 126.8565 - val_loss: 134.4056
Epoch 37/100
 - 7s - loss: 126.4918 - val_loss: 134.3834
Epoch 38/100
 - 7s - loss: 126.4045 - val_loss: 134.6673
Epoch 39/100
 - 7s - loss: 126.2089 - val_loss: 134.5927
Epoch 40/100
 - 7s - loss: 126.0172 - val_loss: 134.3165
Epoch 41/100
 - 7s - loss: 125.9181 - val_loss: 134.7828
Epoch 42/100
 - 7s - loss: 125.7069 - val_loss: 134.6580
Epoch 43/100
 - 7s - loss: 125.6588 - val_loss: 134.2014
Epoch 44/100
 - 7s - loss: 125.4340 - val_loss: 134.8716
Epoch 45/100
 - 7s - loss: 125.3487 - val_loss: 134.6684
Epoch 00045: early stopping
Out[5]:
<keras.callbacks.History at 0x7f0df811d9e8>

In [6]:
# build a model to project inputs on the latent space
encoder = Model(x, z_mean)

# display a 2D plot of the digit classes in the latent space
x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
plt.figure(figsize=(10,10))
plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test)
plt.colorbar()
plt.show()


Decoder generator

To make the decoder generator serializable, we will redefine new layers and transfer weights over, rather than sharing the layers. Sharing layers will create new nodes, some with different output shapes, which causes problems for serialization.

Here we also set batch_size to 1


In [7]:
batch_size = 1

_hid_decoded = Dense(intermediate_dim, activation='relu')
_up_decoded = Dense(filters * 14 * 14, activation='relu')
_reshape_decoded = Reshape((14, 14, filters))
_deconv_1_decoded = Conv2DTranspose(filters, kernel_size=num_conv, strides=(1,1),
                                    padding='same', activation='relu')
_deconv_2_decoded = Conv2DTranspose(filters, kernel_size=num_conv, strides=(1,1),
                                    padding='same', activation='relu')
_x_decoded_relu = Conv2DTranspose(filters, kernel_size=(2,2), strides=(2,2),
                                  padding='valid', activation='relu')
_x_decoded_mean_squash = Conv2D(img_chns, kernel_size=(2,2), padding='same', activation='sigmoid')

decoder_input = Input(shape=(latent_dim,))
layer1 = _hid_decoded(decoder_input)
layer2 = _up_decoded(layer1)
layer3 = _reshape_decoded(layer2)
layer4 = _deconv_1_decoded(layer3)
layer5 = _deconv_2_decoded(layer4)
layer6 = _x_decoded_relu(layer5)
layer7 = _x_decoded_mean_squash(layer6)
generator = Model(decoder_input, layer7)

_hid_decoded.set_weights(decoder_hid.get_weights())
_up_decoded.set_weights(decoder_upsample.get_weights())
_deconv_1_decoded.set_weights(decoder_deconv_1.get_weights())
_deconv_2_decoded.set_weights(decoder_deconv_2.get_weights())
_x_decoded_relu.set_weights(decoder_deconv_3_upsamp.get_weights())
_x_decoded_mean_squash.set_weights(decoder_mean_squash.get_weights())

# display a 2D manifold of the digits
n = 15  # figure with 15x15 digits
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
# we will sample n points within [-1, 1] standard deviations
grid_x = np.linspace(-1, 1, n)
grid_y = np.linspace(-1, 1, n)

for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi, yi]])
        z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
        x_decoded = generator.predict(z_sample, batch_size=batch_size)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i * digit_size: (i + 1) * digit_size,
               j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(10,10))
plt.imshow(figure)
plt.show()



In [8]:
generator.save(KERAS_MODEL_FILEPATH)

In [ ]: