Since our inputs are images, it makes sense to use convolutional neural networks (convnets) as encoders and decoders. In practical settings, autoencoders applied to images are always convolutional autoencoders --they simply perform much better.
Let's implement one. The encoder will consist in a stack of Conv2D and MaxPooling2D layers (max pooling being used for spatial down-sampling), while the decoder will consist in a stack of Conv2D and UpSampling2D layers.
In [4]:
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from keras.models import Model
from keras import backend as K
input_img = Input(shape=(32, 32, 3)) # adapt this if using `channels_first` image data format
x1 = Conv2D(16, (3, 3), activation='relu', padding='same')(input_img)
x2 = MaxPooling2D((2, 2), padding='same')(x1)
x3 = Conv2D(8, (3, 3), activation='relu', padding='same')(x2)
x4 = MaxPooling2D((2, 2), padding='same')(x3)
x5 = Conv2D(8, (3, 3), activation='relu', padding='same')(x4)
encoded = MaxPooling2D((2, 2), padding='same')(x5)
# at this point the representation is (4, 4, 8) i.e. 128-dimensional
x6 = Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
x7 = UpSampling2D((2, 2))(x6)
x8 = Conv2D(8, (3, 3), activation='relu', padding='same')(x7)
x9 = UpSampling2D((2, 2))(x8)
x10 = Conv2D(16, (3, 3), activation='relu', padding='same')(x9)
x11 = UpSampling2D((2, 2))(x10)
decoded = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x11)
autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adagrad', loss='binary_crossentropy')
In [5]:
from keras.datasets import cifar10
import numpy as np
(x_train, _), (x_test, _) = cifar10.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 32, 32, 3)) # adapt this if using `channels_first` image data format
x_test = np.reshape(x_test, (len(x_test), 32, 32, 3)) # adapt this if using `channels_first` image data format
In [6]:
autoencoder.fit(x_train, x_train,
epochs=50,
batch_size=128,
shuffle=True,
validation_data=(x_test, x_test))
Out[6]:
The model converges to a loss of 0.094, significantly better than any of our previous models (this is in large part due to the higher entropic capacity of the encoded representation, 128 dimensions vs. 32 previously). Let's take a look at the reconstructed digits:
In [7]:
from keras.models import load_model
autoencoder.save('cifar10_autoencoders.h5') # creates a HDF5 file 'my_model.h5'
#del model # deletes the existing model
In [50]:
# returns a compiled model
# identical to the previous one
autoencoder = load_model('cifar10_autoencoders.h5')
In [51]:
import matplotlib.pyplot as plt
decoded_imgs = autoencoder.predict(x_test)
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
# display original
ax = plt.subplot(2, n, i + 1)
plt.imshow(x_test[i].reshape(32, 32, 3))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# display reconstruction
ax = plt.subplot(2, n, i + n + 1)
plt.imshow(decoded_imgs[i].reshape(32, 32, 3))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
In [10]:
import matplotlib.pyplot as plt
n = 10
for i in range(n):
fig = plt.figure(figsize=(4, 10))
conv_1 = np.asarray(autoencoder.layers[1].get_weights())[0][:,:,0,i]
ax = fig.add_subplot(111)
plt.imshow(conv_1.transpose())
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
In [11]:
from keras import backend as K
In [44]:
# K.learning_phase() is a flag that indicates if the network is in training or
# predict phase. It allow layer (e.g. Dropout) to only be applied during training
inputs = [K.learning_phase()] + autoencoder.inputs
_layer1_f = K.function(inputs, [x2])
def convout1_f(X):
# The [0] is to disable the training phase flag
return _layer1_f([0] + [X])
_layer2_f = K.function(inputs, [x4])
def convout2_f(X):
# The [0] is to disable the training phase flag
return _layer2_f([0] + [X])
_layer3_f = K.function(inputs, [encoded])
def convout3_f(X):
# The [0] is to disable the training phase flag
return _layer3_f([0] + [X])
_up_layer1_f = K.function(inputs, [x6])
def convout4_f(X):
# The [0] is to disable the training phase flag
return _up_layer1_f([0] + [X])
_up_layer2_f = K.function(inputs, [x8])
def convout5_f(X):
# The [0] is to disable the training phase flag
return _up_layer2_f([0] + [X])
_up_layer3_f = K.function(inputs, [x10])
def convout6_f(X):
# The [0] is to disable the training phase flag
return _up_layer3_f([0] + [X])
_up_layer4_f = K.function(inputs, [decoded])
def convout7_f(X):
# The [0] is to disable the training phase flag
return _up_layer4_f([0] + [X])
In [13]:
x2
Out[13]:
In [14]:
i = 1
x = x_test[i:i+1]
In [15]:
np.squeeze(np.squeeze(np.array(convout1_f(x)),0),0).shape
Out[15]:
In [73]:
#Plotting conv_1
i = 1
x = x_test[i:i+1]
check = np.squeeze(np.squeeze(np.array(convout1_f(x)),0),0)
temp = x[0,:,:,:]
fig, axes = plt.subplots(1, 1, figsize=(3, 3))
plt.imshow(temp)
plt.show()
k = 0
while k < check.shape[2]:
#plt.figure()
#plt.subplot(231 + i)
fig, axes = plt.subplots(4, 4, figsize=(5, 5))
for i in range(4):
for j in range(4):
axes[i,j].imshow(check[:,:,k])
k += 1
#axes[0, 0].imshow(R, cmap='jet')
#plt.imshow(check[:,:,i])
plt.show()
In [17]:
check.shape
Out[17]:
In [71]:
i = 1
x = x_test[i:i+1]
check = np.squeeze(np.squeeze(np.array(convout2_f(x)),0),0)
check.shape
Out[71]:
In [72]:
#Plotting conv_2
temp = x[0,:,:,:]
fig, axes = plt.subplots(1, 1, figsize=(3, 3))
plt.imshow(temp)
plt.show()
k = 0
while k < check.shape[2]:
#plt.figure()
#plt.subplot(231 + i)
fig, axes = plt.subplots(2, 4, figsize=(5, 5))
for i in range(2):
for j in range(4):
axes[i,j].imshow(check[:,:,k])
k += 1
#axes[0, 0].imshow(R, cmap='jet')
#plt.imshow(check[:,:,i])
plt.show()
In [69]:
i = 1
x = x_test[i:i+1]
check = np.squeeze(np.squeeze(np.array(convout3_f(x)),0),0)
check.shape
Out[69]:
In [70]:
#Plotting conv_3
temp = x[0,:,:,:]
fig, axes = plt.subplots(1, 1, figsize=(3, 3))
plt.imshow(temp)
plt.show()
k = 0
while k < check.shape[2]:
#plt.figure()
#plt.subplot(231 + i)
fig, axes = plt.subplots(2, 4, figsize=(5, 5))
for i in range(2):
for j in range(4):
axes[i,j].imshow(check[:,:,k])
k += 1
#axes[0, 0].imshow(R, cmap='jet')
#plt.imshow(check[:,:,i])
plt.show()
In [67]:
i = 1
x = x_test[i:i+1]
check = np.squeeze(np.squeeze(np.array(convout4_f(x)),0),0)
check.shape
Out[67]:
In [68]:
#Plotting conv_4
temp = x[0,:,:,:]
fig, axes = plt.subplots(1, 1, figsize=(3, 3))
plt.imshow(temp)
plt.show()
k = 0
while k < check.shape[2]:
#plt.figure()
#plt.subplot(231 + i)
fig, axes = plt.subplots(2, 4, figsize=(5, 5))
for i in range(2):
for j in range(4):
axes[i,j].imshow(check[:,:,k])
k += 1
#axes[0, 0].imshow(R, cmap='jet')
#plt.imshow(check[:,:,i])
plt.show()
In [64]:
i = 1
x = x_test[i:i+1]
check = np.squeeze(np.squeeze(np.array(convout5_f(x)),0),0)
check.shape
Out[64]:
In [66]:
#Plotting conv_5
temp = x[0,:,:,:]
fig, axes = plt.subplots(1, 1, figsize=(3, 3))
plt.imshow(temp)
plt.show()
k = 0
while k < check.shape[2]:
#plt.figure()
#plt.subplot(231 + i)
fig, axes = plt.subplots(2, 4, figsize=(5, 5))
for i in range(2):
for j in range(4):
axes[i,j].imshow(check[:,:,k])
k += 1
#axes[0, 0].imshow(R, cmap='jet')
#plt.imshow(check[:,:,i])
plt.show()
In [61]:
i = 1
x = x_test[i:i+1]
check = np.squeeze(np.squeeze(np.array(convout6_f(x)),0),0)
check.shape
Out[61]:
In [63]:
#Plotting conv_6
temp = x[0,:,:,:]
fig, axes = plt.subplots(1, 1, figsize=(3, 3))
plt.imshow(temp)
plt.show()
k = 0
while k < check.shape[2]:
#plt.figure()
#plt.subplot(231 + i)
fig, axes = plt.subplots(4, 4, figsize=(5, 5))
for i in range(4):
for j in range(4):
axes[i,j].imshow(check[:,:,k])
k += 1
#axes[0, 0].imshow(R, cmap='jet')
#plt.imshow(check[:,:,i])
plt.show()
In [45]:
i = 1
x = x_test[i:i+1]
check = np.squeeze(np.squeeze(np.array(convout7_f(x)),0),0)
check.shape
Out[45]:
In [54]:
#Plotting final layer
temp = x[0,:,:,:]
#plt.imshow(temp)
#fig = plt.figure()
#fig.set_figheight(3)
#fig.set_figwidth(3)
#plt.gray()
#ax.get_xaxis().set_visible(False)
#ax.get_yaxis().set_visible(False)
fig, axes = plt.subplots(1, 1, figsize=(3, 3))
plt.imshow(temp)
plt.show()
k = 2
while k < check.shape[2]:
#plt.figure()
#plt.subplot(231 + i)
fig, axes = plt.subplots(1, 1, figsize=(3, 3))
plt.imshow(check[:,:,:])
k += 1
#axes[0, 0].imshow(R, cmap='jet')
#plt.imshow(check[:,:,i])
plt.show()
In [ ]: