Non supervised learning

Autoencoders

Suppose we have only a set of unlabeled training examples $x_1,x_2,x_3, \dots $, where $x_i \in \Re^n$.

An autoencoder neural network is an unsupervised learning algorithm that applies backpropagation and uses a loss function that is optimal when setting the target values to be equal to the inputs, $y_i=x_i$.

To build an autoencoder, you need three things: an encoding function, a decoding function, and a distance function between the amount of information loss between the compressed representation of your data and the decompressed representation.

Source: https://blog.keras.io/building-autoencoders-in-keras.html

Two practical applications of autoencoders are data denoising, and dimensionality reduction for data visualization.

With appropriate dimensionality and sparsity constraints, autoencoders can learn data projections that are more interesting than PCA or other basic techniques.

We'll start simple, with a single fully-connected neural layer as encoder and as decoder:


In [1]:
# Source: Adapted from https://blog.keras.io/building-autoencoders-in-keras.html

from keras.layers import Input, Dense
from keras.models import Model

# this is the size of our encoded representations       
encoding_dim = 32  # 32 floats -> compression of factor 24.5, 
                   # assuming the input is 784 floats

input_img = Input(shape=(784,))

# encoded representation of the input
encoding_layer = Dense(encoding_dim, 
                       activation='relu')
encoded = encoding_layer(input_img)

# lossy reconstruction of the input
decoding_layer = Dense(784, 
                       activation='sigmoid')
decoded = decoding_layer(encoded)

# this model maps an input to its reconstruction
autoencoder = Model(input_img, decoded)


Using TensorFlow backend.

Let's also create a separate encoder model and a separate decoder model:


In [2]:
# this model maps an input to its encoded representation
encoding_model = Model(input_img, encoded)

# create a placeholder for an encoded input
# and create the decoder model
encoded_input = Input(shape=(encoding_dim,))
decoding_model = Model(encoded_input, decoding_layer(encoded_input))

autoencoder.compile(optimizer='adam', loss='mse')

Let's prepare our input data.


In [4]:
from keras.datasets import mnist
import numpy as np
(x_train, _), (x_test, _) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
print x_train.shape
print x_test.shape

autoencoder.fit(x_train, 
                x_train,
                nb_epoch=15,
                batch_size=256,
                shuffle=True,
                validation_data=(x_test, x_test))


(60000, 784)
(10000, 784)
Train on 60000 samples, validate on 10000 samples
Epoch 1/15
60000/60000 [==============================] - 3s - loss: 0.0160 - val_loss: 0.0147
Epoch 2/15
60000/60000 [==============================] - 4s - loss: 0.0143 - val_loss: 0.0133
Epoch 3/15
60000/60000 [==============================] - 4s - loss: 0.0131 - val_loss: 0.0123
Epoch 4/15
60000/60000 [==============================] - 4s - loss: 0.0122 - val_loss: 0.0115
Epoch 5/15
60000/60000 [==============================] - 4s - loss: 0.0116 - val_loss: 0.0110
Epoch 6/15
60000/60000 [==============================] - 4s - loss: 0.0112 - val_loss: 0.0107
Epoch 7/15
60000/60000 [==============================] - 5s - loss: 0.0109 - val_loss: 0.0105
Epoch 8/15
60000/60000 [==============================] - 5s - loss: 0.0107 - val_loss: 0.0103
Epoch 9/15
60000/60000 [==============================] - 4s - loss: 0.0106 - val_loss: 0.0102
Epoch 10/15
60000/60000 [==============================] - 4s - loss: 0.0105 - val_loss: 0.0101
Epoch 11/15
60000/60000 [==============================] - 4s - loss: 0.0104 - val_loss: 0.0101
Epoch 12/15
60000/60000 [==============================] - 4s - loss: 0.0104 - val_loss: 0.0100
Epoch 13/15
60000/60000 [==============================] - 4s - loss: 0.0103 - val_loss: 0.0100
Epoch 14/15
60000/60000 [==============================] - 4s - loss: 0.0103 - val_loss: 0.0099
Epoch 15/15
60000/60000 [==============================] - 4s - loss: 0.0102 - val_loss: 0.0099
Out[4]:
<keras.callbacks.History at 0x7faf08bf1510>

In [6]:
# encode and decode some digits
# note that we take them from the *test* set
encoded_imgs = encoding_model.predict(x_test)
decoded_imgs = decoding_model.predict(encoded_imgs)

import matplotlib.pyplot as plt
%matplotlib inline

n = 10  # how many digits we will display
plt.figure(figsize=(10, 4))
for i in range(n):
    # display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()


Adding depth and sparsity constraint on the encoded representations

In the previous example, the representations were only constrained by the size of the hidden layer (32). In such a situation, what typically happens is that the hidden layer is learning an approximation of PCA (principal component analysis). But another way to constrain the representations to be compact is to add a sparsity contraint on the activity of the hidden representations, so fewer units would "fire" at a given time.

In Keras, this can be done by adding an activity_regularizer to our Dense layer:


In [15]:
#autoencoder.reset_states()
#encoder.reset_states()
#decoder.reset_states()

from keras import regularizers
from keras import optimizers
from keras.regularizers import l2, activity_l1
from keras.layers import Input, Dense
from keras.models import Model

input_img = Input(shape=(784,))
encoded = Dense(128, activation='relu')(input_img)
encoded = Dense(64, activation='relu')(encoded)
encoded = Dense(32, activation='relu')(encoded)
decoded = Dense(64, activation='relu')(encoded)
decoded = Dense(128, activation='relu')(decoded)
decoded = Dense(784, activation='sigmoid')(decoded)

autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adadelta', 
                    loss='binary_crossentropy',
                    activity_regularizer=regularizers.l1(10e-5))

autoencoder.fit(x_train, x_train,
                nb_epoch=100,
                batch_size=256,
                shuffle=True,
                validation_data=(x_test, x_test))


/usr/local/lib/python2.7/dist-packages/keras/backend/tensorflow_backend.py:1057: UserWarning: Expected no kwargs, you passed 1
kwargs passed to function are ignored with Tensorflow backend
  warnings.warn('\n'.join(msg))
Train on 60000 samples, validate on 10000 samples
Epoch 1/100
60000/60000 [==============================] - 4s - loss: 0.3545 - val_loss: 0.2636
Epoch 2/100
60000/60000 [==============================] - 4s - loss: 0.2583 - val_loss: 0.2516
Epoch 3/100
60000/60000 [==============================] - 4s - loss: 0.2428 - val_loss: 0.2302
Epoch 4/100
60000/60000 [==============================] - 4s - loss: 0.2180 - val_loss: 0.2071
Epoch 5/100
60000/60000 [==============================] - 4s - loss: 0.2032 - val_loss: 0.1977
Epoch 6/100
60000/60000 [==============================] - 4s - loss: 0.1944 - val_loss: 0.1875
Epoch 7/100
60000/60000 [==============================] - 4s - loss: 0.1848 - val_loss: 0.1803
Epoch 8/100
60000/60000 [==============================] - 4s - loss: 0.1779 - val_loss: 0.1741
Epoch 9/100
60000/60000 [==============================] - 4s - loss: 0.1718 - val_loss: 0.1670
Epoch 10/100
60000/60000 [==============================] - 4s - loss: 0.1661 - val_loss: 0.1622
Epoch 11/100
60000/60000 [==============================] - 4s - loss: 0.1609 - val_loss: 0.1577
Epoch 12/100
60000/60000 [==============================] - 4s - loss: 0.1570 - val_loss: 0.1534
Epoch 13/100
60000/60000 [==============================] - 4s - loss: 0.1541 - val_loss: 0.1521
Epoch 14/100
60000/60000 [==============================] - 4s - loss: 0.1519 - val_loss: 0.1482
Epoch 15/100
60000/60000 [==============================] - 4s - loss: 0.1500 - val_loss: 0.1473
Epoch 16/100
60000/60000 [==============================] - 4s - loss: 0.1484 - val_loss: 0.1459
Epoch 17/100
60000/60000 [==============================] - 4s - loss: 0.1468 - val_loss: 0.1441
Epoch 18/100
60000/60000 [==============================] - 4s - loss: 0.1447 - val_loss: 0.1425
Epoch 19/100
60000/60000 [==============================] - 4s - loss: 0.1429 - val_loss: 0.1408
Epoch 20/100
60000/60000 [==============================] - 4s - loss: 0.1413 - val_loss: 0.1390
Epoch 21/100
60000/60000 [==============================] - 4s - loss: 0.1399 - val_loss: 0.1374
Epoch 22/100
60000/60000 [==============================] - 4s - loss: 0.1386 - val_loss: 0.1367
Epoch 23/100
60000/60000 [==============================] - 4s - loss: 0.1372 - val_loss: 0.1349
Epoch 24/100
60000/60000 [==============================] - 4s - loss: 0.1359 - val_loss: 0.1336
Epoch 25/100
60000/60000 [==============================] - 4s - loss: 0.1347 - val_loss: 0.1323
Epoch 26/100
60000/60000 [==============================] - 4s - loss: 0.1334 - val_loss: 0.1315
Epoch 27/100
60000/60000 [==============================] - 4s - loss: 0.1321 - val_loss: 0.1309
Epoch 28/100
60000/60000 [==============================] - 4s - loss: 0.1311 - val_loss: 0.1292
Epoch 29/100
60000/60000 [==============================] - 4s - loss: 0.1300 - val_loss: 0.1272
Epoch 30/100
60000/60000 [==============================] - 4s - loss: 0.1288 - val_loss: 0.1266
Epoch 31/100
60000/60000 [==============================] - 4s - loss: 0.1277 - val_loss: 0.1255
Epoch 32/100
60000/60000 [==============================] - 4s - loss: 0.1267 - val_loss: 0.1250
Epoch 33/100
60000/60000 [==============================] - 4s - loss: 0.1257 - val_loss: 0.1236
Epoch 34/100
60000/60000 [==============================] - 4s - loss: 0.1250 - val_loss: 0.1222
Epoch 35/100
60000/60000 [==============================] - 4s - loss: 0.1240 - val_loss: 0.1215
Epoch 36/100
60000/60000 [==============================] - 4s - loss: 0.1233 - val_loss: 0.1224
Epoch 37/100
60000/60000 [==============================] - 4s - loss: 0.1224 - val_loss: 0.1201
Epoch 38/100
60000/60000 [==============================] - 4s - loss: 0.1217 - val_loss: 0.1202
Epoch 39/100
60000/60000 [==============================] - 4s - loss: 0.1211 - val_loss: 0.1191
Epoch 40/100
60000/60000 [==============================] - 4s - loss: 0.1204 - val_loss: 0.1181
Epoch 41/100
60000/60000 [==============================] - 4s - loss: 0.1195 - val_loss: 0.1182
Epoch 42/100
60000/60000 [==============================] - 4s - loss: 0.1190 - val_loss: 0.1172
Epoch 43/100
60000/60000 [==============================] - 4s - loss: 0.1182 - val_loss: 0.1186
Epoch 44/100
60000/60000 [==============================] - 4s - loss: 0.1177 - val_loss: 0.1180
Epoch 45/100
60000/60000 [==============================] - 4s - loss: 0.1170 - val_loss: 0.1146
Epoch 46/100
60000/60000 [==============================] - 4s - loss: 0.1164 - val_loss: 0.1155
Epoch 47/100
60000/60000 [==============================] - 4s - loss: 0.1160 - val_loss: 0.1141
Epoch 48/100
60000/60000 [==============================] - 4s - loss: 0.1155 - val_loss: 0.1146
Epoch 49/100
60000/60000 [==============================] - 4s - loss: 0.1149 - val_loss: 0.1130
Epoch 50/100
60000/60000 [==============================] - 4s - loss: 0.1144 - val_loss: 0.1122
Epoch 51/100
60000/60000 [==============================] - 4s - loss: 0.1139 - val_loss: 0.1130
Epoch 52/100
60000/60000 [==============================] - 4s - loss: 0.1135 - val_loss: 0.1127
Epoch 53/100
60000/60000 [==============================] - 4s - loss: 0.1130 - val_loss: 0.1122
Epoch 54/100
60000/60000 [==============================] - 5s - loss: 0.1126 - val_loss: 0.1108
Epoch 55/100
60000/60000 [==============================] - 4s - loss: 0.1122 - val_loss: 0.1104
Epoch 56/100
60000/60000 [==============================] - 4s - loss: 0.1117 - val_loss: 0.1102
Epoch 57/100
60000/60000 [==============================] - 4s - loss: 0.1114 - val_loss: 0.1123
Epoch 58/100
60000/60000 [==============================] - 4s - loss: 0.1109 - val_loss: 0.1105
Epoch 59/100
60000/60000 [==============================] - 4s - loss: 0.1108 - val_loss: 0.1091
Epoch 60/100
60000/60000 [==============================] - 5s - loss: 0.1103 - val_loss: 0.1082
Epoch 61/100
60000/60000 [==============================] - 5s - loss: 0.1099 - val_loss: 0.1086
Epoch 62/100
60000/60000 [==============================] - 4s - loss: 0.1096 - val_loss: 0.1090
Epoch 63/100
60000/60000 [==============================] - 4s - loss: 0.1092 - val_loss: 0.1071
Epoch 64/100
60000/60000 [==============================] - 4s - loss: 0.1088 - val_loss: 0.1086
Epoch 65/100
60000/60000 [==============================] - 4s - loss: 0.1086 - val_loss: 0.1067
Epoch 66/100
60000/60000 [==============================] - 4s - loss: 0.1081 - val_loss: 0.1069
Epoch 67/100
60000/60000 [==============================] - 4s - loss: 0.1080 - val_loss: 0.1065
Epoch 68/100
60000/60000 [==============================] - 4s - loss: 0.1075 - val_loss: 0.1065
Epoch 69/100
60000/60000 [==============================] - 4s - loss: 0.1072 - val_loss: 0.1073
Epoch 70/100
60000/60000 [==============================] - 4s - loss: 0.1070 - val_loss: 0.1068
Epoch 71/100
60000/60000 [==============================] - 4s - loss: 0.1066 - val_loss: 0.1044
Epoch 72/100
60000/60000 [==============================] - 4s - loss: 0.1064 - val_loss: 0.1050
Epoch 73/100
60000/60000 [==============================] - 4s - loss: 0.1061 - val_loss: 0.1044
Epoch 74/100
60000/60000 [==============================] - 5s - loss: 0.1058 - val_loss: 0.1047
Epoch 75/100
60000/60000 [==============================] - 4s - loss: 0.1056 - val_loss: 0.1040
Epoch 76/100
60000/60000 [==============================] - 4s - loss: 0.1052 - val_loss: 0.1035
Epoch 77/100
60000/60000 [==============================] - 4s - loss: 0.1049 - val_loss: 0.1025
Epoch 78/100
60000/60000 [==============================] - 4s - loss: 0.1046 - val_loss: 0.1024
Epoch 79/100
60000/60000 [==============================] - 4s - loss: 0.1044 - val_loss: 0.1035
Epoch 80/100
60000/60000 [==============================] - 4s - loss: 0.1042 - val_loss: 0.1039
Epoch 81/100
60000/60000 [==============================] - 4s - loss: 0.1040 - val_loss: 0.1041
Epoch 82/100
60000/60000 [==============================] - 4s - loss: 0.1037 - val_loss: 0.1034
Epoch 83/100
60000/60000 [==============================] - 4s - loss: 0.1034 - val_loss: 0.1027
Epoch 84/100
60000/60000 [==============================] - 4s - loss: 0.1032 - val_loss: 0.1025
Epoch 85/100
60000/60000 [==============================] - 4s - loss: 0.1031 - val_loss: 0.1026
Epoch 86/100
60000/60000 [==============================] - 4s - loss: 0.1028 - val_loss: 0.1011
Epoch 87/100
60000/60000 [==============================] - 4s - loss: 0.1026 - val_loss: 0.1013
Epoch 88/100
60000/60000 [==============================] - 4s - loss: 0.1024 - val_loss: 0.1016
Epoch 89/100
60000/60000 [==============================] - 4s - loss: 0.1023 - val_loss: 0.1028
Epoch 90/100
60000/60000 [==============================] - 4s - loss: 0.1019 - val_loss: 0.1009
Epoch 91/100
60000/60000 [==============================] - 4s - loss: 0.1019 - val_loss: 0.1020
Epoch 92/100
60000/60000 [==============================] - 4s - loss: 0.1016 - val_loss: 0.1012
Epoch 93/100
60000/60000 [==============================] - 5s - loss: 0.1016 - val_loss: 0.1005
Epoch 94/100
60000/60000 [==============================] - 4s - loss: 0.1013 - val_loss: 0.0994
Epoch 95/100
60000/60000 [==============================] - 4s - loss: 0.1010 - val_loss: 0.1004
Epoch 96/100
60000/60000 [==============================] - 4s - loss: 0.1010 - val_loss: 0.1003
Epoch 97/100
60000/60000 [==============================] - 4s - loss: 0.1007 - val_loss: 0.1001
Epoch 98/100
60000/60000 [==============================] - 4s - loss: 0.1006 - val_loss: 0.0992
Epoch 99/100
60000/60000 [==============================] - 4s - loss: 0.1005 - val_loss: 0.0992
Epoch 100/100
60000/60000 [==============================] - 4s - loss: 0.1004 - val_loss: 0.0996
Out[15]:
<keras.callbacks.History at 0x7ff407e34d10>

In [ ]:

Convolutional Autoencoders

Since our inputs are images, it makes sense to use convolutional neural networks (convnets) as encoders and decoders.


In [16]:
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from keras.models import Model
from keras import backend as K

input_img = Input(shape=(28, 28, 1))  

x = Conv2D(16, 3, 3, activation='relu', border_mode='same')(input_img)
x = MaxPooling2D((2, 2), border_mode='same')(x)
x = Conv2D(8, 3, 3, activation='relu', border_mode='same')(x)
x = MaxPooling2D((2, 2), border_mode='same')(x)
x = Conv2D(8, 3, 3, activation='relu', border_mode='same')(x)
encoded = MaxPooling2D((2, 2), border_mode='same')(x)

# at this point the representation is (4, 4, 8) i.e. 128-dimensional

x = Conv2D(8, 3, 3, activation='relu', border_mode='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(8, 3, 3, activation='relu', border_mode='same')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(16, 3, 3, activation='relu')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, 3, 3, activation='sigmoid', border_mode='same')(x)

# at this point the representation is (28, 28, 1) i.e. 784-dimensional

autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')

In [17]:
from keras.datasets import mnist
import numpy as np

(x_train, _), (x_test, _) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))  # adapt this if using `channels_first` image data format
x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))  # adapt this if using `channels_first` image data format

In [18]:
from keras.callbacks import TensorBoard

autoencoder.fit(x_train, x_train,
                nb_epoch=50,
                batch_size=128,
                shuffle=True,
                validation_data=(x_test, x_test),
                callbacks=[TensorBoard(log_dir='/tmp/autoencoder')])


