Convolutional autoencoder

Since our inputs are images, it makes sense to use convolutional neural networks (convnets) as encoders and decoders. In practical settings, autoencoders applied to images are always convolutional autoencoders --they simply perform much better.

Let's implement one. The encoder will consist in a stack of Conv2D and MaxPooling2D layers (max pooling being used for spatial down-sampling), while the decoder will consist in a stack of Conv2D and UpSampling2D layers.


In [4]:
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from keras.models import Model
from keras import backend as K

input_img = Input(shape=(32, 32, 3))  # adapt this if using `channels_first` image data format

x1 = Conv2D(16, (3, 3), activation='relu', padding='same')(input_img)
x2 = MaxPooling2D((2, 2), padding='same')(x1)
x3 = Conv2D(8, (3, 3), activation='relu', padding='same')(x2)
x4 = MaxPooling2D((2, 2), padding='same')(x3)
x5 = Conv2D(8, (3, 3), activation='relu', padding='same')(x4)
encoded = MaxPooling2D((2, 2), padding='same')(x5)

# at this point the representation is (4, 4, 8) i.e. 128-dimensional

x6 = Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
x7 = UpSampling2D((2, 2))(x6)
x8 = Conv2D(8, (3, 3), activation='relu', padding='same')(x7)
x9 = UpSampling2D((2, 2))(x8)
x10 = Conv2D(16, (3, 3), activation='relu', padding='same')(x9)
x11 = UpSampling2D((2, 2))(x10)
decoded = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x11)

autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adagrad', loss='binary_crossentropy')

In [5]:
from keras.datasets import cifar10
import numpy as np

(x_train, _), (x_test, _) = cifar10.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 32, 32, 3))  # adapt this if using `channels_first` image data format
x_test = np.reshape(x_test, (len(x_test), 32, 32, 3))  # adapt this if using `channels_first` image data format

In [6]:
autoencoder.fit(x_train, x_train,
                epochs=50,
                batch_size=128,
                shuffle=True,
                validation_data=(x_test, x_test))


