Denoising Autoencoder

This is just another application of the autoencoder to denoise images. (This was not discussed in the paper that this work originated from)

The following code is a mere rearrangement of the code from the great tutorial below: https://blog.keras.io/building-autoencoders-in-keras.html


In [1]:
from keras.datasets import mnist
import numpy as np

(x_train, _), (x_test, _) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))  # adapt this if using `channels_first` image data format
x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))  # adapt this if using `channels_first` image data format

noise_factor = 0.5
x_train_noisy = x_train + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_train.shape) 
x_test_noisy = x_test + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_test.shape) 

x_train_noisy = np.clip(x_train_noisy, 0., 1.)
x_test_noisy = np.clip(x_test_noisy, 0., 1.)


Using Theano backend.
WARNING (theano.sandbox.cuda): The cuda backend is deprecated and will be removed in the next release (v0.10).  Please switch to the gpuarray backend. You can get more information about how to switch at this URL:
 https://github.com/Theano/Theano/wiki/Converting-to-the-new-gpu-back-end%28gpuarray%29

Using gpu device 0: GeForce GTX 760 (CNMeM is disabled, cuDNN not available)

In [2]:
import matplotlib.pyplot as plt

n = 10
plt.figure(figsize=(20, 2))
for i in range(n):
    which = np.random.randint(1, len(x_test[0]))    
    ax = plt.subplot(1, n, i + 1)
    plt.imshow(x_test_noisy[which].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()



In [3]:
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from keras.models import Model

input_img = Input(shape=(28, 28, 1))  # adapt this if using `channels_first` image data format

x = Conv2D(32, (3, 3), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2, 2), padding='same')(x)

# at this point the representation is (7, 7, 32)

x = Conv2D(32, (3, 3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')

In [4]:
autoencoder.fit(x_train_noisy, x_train,
                epochs=30,
                batch_size=128,
                shuffle=True,
                validation_data=(x_test_noisy, x_test))


Train on 60000 samples, validate on 10000 samples
Epoch 1/30
60000/60000 [==============================] - 105s - loss: 0.2030 - val_loss: 0.1349
Epoch 2/30
60000/60000 [==============================] - 106s - loss: 0.1251 - val_loss: 0.1175
Epoch 3/30
60000/60000 [==============================] - 109s - loss: 0.1159 - val_loss: 0.1109
Epoch 4/30
60000/60000 [==============================] - 107s - loss: 0.1113 - val_loss: 0.1111
Epoch 5/30
60000/60000 [==============================] - 107s - loss: 0.1086 - val_loss: 0.1066
Epoch 6/30
60000/60000 [==============================] - 108s - loss: 0.1066 - val_loss: 0.1053
Epoch 7/30
60000/60000 [==============================] - 109s - loss: 0.1054 - val_loss: 0.1028
Epoch 8/30
60000/60000 [==============================] - 108s - loss: 0.1041 - val_loss: 0.1018
Epoch 9/30
60000/60000 [==============================] - 107s - loss: 0.1033 - val_loss: 0.1021
Epoch 10/30
60000/60000 [==============================] - 107s - loss: 0.1026 - val_loss: 0.1007
Epoch 11/30
60000/60000 [==============================] - 107s - loss: 0.1021 - val_loss: 0.1019
Epoch 12/30
60000/60000 [==============================] - 107s - loss: 0.1018 - val_loss: 0.1002
Epoch 13/30
60000/60000 [==============================] - 107s - loss: 0.1014 - val_loss: 0.1006
Epoch 14/30
60000/60000 [==============================] - 109s - loss: 0.1010 - val_loss: 0.0994
Epoch 15/30
60000/60000 [==============================] - 107s - loss: 0.1005 - val_loss: 0.0994
Epoch 16/30
60000/60000 [==============================] - 107s - loss: 0.1002 - val_loss: 0.0996
Epoch 17/30
60000/60000 [==============================] - 107s - loss: 0.1000 - val_loss: 0.0992
Epoch 18/30
60000/60000 [==============================] - 107s - loss: 0.0996 - val_loss: 0.0994
Epoch 19/30
60000/60000 [==============================] - 107s - loss: 0.0995 - val_loss: 0.0980
Epoch 20/30
60000/60000 [==============================] - 6713s - loss: 0.0991 - val_loss: 0.0980
Epoch 21/30
60000/60000 [==============================] - 106s - loss: 0.0989 - val_loss: 0.0981
Epoch 22/30
60000/60000 [==============================] - 107s - loss: 0.0988 - val_loss: 0.0986
Epoch 23/30
60000/60000 [==============================] - 110s - loss: 0.0987 - val_loss: 0.0982
Epoch 24/30
60000/60000 [==============================] - 110s - loss: 0.0984 - val_loss: 0.0985
Epoch 25/30
60000/60000 [==============================] - 110s - loss: 0.0982 - val_loss: 0.0982
Epoch 26/30
60000/60000 [==============================] - 110s - loss: 0.0980 - val_loss: 0.0971
Epoch 27/30
60000/60000 [==============================] - 108s - loss: 0.0980 - val_loss: 0.0977
Epoch 28/30
60000/60000 [==============================] - 107s - loss: 0.0979 - val_loss: 0.0975
Epoch 29/30
60000/60000 [==============================] - 107s - loss: 0.0978 - val_loss: 0.0971
Epoch 30/30
60000/60000 [==============================] - 107s - loss: 0.0976 - val_loss: 0.0977
Out[4]:
<keras.callbacks.History at 0x24880748>

In [7]:
decoded_imgs = autoencoder.predict(x_test_noisy)

n = 10  # how many digits we will display
plt.figure(figsize=(20, 4))
for i in range(n):
    which = np.random.randint(1, len(x_test_noisy[0]))
    
    # display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test_noisy[which].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[which].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)