Train on 60000 samples, validate on 10000 samples
Epoch 1/50
60000/60000 [==============================] - 67s - loss: 0.2324 - val_loss: 0.1744
Epoch 2/50
60000/60000 [==============================] - 65s - loss: 0.1585 - val_loss: 0.1449
Epoch 3/50
60000/60000 [==============================] - 67s - loss: 0.1427 - val_loss: 0.1370
Epoch 4/50
60000/60000 [==============================] - 68s - loss: 0.1338 - val_loss: 0.1284
Epoch 5/50
60000/60000 [==============================] - 69s - loss: 0.1276 - val_loss: 0.1256
Epoch 6/50
60000/60000 [==============================] - 90s - loss: 0.1234 - val_loss: 0.1198
Epoch 7/50
60000/60000 [==============================] - 91s - loss: 0.1202 - val_loss: 0.1174
Epoch 8/50
60000/60000 [==============================] - 68s - loss: 0.1179 - val_loss: 0.1173
Epoch 9/50
60000/60000 [==============================] - 68s - loss: 0.1161 - val_loss: 0.1137
Epoch 10/50
60000/60000 [==============================] - 67s - loss: 0.1143 - val_loss: 0.1130
Epoch 11/50
60000/60000 [==============================] - 68s - loss: 0.1126 - val_loss: 0.1116
Epoch 12/50
60000/60000 [==============================] - 68s - loss: 0.1119 - val_loss: 0.1104
Epoch 13/50
60000/60000 [==============================] - 70s - loss: 0.1109 - val_loss: 0.1103
Epoch 14/50
60000/60000 [==============================] - 68s - loss: 0.1101 - val_loss: 0.1101
Epoch 15/50
60000/60000 [==============================] - 69s - loss: 0.1091 - val_loss: 0.1093
Epoch 16/50
60000/60000 [==============================] - 67s - loss: 0.1085 - val_loss: 0.1066
Epoch 17/50
60000/60000 [==============================] - 68s - loss: 0.1080 - val_loss: 0.1061
Epoch 18/50
60000/60000 [==============================] - 68s - loss: 0.1070 - val_loss: 0.1043
Epoch 19/50
60000/60000 [==============================] - 68s - loss: 0.1064 - val_loss: 0.1042
Epoch 20/50
60000/60000 [==============================] - 95s - loss: 0.1060 - val_loss: 0.1039
Epoch 21/50
60000/60000 [==============================] - 101s - loss: 0.1056 - val_loss: 0.1050
Epoch 22/50
60000/60000 [==============================] - 98s - loss: 0.1052 - val_loss: 0.1037
Epoch 23/50
60000/60000 [==============================] - 86s - loss: 0.1047 - val_loss: 0.1034
Epoch 24/50
60000/60000 [==============================] - 68s - loss: 0.1044 - val_loss: 0.1019
Epoch 25/50
60000/60000 [==============================] - 68s - loss: 0.1041 - val_loss: 0.1028
Epoch 26/50
60000/60000 [==============================] - 67s - loss: 0.1035 - val_loss: 0.1021
Epoch 27/50
60000/60000 [==============================] - 68s - loss: 0.1033 - val_loss: 0.1034
Epoch 28/50
60000/60000 [==============================] - 68s - loss: 0.1030 - val_loss: 0.1041
Epoch 29/50
60000/60000 [==============================] - 67s - loss: 0.1027 - val_loss: 0.1023
Epoch 30/50
60000/60000 [==============================] - 69s - loss: 0.1023 - val_loss: 0.1009
Epoch 31/50
60000/60000 [==============================] - 69s - loss: 0.1021 - val_loss: 0.0998
Epoch 32/50
60000/60000 [==============================] - 68s - loss: 0.1017 - val_loss: 0.1004
Epoch 33/50
60000/60000 [==============================] - 67s - loss: 0.1016 - val_loss: 0.0998
Epoch 34/50
60000/60000 [==============================] - 71s - loss: 0.1016 - val_loss: 0.1015
Epoch 35/50
60000/60000 [==============================] - 68s - loss: 0.1012 - val_loss: 0.1000
Epoch 36/50
60000/60000 [==============================] - 68s - loss: 0.1008 - val_loss: 0.0993
Epoch 37/50
60000/60000 [==============================] - 68s - loss: 0.1008 - val_loss: 0.0992
Epoch 38/50
60000/60000 [==============================] - 69s - loss: 0.1006 - val_loss: 0.0990
Epoch 39/50
60000/60000 [==============================] - 70s - loss: 0.1004 - val_loss: 0.1000
Epoch 40/50
60000/60000 [==============================] - 101s - loss: 0.1002 - val_loss: 0.0985
Epoch 41/50
60000/60000 [==============================] - 102s - loss: 0.1000 - val_loss: 0.0997
Epoch 42/50
60000/60000 [==============================] - 102s - loss: 0.0997 - val_loss: 0.0993
Epoch 43/50
60000/60000 [==============================] - 101s - loss: 0.0996 - val_loss: 0.0987
Epoch 44/50
60000/60000 [==============================] - 102s - loss: 0.0992 - val_loss: 0.0990
Epoch 45/50
60000/60000 [==============================] - 101s - loss: 0.0991 - val_loss: 0.0976
Epoch 46/50
60000/60000 [==============================] - 98s - loss: 0.0988 - val_loss: 0.0981
Epoch 47/50
60000/60000 [==============================] - 98s - loss: 0.0989 - val_loss: 0.0967
Epoch 48/50
60000/60000 [==============================] - 97s - loss: 0.0987 - val_loss: 0.0983
Epoch 49/50
60000/60000 [==============================] - 102s - loss: 0.0985 - val_loss: 0.0975
Epoch 50/50
60000/60000 [==============================] - 101s - loss: 0.0983 - val_loss: 0.0967
Out[18]:
<keras.callbacks.History at 0x7ff407285b50>

In [20]:
decoded_imgs = autoencoder.predict(x_test)

import matplotlib.pyplot as plt
n = 10
plt.figure(figsize=(10, 2))
for i in range(1,n):
    # display original
    ax = plt.subplot(2, n, i)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()


Example: Image denoising

(Source: https://blog.keras.io/building-autoencoders-in-keras.html)

In [21]:
import matplotlib.pyplot as plt
import numpy as np
from keras.datasets import mnist
from keras.layers import Dense
from keras.models import Sequential

(x_train, _), (x_test, _) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 784))
x_test = np.reshape(x_test, (len(x_test), 784))

noise_factor = 0.5
x_train_noisy = x_train + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_train.shape) 
x_test_noisy = x_test + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_test.shape) 
x_train_noisy = np.clip(x_train_noisy, 0., 1.)
x_test_noisy = np.clip(x_test_noisy, 0., 1.)