Train on 50000 samples, validate on 10000 samples
Epoch 1/50
50000/50000 [==============================] - 143s - loss: 0.6273 - val_loss: 0.6111
Epoch 2/50
50000/50000 [==============================] - 126s - loss: 0.6075 - val_loss: 0.6041
Epoch 3/50
50000/50000 [==============================] - 131s - loss: 0.6020 - val_loss: 0.6011
Epoch 4/50
50000/50000 [==============================] - 130s - loss: 0.5977 - val_loss: 0.5963
Epoch 5/50
50000/50000 [==============================] - 134s - loss: 0.5938 - val_loss: 0.5933
Epoch 6/50
50000/50000 [==============================] - 132s - loss: 0.5921 - val_loss: 0.5934
Epoch 7/50
50000/50000 [==============================] - 130s - loss: 0.5909 - val_loss: 0.5910
Epoch 8/50
50000/50000 [==============================] - 136s - loss: 0.5901 - val_loss: 0.5904
Epoch 9/50
50000/50000 [==============================] - 133s - loss: 0.5894 - val_loss: 0.5897
Epoch 10/50
50000/50000 [==============================] - 131s - loss: 0.5888 - val_loss: 0.5897
Epoch 11/50
50000/50000 [==============================] - 139s - loss: 0.5883 - val_loss: 0.5888
Epoch 12/50
50000/50000 [==============================] - 138s - loss: 0.5879 - val_loss: 0.5884
Epoch 13/50
50000/50000 [==============================] - 140s - loss: 0.5875 - val_loss: 0.5882
Epoch 14/50
50000/50000 [==============================] - 137s - loss: 0.5871 - val_loss: 0.5882
Epoch 15/50
50000/50000 [==============================] - 121s - loss: 0.5867 - val_loss: 0.5872
Epoch 16/50
50000/50000 [==============================] - 122s - loss: 0.5865 - val_loss: 0.5872
Epoch 17/50
50000/50000 [==============================] - 122s - loss: 0.5861 - val_loss: 0.5868
Epoch 18/50
50000/50000 [==============================] - 121s - loss: 0.5858 - val_loss: 0.5871
Epoch 19/50
50000/50000 [==============================] - 152s - loss: 0.5856 - val_loss: 0.5868
Epoch 20/50
50000/50000 [==============================] - 194s - loss: 0.5853 - val_loss: 0.5859
Epoch 21/50
50000/50000 [==============================] - 197s - loss: 0.5850 - val_loss: 0.5857
Epoch 22/50
50000/50000 [==============================] - 191s - loss: 0.5847 - val_loss: 0.5854
Epoch 23/50
50000/50000 [==============================] - 192s - loss: 0.5846 - val_loss: 0.5852
Epoch 24/50
50000/50000 [==============================] - 161s - loss: 0.5844 - val_loss: 0.5852
Epoch 25/50
50000/50000 [==============================] - 167s - loss: 0.5842 - val_loss: 0.5848
Epoch 26/50
50000/50000 [==============================] - 156s - loss: 0.5841 - val_loss: 0.5852
Epoch 27/50
50000/50000 [==============================] - 158s - loss: 0.5839 - val_loss: 0.5845
Epoch 28/50
50000/50000 [==============================] - 152s - loss: 0.5838 - val_loss: 0.5844
Epoch 29/50
50000/50000 [==============================] - 151s - loss: 0.5837 - val_loss: 0.5842
Epoch 30/50
50000/50000 [==============================] - 152s - loss: 0.5834 - val_loss: 0.5843
Epoch 31/50
50000/50000 [==============================] - 151s - loss: 0.5833 - val_loss: 0.5841
Epoch 32/50
50000/50000 [==============================] - 154s - loss: 0.5832 - val_loss: 0.5838
Epoch 33/50
50000/50000 [==============================] - 153s - loss: 0.5831 - val_loss: 0.5838
Epoch 34/50
50000/50000 [==============================] - 153s - loss: 0.5829 - val_loss: 0.5837
Epoch 35/50
50000/50000 [==============================] - 147s - loss: 0.5829 - val_loss: 0.5835
Epoch 36/50
50000/50000 [==============================] - 147s - loss: 0.5827 - val_loss: 0.5834
Epoch 37/50
50000/50000 [==============================] - 147s - loss: 0.5826 - val_loss: 0.5833
Epoch 38/50
50000/50000 [==============================] - 146s - loss: 0.5825 - val_loss: 0.5832
Epoch 39/50
50000/50000 [==============================] - 147s - loss: 0.5824 - val_loss: 0.5832
Epoch 40/50
50000/50000 [==============================] - 154s - loss: 0.5823 - val_loss: 0.5832
Epoch 41/50
50000/50000 [==============================] - 155s - loss: 0.5822 - val_loss: 0.5829
Epoch 42/50
50000/50000 [==============================] - 154s - loss: 0.5821 - val_loss: 0.5828
Epoch 43/50
50000/50000 [==============================] - 149s - loss: 0.5819 - val_loss: 0.5827
Epoch 44/50
50000/50000 [==============================] - 151s - loss: 0.5819 - val_loss: 0.5826
Epoch 45/50
50000/50000 [==============================] - 153s - loss: 0.5818 - val_loss: 0.5825
Epoch 46/50
50000/50000 [==============================] - 155s - loss: 0.5817 - val_loss: 0.5824
Epoch 47/50
50000/50000 [==============================] - 159s - loss: 0.5817 - val_loss: 0.5826
Epoch 48/50
50000/50000 [==============================] - 154s - loss: 0.5815 - val_loss: 0.5829
Epoch 49/50
50000/50000 [==============================] - 162s - loss: 0.5814 - val_loss: 0.5822
Epoch 50/50
50000/50000 [==============================] - 157s - loss: 0.5814 - val_loss: 0.5821
Out[6]:
<keras.callbacks.History at 0x2bb89c77588>

The model converges to a loss of 0.094, significantly better than any of our previous models (this is in large part due to the higher entropic capacity of the encoded representation, 128 dimensions vs. 32 previously). Let's take a look at the reconstructed digits:


In [7]:
from keras.models import load_model

autoencoder.save('cifar10_autoencoders.h5')  # creates a HDF5 file 'my_model.h5'
#del model  # deletes the existing model

In [50]:
# returns a compiled model
# identical to the previous one
autoencoder = load_model('cifar10_autoencoders.h5')

In [51]:
import matplotlib.pyplot as plt

decoded_imgs = autoencoder.predict(x_test)

n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
    # display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i].reshape(32, 32, 3))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + n + 1)
    plt.imshow(decoded_imgs[i].reshape(32, 32, 3))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()


Plotting the weights from the first layer


In [10]:
import matplotlib.pyplot as plt

n = 10

for i in range(n):
    fig = plt.figure(figsize=(4, 10))
    conv_1 = np.asarray(autoencoder.layers[1].get_weights())[0][:,:,0,i]
    ax = fig.add_subplot(111)
    plt.imshow(conv_1.transpose())
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    plt.show()


<matplotlib.figure.Figure at 0x2bb8a11b208>

In [11]:
from keras import backend as K

In [44]:
# K.learning_phase() is a flag that indicates if the network is in training or
# predict phase. It allow layer (e.g. Dropout) to only be applied during training
inputs = [K.learning_phase()] + autoencoder.inputs

_layer1_f = K.function(inputs, [x2])
def convout1_f(X):
    # The [0] is to disable the training phase flag
    return _layer1_f([0] + [X])

_layer2_f = K.function(inputs, [x4])
def convout2_f(X):
    # The [0] is to disable the training phase flag
    return _layer2_f([0] + [X])

_layer3_f = K.function(inputs, [encoded])
def convout3_f(X):
    # The [0] is to disable the training phase flag
    return _layer3_f([0] + [X])

_up_layer1_f = K.function(inputs, [x6])
def convout4_f(X):
    # The [0] is to disable the training phase flag
    return _up_layer1_f([0] + [X])

_up_layer2_f = K.function(inputs, [x8])
def convout5_f(X):
    # The [0] is to disable the training phase flag
    return _up_layer2_f([0] + [X])

_up_layer3_f = K.function(inputs, [x10])
def convout6_f(X):
    # The [0] is to disable the training phase flag
    return _up_layer3_f([0] + [X])

_up_layer4_f = K.function(inputs, [decoded])
def convout7_f(X):
    # The [0] is to disable the training phase flag
    return _up_layer4_f([0] + [X])

In [13]:
x2


Out[13]:
<tf.Tensor 'max_pooling2d_4/MaxPool:0' shape=(?, 16, 16, 16) dtype=float32>

In [14]:
i = 1
x = x_test[i:i+1]

Visualizing the first convnet/output layer_1 with sample first test image


In [15]:
np.squeeze(np.squeeze(np.array(convout1_f(x)),0),0).shape


Out[15]:
(16, 16, 16)

In [73]:
#Plotting conv_1
i = 1
x = x_test[i:i+1]
check = np.squeeze(np.squeeze(np.array(convout1_f(x)),0),0)

temp = x[0,:,:,:]
fig, axes = plt.subplots(1, 1, figsize=(3, 3))
plt.imshow(temp)
plt.show()


k = 0
while k < check.shape[2]:
    #plt.figure()
    #plt.subplot(231 + i)
    fig, axes = plt.subplots(4, 4, figsize=(5, 5))
    for i in range(4):
        for j in range(4):
            axes[i,j].imshow(check[:,:,k])
            k += 1
    #axes[0, 0].imshow(R, cmap='jet')
    #plt.imshow(check[:,:,i])

plt.show()



In [17]:
check.shape


Out[17]:
(16, 16, 16)

Visualizing the second convnet/output layer_2 with sample test image


In [71]:
i = 1
x = x_test[i:i+1]
check = np.squeeze(np.squeeze(np.array(convout2_f(x)),0),0)
check.shape


Out[71]:
(8, 8, 8)

In [72]:
#Plotting conv_2

temp = x[0,:,:,:]
fig, axes = plt.subplots(1, 1, figsize=(3, 3))
plt.imshow(temp)
plt.show()


k = 0
while k < check.shape[2]:
    #plt.figure()
    #plt.subplot(231 + i)
    fig, axes = plt.subplots(2, 4, figsize=(5, 5))
    for i in range(2):
        for j in range(4):
            axes[i,j].imshow(check[:,:,k])
            k += 1
    #axes[0, 0].imshow(R, cmap='jet')
    #plt.imshow(check[:,:,i])

plt.show()


Plotting the third convnet/output layer_3 with sample test image


In [69]:
i = 1
x = x_test[i:i+1]
check = np.squeeze(np.squeeze(np.array(convout3_f(x)),0),0)
check.shape


Out[69]:
(4, 4, 8)