n = 10
plt.figure(figsize=(20, 2))
for i in range(n):
    ax = plt.subplot(1, n, i+1)
    plt.imshow(x_test_noisy[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()



In [24]:
model = Sequential()
model.add(Dense(128, activation='relu', input_dim=784))
model.add(Dense(64, activation='relu'))
model.add(Dense(32, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(128, activation='relu'))
model.add(Dense(784, activation='sigmoid'))
model.compile(optimizer='adam', loss='binary_crossentropy')

model.fit(x_train_noisy, x_train, 
          nb_epoch=100,
          batch_size=256,
          shuffle=True,
          validation_data=(x_test_noisy, x_test))


Train on 60000 samples, validate on 10000 samples
Epoch 1/100
60000/60000 [==============================] - 8s - loss: 0.2734 - val_loss: 0.2150
Epoch 2/100
60000/60000 [==============================] - 7s - loss: 0.1977 - val_loss: 0.1768
Epoch 3/100
60000/60000 [==============================] - 8s - loss: 0.1706 - val_loss: 0.1648
Epoch 4/100
60000/60000 [==============================] - 8s - loss: 0.1596 - val_loss: 0.1539
Epoch 5/100
60000/60000 [==============================] - 8s - loss: 0.1513 - val_loss: 0.1476
Epoch 6/100
60000/60000 [==============================] - 8s - loss: 0.1458 - val_loss: 0.1430
Epoch 7/100
60000/60000 [==============================] - 8s - loss: 0.1418 - val_loss: 0.1395
Epoch 8/100
60000/60000 [==============================] - 8s - loss: 0.1386 - val_loss: 0.1372
Epoch 9/100
60000/60000 [==============================] - 8s - loss: 0.1360 - val_loss: 0.1348
Epoch 10/100
60000/60000 [==============================] - 7s - loss: 0.1340 - val_loss: 0.1327
Epoch 11/100
60000/60000 [==============================] - 7s - loss: 0.1324 - val_loss: 0.1315
Epoch 12/100
60000/60000 [==============================] - 8s - loss: 0.1308 - val_loss: 0.1302
Epoch 13/100
60000/60000 [==============================] - 8s - loss: 0.1296 - val_loss: 0.1296
Epoch 14/100
60000/60000 [==============================] - 8s - loss: 0.1284 - val_loss: 0.1287
Epoch 15/100
60000/60000 [==============================] - 8s - loss: 0.1274 - val_loss: 0.1272
Epoch 16/100
60000/60000 [==============================] - 7s - loss: 0.1267 - val_loss: 0.1263
Epoch 17/100
60000/60000 [==============================] - 7s - loss: 0.1258 - val_loss: 0.1259
Epoch 18/100
60000/60000 [==============================] - 8s - loss: 0.1250 - val_loss: 0.1256
Epoch 19/100
60000/60000 [==============================] - 8s - loss: 0.1244 - val_loss: 0.1249
Epoch 20/100
60000/60000 [==============================] - 8s - loss: 0.1239 - val_loss: 0.1249
Epoch 21/100
60000/60000 [==============================] - 7s - loss: 0.1232 - val_loss: 0.1236
Epoch 22/100
60000/60000 [==============================] - 8s - loss: 0.1227 - val_loss: 0.1232
Epoch 23/100
60000/60000 [==============================] - 7s - loss: 0.1222 - val_loss: 0.1225
Epoch 24/100
60000/60000 [==============================] - 7s - loss: 0.1217 - val_loss: 0.1223
Epoch 25/100
60000/60000 [==============================] - 8s - loss: 0.1212 - val_loss: 0.1220
Epoch 26/100
60000/60000 [==============================] - 8s - loss: 0.1208 - val_loss: 0.1222
Epoch 27/100
60000/60000 [==============================] - 8s - loss: 0.1205 - val_loss: 0.1212
Epoch 28/100
60000/60000 [==============================] - 6s - loss: 0.1202 - val_loss: 0.1212
Epoch 29/100
60000/60000 [==============================] - 5s - loss: 0.1197 - val_loss: 0.1206
Epoch 30/100
60000/60000 [==============================] - 5s - loss: 0.1195 - val_loss: 0.1205
Epoch 31/100
60000/60000 [==============================] - 5s - loss: 0.1192 - val_loss: 0.1201
Epoch 32/100
60000/60000 [==============================] - 6s - loss: 0.1189 - val_loss: 0.1201
Epoch 33/100
60000/60000 [==============================] - 5s - loss: 0.1186 - val_loss: 0.1200
Epoch 34/100
60000/60000 [==============================] - 5s - loss: 0.1184 - val_loss: 0.1195
Epoch 35/100
60000/60000 [==============================] - 5s - loss: 0.1182 - val_loss: 0.1193
Epoch 36/100
60000/60000 [==============================] - 5s - loss: 0.1179 - val_loss: 0.1194
Epoch 37/100
60000/60000 [==============================] - 8s - loss: 0.1178 - val_loss: 0.1189
Epoch 38/100
60000/60000 [==============================] - 8s - loss: 0.1175 - val_loss: 0.1192
Epoch 39/100
60000/60000 [==============================] - 8s - loss: 0.1174 - val_loss: 0.1189
Epoch 40/100
60000/60000 [==============================] - 8s - loss: 0.1171 - val_loss: 0.1184
Epoch 41/100
60000/60000 [==============================] - 5s - loss: 0.1170 - val_loss: 0.1186
Epoch 42/100
60000/60000 [==============================] - 5s - loss: 0.1169 - val_loss: 0.1183
Epoch 43/100
60000/60000 [==============================] - 5s - loss: 0.1166 - val_loss: 0.1184
Epoch 44/100
60000/60000 [==============================] - 5s - loss: 0.1165 - val_loss: 0.1183
Epoch 45/100
60000/60000 [==============================] - 5s - loss: 0.1164 - val_loss: 0.1185
Epoch 46/100
60000/60000 [==============================] - 5s - loss: 0.1161 - val_loss: 0.1177
Epoch 47/100
60000/60000 [==============================] - 5s - loss: 0.1161 - val_loss: 0.1179
Epoch 48/100
60000/60000 [==============================] - 5s - loss: 0.1159 - val_loss: 0.1178
Epoch 49/100
60000/60000 [==============================] - 5s - loss: 0.1159 - val_loss: 0.1180
Epoch 50/100
60000/60000 [==============================] - 5s - loss: 0.1158 - val_loss: 0.1173
Epoch 51/100
60000/60000 [==============================] - 4s - loss: 0.1156 - val_loss: 0.1177
Epoch 52/100
60000/60000 [==============================] - 4s - loss: 0.1155 - val_loss: 0.1175
Epoch 53/100
60000/60000 [==============================] - 4s - loss: 0.1153 - val_loss: 0.1172
Epoch 54/100
60000/60000 [==============================] - 4s - loss: 0.1153 - val_loss: 0.1173
Epoch 55/100
60000/60000 [==============================] - 4s - loss: 0.1152 - val_loss: 0.1171
Epoch 56/100
60000/60000 [==============================] - 5s - loss: 0.1150 - val_loss: 0.1171
Epoch 57/100
60000/60000 [==============================] - 5s - loss: 0.1150 - val_loss: 0.1180
Epoch 58/100
60000/60000 [==============================] - 4s - loss: 0.1149 - val_loss: 0.1168
Epoch 59/100
60000/60000 [==============================] - 4s - loss: 0.1148 - val_loss: 0.1171
Epoch 60/100
60000/60000 [==============================] - 4s - loss: 0.1147 - val_loss: 0.1168
Epoch 61/100
60000/60000 [==============================] - 4s - loss: 0.1146 - val_loss: 0.1169
Epoch 62/100
60000/60000 [==============================] - 4s - loss: 0.1146 - val_loss: 0.1166
Epoch 63/100
60000/60000 [==============================] - 4s - loss: 0.1144 - val_loss: 0.1164
Epoch 64/100
60000/60000 [==============================] - 4s - loss: 0.1143 - val_loss: 0.1167
Epoch 65/100
60000/60000 [==============================] - 4s - loss: 0.1143 - val_loss: 0.1166
Epoch 66/100
60000/60000 [==============================] - 4s - loss: 0.1141 - val_loss: 0.1167
Epoch 67/100
60000/60000 [==============================] - 5s - loss: 0.1140 - val_loss: 0.1164
Epoch 68/100
60000/60000 [==============================] - 5s - loss: 0.1139 - val_loss: 0.1162
Epoch 69/100
60000/60000 [==============================] - 6s - loss: 0.1139 - val_loss: 0.1164
Epoch 70/100
60000/60000 [==============================] - 5s - loss: 0.1138 - val_loss: 0.1165
Epoch 71/100
60000/60000 [==============================] - 5s - loss: 0.1136 - val_loss: 0.1165
Epoch 72/100
60000/60000 [==============================] - 5s - loss: 0.1135 - val_loss: 0.1160
Epoch 73/100
60000/60000 [==============================] - 5s - loss: 0.1135 - val_loss: 0.1164
Epoch 74/100
60000/60000 [==============================] - 5s - loss: 0.1134 - val_loss: 0.1164
Epoch 75/100
60000/60000 [==============================] - 5s - loss: 0.1134 - val_loss: 0.1160
Epoch 76/100
60000/60000 [==============================] - 4s - loss: 0.1133 - val_loss: 0.1157
Epoch 77/100
60000/60000 [==============================] - 5s - loss: 0.1132 - val_loss: 0.1158
Epoch 78/100
60000/60000 [==============================] - 4s - loss: 0.1131 - val_loss: 0.1158
Epoch 79/100
60000/60000 [==============================] - 4s - loss: 0.1131 - val_loss: 0.1159
Epoch 80/100
60000/60000 [==============================] - 4s - loss: 0.1131 - val_loss: 0.1157
Epoch 81/100
60000/60000 [==============================] - 4s - loss: 0.1130 - val_loss: 0.1167
Epoch 82/100
60000/60000 [==============================] - 4s - loss: 0.1129 - val_loss: 0.1160
Epoch 83/100
60000/60000 [==============================] - 5s - loss: 0.1128 - val_loss: 0.1157
Epoch 84/100
60000/60000 [==============================] - 7s - loss: 0.1128 - val_loss: 0.1154
Epoch 85/100
60000/60000 [==============================] - 8s - loss: 0.1127 - val_loss: 0.1151
Epoch 86/100
60000/60000 [==============================] - 8s - loss: 0.1126 - val_loss: 0.1153
Epoch 87/100
60000/60000 [==============================] - 8s - loss: 0.1126 - val_loss: 0.1156
Epoch 88/100
60000/60000 [==============================] - 6s - loss: 0.1126 - val_loss: 0.1155
Epoch 89/100
60000/60000 [==============================] - 5s - loss: 0.1125 - val_loss: 0.1158
Epoch 90/100
60000/60000 [==============================] - 5s - loss: 0.1125 - val_loss: 0.1154
Epoch 91/100
60000/60000 [==============================] - 5s - loss: 0.1124 - val_loss: 0.1160
Epoch 92/100
60000/60000 [==============================] - 5s - loss: 0.1123 - val_loss: 0.1150
Epoch 93/100
60000/60000 [==============================] - 5s - loss: 0.1123 - val_loss: 0.1149
Epoch 94/100
60000/60000 [==============================] - 5s - loss: 0.1122 - val_loss: 0.1149
Epoch 95/100
60000/60000 [==============================] - 5s - loss: 0.1122 - val_loss: 0.1150
Epoch 96/100
60000/60000 [==============================] - 5s - loss: 0.1122 - val_loss: 0.1152
Epoch 97/100
60000/60000 [==============================] - 5s - loss: 0.1122 - val_loss: 0.1152
Epoch 98/100
60000/60000 [==============================] - 5s - loss: 0.1121 - val_loss: 0.1154
Epoch 99/100
60000/60000 [==============================] - 5s - loss: 0.1121 - val_loss: 0.1151
Epoch 100/100
60000/60000 [==============================] - 5s - loss: 0.1120 - val_loss: 0.1148
Out[24]:
<keras.callbacks.History at 0x7ff40629bf90>

In [26]:
decoded_imgs = model.predict(x_test)
n = 10
plt.figure(figsize=(20, 6))
for i in range(1, n):
    # display original
    ax = plt.subplot(3, n, i)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display noisy
    ax = plt.subplot(3, n, i + n)
    plt.imshow(x_test_noisy[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(3, n, i + 2*n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()


Variational Autoencoders

A variational autoencoder is an autoencoder that adds probabilistic constraints on the representations being learned.

When using probabilistic models, compressed representation is called latent variable model.

So, instead of learning a function this model is learning a probabilistic distribution function that models your data.

Why? Standard autoencoders are not suited to work as a generative model. If you pick a random value for your decoder you won't get necessarily a good reconstruction: the value can far away from any previous value the network has seen before! That's why attaching a probabilistic model to the compressed representation is a good idea!

For sake of simplicity, let's use a standard normal distribution to define the distribution of inputs ($\mathbf V$) the decoder will receive.

The architecture of a variational autoencoder (VAE) is thus:

(Source: http://ijdykeman.github.io/ml/2016/12/21/cvae.html)

We want the decoder to take any point taken from a standard normal distribution to return a reasonable element of our dataset:

(Source: http://ijdykeman.github.io/ml/2016/12/21/cvae.html)

Let's consider the encoder role in this architecture.

In a traditional autoencoder, the encoder model takes a sample from data and returns a single point in the latent space, which is then passed to the decoder.

What information is encoded in the latent space?

In VAE the encoder instead produces (the parameters of) a probability distribution in the latent space:

(Source: http://ijdykeman.github.io/ml/2016/12/21/cvae.html)

These distributions are (non standard) Gaussians of the same dimensionality as the latent space.

First, let’s implement the encoder net, which takes input $X$ and outputs two things: $\mu(X)$ and $\Sigma(X)$, the parameters of the Gaussian. Our encoder will be a neural net with one hidden layer.

Our latent variable is two dimensional, so that we could easily visualize it.


In [35]:
# vae architecture

from tensorflow.examples.tutorials.mnist import input_data
from keras.layers import Input, Dense, Lambda
from keras.models import Model
from keras.objectives import binary_crossentropy
from keras.callbacks import LearningRateScheduler

import numpy as np
import matplotlib.pyplot as plt
import keras.backend as K
import tensorflow as tf

m = 50
n_z = 2
n_epoch = 100

# encoder
inputs = Input(shape=(784,))
h_q = Dense(512, activation='relu')(inputs)
mu = Dense(n_z, activation='linear')(h_q)
log_sigma = Dense(n_z, activation='linear')(h_q)

Up to now we have an encoder that takes images and produce (the parameters of) a pdf in the latent space. The decoder takes points in the latent space and return reconstructions.

How do we connect both models? By sampling from the produced distribution!

(Source: http://ijdykeman.github.io/ml/2016/12/21/cvae.html)

To this end we will implement a random variate reparameterisation: the substitution of a random variable by a deterministic transformation of a simpler random variable.

There are several methods by which non-uniform random numbers, or random variates, can be generated. The most popular methods are the one-liners, which give us the simple tools to generate random variates in one line of code, following the classic paper by Luc Devroye (Luc Devroye, Random variate generation in one line of code, Proceedings of the 28th conference on Winter simulation, 1996).

In the case of a Gaussian, we can use the following algorithm:

  • Generate $\epsilon \sim \mathcal{N}(0;1)$.
  • Compute a sample from $\mathcal{N}(\mu; RR^T)$ as $\mu + R \epsilon$.

In [ ]:
def sample_z(args):
    mu, log_sigma = args
    eps = K.random_normal(shape=(m, n_z), mean=0., std=1.)
    return mu + K.exp(log_sigma / 2) * eps

# Sample z
z = Lambda(sample_z)([mu, log_sigma])

Now we can create the decoder net:


In [37]:
decoder_hidden = Dense(512, activation='relu')
h_p = decoder_hidden(z)

decoder_out = Dense(784, activation='sigmoid')
outputs = decoder_out(h_p)

Lastly, from this model, we can do three things: reconstruct inputs, encode inputs into latent variables, and generate data from latent variable.


In [38]:
# Overall VAE model, for reconstruction and training
vae = Model(inputs, outputs)

# Encoder model, to encode input into latent variable
# We use the mean as the output as it is the center point, the representative of the gaussian
encoder = Model(inputs, mu)

# Generator model, generate new data given latent variable z
d_in = Input(shape=(n_z,))
d_h = decoder_hidden(d_in)
d_out = decoder_out(d_h)
decoder = Model(d_in, d_out)

In order to be coherent with our previous definitions, we must assure that points sampled fron the latent space fit a standard normal distribition, but the encoder is producing non standard normal distributions. So, we must add a constraint for getting something like this:

(Source: http://ijdykeman.github.io/ml/2016/12/21/cvae.html)

In order to impose this constraint in the loss function by using the Kullback-Leibler divergence.

The Kullback–Leibler divergence is a measure of how one probability distribution diverges from a second expected probability distribution. For discrete probability distributions $P$ and $Q$, the Kullback–Leibler divergence from $Q$ to $P$ is defined to be $$ D_{\mathrm {KL} }(P\|Q)=\sum _{i}P(i)\,\log {\frac {P(i)}{Q(i)}}. $$

The rest of the loss function must take into account the "reconstruction" error.


In [39]:
def vae_loss(y_true, y_pred):
    """ 
    Calculate loss = reconstruction loss + 
    KL loss for each data in minibatch 
    """
    recon = K.sum(K.binary_crossentropy(y_pred, y_true), axis=1)
    # D_KL(Q(z|X) || P(z|X)); 
    # calculate in closed form as both dist. are Gaussian
    kl = 0.5 * K.sum(K.exp(log_sigma) + K.square(mu) 
                     - 1. - log_sigma, axis=1)
    return recon + kl

Training a VAE

How do we train a model that have a sampling step?

In fact this is not a problem! By using the one-liner method for sampling we have expressed the latent distribution in a way that its parameters are factored out of the parameters of the random variable so that backpropagation can be used to find the optimal parameters of the latent distribution. For this reason this method is called reparametrization trick.

By using this trick we can train end-to-end a VAE with backpropagation.


In [40]:
from keras.datasets import mnist

(x_train, _), (x_test, y_test) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 784))
x_test = np.reshape(x_test, (len(x_test), 784))

vae.compile(optimizer='adam', loss=vae_loss)
vae.fit(x_train, x_train, batch_size=m, nb_epoch=n_epoch)


Epoch 1/100
60000/60000 [==============================] - 19s - loss: 181.1754    
Epoch 2/100
60000/60000 [==============================] - 20s - loss: 164.1534    
Epoch 3/100
60000/60000 [==============================] - 20s - loss: 161.2699    
Epoch 4/100
60000/60000 [==============================] - 20s - loss: 159.3786    
Epoch 5/100
60000/60000 [==============================] - 21s - loss: 157.9847    
Epoch 6/100
60000/60000 [==============================] - 20s - loss: 156.8659    
Epoch 7/100
60000/60000 [==============================] - 20s - loss: 155.9391    
Epoch 8/100
60000/60000 [==============================] - 20s - loss: 155.1295    
Epoch 9/100
60000/60000 [==============================] - 20s - loss: 154.4790    
Epoch 10/100
60000/60000 [==============================] - 20s - loss: 153.8685    
Epoch 11/100
60000/60000 [==============================] - 20s - loss: 153.3425    
Epoch 12/100
60000/60000 [==============================] - 20s - loss: 152.8552    
Epoch 13/100
60000/60000 [==============================] - 20s - loss: 152.3775    
Epoch 14/100
60000/60000 [==============================] - 20s - loss: 151.9199    
Epoch 15/100
60000/60000 [==============================] - 20s - loss: 151.5140    
Epoch 16/100
60000/60000 [==============================] - 20s - loss: 151.1423    
Epoch 17/100
60000/60000 [==============================] - 20s - loss: 150.7706    
Epoch 18/100
60000/60000 [==============================] - 19s - loss: 150.4246    
Epoch 19/100
60000/60000 [==============================] - 25s - loss: 150.1033    
Epoch 20/100
60000/60000 [==============================] - 20s - loss: 149.8305    
Epoch 21/100
60000/60000 [==============================] - 19s - loss: 149.5420    
Epoch 22/100
60000/60000 [==============================] - 24s - loss: 149.2889    
Epoch 23/100
60000/60000 [==============================] - 24s - loss: 149.0635    
Epoch 24/100
60000/60000 [==============================] - 20s - loss: 148.7778    
Epoch 25/100
60000/60000 [==============================] - 20s - loss: 148.6028    
Epoch 26/100
60000/60000 [==============================] - 20s - loss: 148.3978    
Epoch 27/100
60000/60000 [==============================] - 20s - loss: 148.2029    
Epoch 28/100
60000/60000 [==============================] - 20s - loss: 147.9991    
Epoch 29/100
60000/60000 [==============================] - 20s - loss: 147.7992    
Epoch 30/100
60000/60000 [==============================] - 21s - loss: 147.5876    
Epoch 31/100
60000/60000 [==============================] - 20s - loss: 147.4915    
Epoch 32/100
60000/60000 [==============================] - 20s - loss: 147.2922    
Epoch 33/100
60000/60000 [==============================] - 20s - loss: 147.1312    
Epoch 34/100
60000/60000 [==============================] - 19s - loss: 146.9788    
Epoch 35/100
60000/60000 [==============================] - 19s - loss: 146.8458    
Epoch 36/100
60000/60000 [==============================] - 19s - loss: 146.6818    
Epoch 37/100
60000/60000 [==============================] - 21s - loss: 146.5508    
Epoch 38/100
60000/60000 [==============================] - 20s - loss: 146.4339    
Epoch 39/100
60000/60000 [==============================] - 19s - loss: 146.2798    
Epoch 40/100
60000/60000 [==============================] - 20s - loss: 146.1480    
Epoch 41/100
60000/60000 [==============================] - 20s - loss: 146.0455    
Epoch 42/100
60000/60000 [==============================] - 20s - loss: 145.9377    
Epoch 43/100
60000/60000 [==============================] - 20s - loss: 145.8114    
Epoch 44/100
60000/60000 [==============================] - 20s - loss: 145.6647    
Epoch 45/100
60000/60000 [==============================] - 20s - loss: 145.5653    
Epoch 46/100
60000/60000 [==============================] - 19s - loss: 145.4719    
Epoch 47/100
60000/60000 [==============================] - 19s - loss: 145.3320    
Epoch 48/100
60000/60000 [==============================] - 20s - loss: 145.2747    
Epoch 49/100
60000/60000 [==============================] - 20s - loss: 145.1338    
Epoch 50/100
60000/60000 [==============================] - 19s - loss: 145.0555    
Epoch 51/100
60000/60000 [==============================] - 20s - loss: 144.9251    
Epoch 52/100
60000/60000 [==============================] - 20s - loss: 144.8793    
Epoch 53/100
60000/60000 [==============================] - 20s - loss: 144.7669    
Epoch 54/100
60000/60000 [==============================] - 20s - loss: 144.7311    
Epoch 55/100
60000/60000 [==============================] - 20s - loss: 144.6316    
Epoch 56/100
60000/60000 [==============================] - 20s - loss: 144.5594    
Epoch 57/100
60000/60000 [==============================] - 20s - loss: 144.4807    
Epoch 58/100
60000/60000 [==============================] - 21s - loss: 144.3660    
Epoch 59/100
60000/60000 [==============================] - 21s - loss: 144.2666    
Epoch 60/100
60000/60000 [==============================] - 22s - loss: 144.2443    
Epoch 61/100
60000/60000 [==============================] - 22s - loss: 144.1875    
Epoch 62/100
60000/60000 [==============================] - 22s - loss: 144.0776    
Epoch 63/100
60000/60000 [==============================] - 22s - loss: 144.0327    
Epoch 64/100
60000/60000 [==============================] - 21s - loss: 143.9392    
Epoch 65/100
60000/60000 [==============================] - 21s - loss: 143.8760    
Epoch 66/100
60000/60000 [==============================] - 21s - loss: 143.8059    
Epoch 67/100
60000/60000 [==============================] - 21s - loss: 143.7386    
Epoch 68/100
60000/60000 [==============================] - 21s - loss: 143.7169    
Epoch 69/100
60000/60000 [==============================] - 21s - loss: 143.6495    
Epoch 70/100
60000/60000 [==============================] - 21s - loss: 143.5333    
Epoch 71/100
60000/60000 [==============================] - 21s - loss: 143.5224    
Epoch 72/100
60000/60000 [==============================] - 22s - loss: 143.4183    
Epoch 73/100
60000/60000 [==============================] - 23s - loss: 143.4063    
Epoch 74/100
60000/60000 [==============================] - 23s - loss: 143.3266    
Epoch 75/100
60000/60000 [==============================] - 23s - loss: 143.2667    
Epoch 76/100
60000/60000 [==============================] - 22s - loss: 143.1886    
Epoch 77/100
60000/60000 [==============================] - 22s - loss: 143.1858    
Epoch 78/100
60000/60000 [==============================] - 22s - loss: 143.1177    
Epoch 79/100
60000/60000 [==============================] - 22s - loss: 143.0715    
Epoch 80/100
60000/60000 [==============================] - 22s - loss: 142.9865    
Epoch 81/100
60000/60000 [==============================] - 22s - loss: 142.9690    
Epoch 82/100
60000/60000 [==============================] - 22s - loss: 142.9008    
Epoch 83/100
60000/60000 [==============================] - 22s - loss: 142.8432    
Epoch 84/100
60000/60000 [==============================] - 23s - loss: 142.8220    
Epoch 85/100
60000/60000 [==============================] - 23s - loss: 142.7384    
Epoch 86/100
60000/60000 [==============================] - 24s - loss: 142.7161    
Epoch 87/100
60000/60000 [==============================] - 27s - loss: 142.6970    
Epoch 88/100
60000/60000 [==============================] - 23s - loss: 142.6455    
Epoch 89/100
60000/60000 [==============================] - 26s - loss: 142.5980    
Epoch 90/100
60000/60000 [==============================] - 25s - loss: 142.5540    
Epoch 91/100
60000/60000 [==============================] - 23s - loss: 142.5116    
Epoch 92/100
60000/60000 [==============================] - 23s - loss: 142.4521    
Epoch 93/100
60000/60000 [==============================] - 22s - loss: 142.3983    
Epoch 94/100
60000/60000 [==============================] - 23s - loss: 142.3383    
Epoch 95/100
60000/60000 [==============================] - 23s - loss: 142.3157    
Epoch 96/100
60000/60000 [==============================] - 23s - loss: 142.3206    
Epoch 97/100
60000/60000 [==============================] - 26s - loss: 142.2342    
Epoch 98/100
60000/60000 [==============================] - 23s - loss: 142.2286    
Epoch 99/100
60000/60000 [==============================] - 24s - loss: 142.1708    
Epoch 100/100
60000/60000 [==============================] - 23s - loss: 142.1172    
Out[40]:
<keras.callbacks.History at 0x7ff3f204ff50>

In [41]:
encoded_imgs = encoder.predict(x_test)
decoded_imgs = decoder.predict(encoded_imgs)

import matplotlib.pyplot as plt
%matplotlib inline

n = 10  # how many digits we will display
plt.figure(figsize=(10, 4))
for i in range(n):
    # display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()



In [43]:
plt.scatter(encoded_imgs[:,0], encoded_imgs[:,1], c=y_test, cmap=plt.cm.get_cmap("jet", 10))
plt.colorbar(ticks=range(10))


Out[43]:
<matplotlib.colorbar.Colorbar at 0x7ff405737790>

Conditional VAE

What about producing specific number instances on demand?

We can do this by adding an extra input (as a one-hot encoding) to both the encoder and the decoder:

(Source: http://ijdykeman.github.io/ml/2016/12/21/cvae.html)

To generate an image of a particular number just feed that number in the decoder along with a random point in the latent space.

(Source: http://ijdykeman.github.io/ml/2016/12/21/cvae.html)

The latent space is no longer encoding what number are you dealing with (because this is already encoded in the extra input!). Instead, it is encoding information such as stroke width, angle, etc.

Bibliography

  • Doersch, Carl. “Tutorial on variational autoencoders.” arXiv preprint arXiv:1606.05908 (2016).
  • Kingma, Diederik P., and Max Welling. “Auto-encoding variational bayes.” arXiv preprint arXiv:1312.6114 (2013).