In [70]:
#Plotting conv_3

temp = x[0,:,:,:]
fig, axes = plt.subplots(1, 1, figsize=(3, 3))
plt.imshow(temp)
plt.show()


k = 0
while k < check.shape[2]:
    #plt.figure()
    #plt.subplot(231 + i)
    fig, axes = plt.subplots(2, 4, figsize=(5, 5))
    for i in range(2):
        for j in range(4):
            axes[i,j].imshow(check[:,:,k])
            k += 1
    #axes[0, 0].imshow(R, cmap='jet')
    #plt.imshow(check[:,:,i])

plt.show()


Visualizing the fourth convnet/decoded/output layer_4 with sample test image


In [67]:
i = 1
x = x_test[i:i+1]
check = np.squeeze(np.squeeze(np.array(convout4_f(x)),0),0)
check.shape


Out[67]:
(4, 4, 8)

In [68]:
#Plotting conv_4

temp = x[0,:,:,:]
fig, axes = plt.subplots(1, 1, figsize=(3, 3))
plt.imshow(temp)
plt.show()

k = 0
while k < check.shape[2]:
    #plt.figure()
    #plt.subplot(231 + i)
    fig, axes = plt.subplots(2, 4, figsize=(5, 5))
    for i in range(2):
        for j in range(4):
            axes[i,j].imshow(check[:,:,k])
            k += 1
    #axes[0, 0].imshow(R, cmap='jet')
    #plt.imshow(check[:,:,i])

plt.show()


Visualizing the fifth convnet/decoded/output layer_5 with sample test image


In [64]:
i = 1
x = x_test[i:i+1]
check = np.squeeze(np.squeeze(np.array(convout5_f(x)),0),0)
check.shape


Out[64]:
(8, 8, 8)

In [66]:
#Plotting conv_5

temp = x[0,:,:,:]
fig, axes = plt.subplots(1, 1, figsize=(3, 3))
plt.imshow(temp)
plt.show()


k = 0
while k < check.shape[2]:
    #plt.figure()
    #plt.subplot(231 + i)
    fig, axes = plt.subplots(2, 4, figsize=(5, 5))
    for i in range(2):
        for j in range(4):
            axes[i,j].imshow(check[:,:,k])
            k += 1
    #axes[0, 0].imshow(R, cmap='jet')
    #plt.imshow(check[:,:,i])

plt.show()


Visualizing the sixth convnet/decoded/output layer_6 with sample test image


In [61]:
i = 1
x = x_test[i:i+1]
check = np.squeeze(np.squeeze(np.array(convout6_f(x)),0),0)
check.shape


Out[61]:
(16, 16, 16)

In [63]:
#Plotting conv_6

temp = x[0,:,:,:]
fig, axes = plt.subplots(1, 1, figsize=(3, 3))
plt.imshow(temp)
plt.show()


k = 0
while k < check.shape[2]:
    #plt.figure()
    #plt.subplot(231 + i)
    fig, axes = plt.subplots(4, 4, figsize=(5, 5))
    for i in range(4):
        for j in range(4):
            axes[i,j].imshow(check[:,:,k])
            k += 1
    #axes[0, 0].imshow(R, cmap='jet')
    #plt.imshow(check[:,:,i])

plt.show()


Visualizing the final decoded/output layer with sample test image


In [45]:
i = 1
x = x_test[i:i+1]
check = np.squeeze(np.squeeze(np.array(convout7_f(x)),0),0)
check.shape


Out[45]:
(32, 32, 3)

In [54]:
#Plotting final layer

temp = x[0,:,:,:]
#plt.imshow(temp)
#fig = plt.figure()
#fig.set_figheight(3)
#fig.set_figwidth(3)
#plt.gray()
#ax.get_xaxis().set_visible(False)
#ax.get_yaxis().set_visible(False)
fig, axes = plt.subplots(1, 1, figsize=(3, 3))
plt.imshow(temp)
plt.show()


k = 2
while k < check.shape[2]:
    #plt.figure()
    #plt.subplot(231 + i)
    fig, axes = plt.subplots(1, 1, figsize=(3, 3))
    plt.imshow(check[:,:,:])
    k += 1
    #axes[0, 0].imshow(R, cmap='jet')
    #plt.imshow(check[:,:,i])

plt.show()



In [ ]: