In [2]:
%pylab inline
In [ ]:
import numpy as np
import matplotlib.pyplot as plt
from keras.layers import Input, Dense, Lambda
from keras.models import Model, Sequential
from keras import backend as K
from keras import objectives
from keras.datasets import mnist
from keras.activations import softmax
from keras.objectives import binary_crossentropy as bce
batch_size = 100
data_dim = 784
M = 10
N = 30
nb_epoch = 100
epsilon_std = 0.01
anneal_rate = 0.0003
min_temperature = 0.5
tau = K.variable(5.0, name="temperature")
x = Input(batch_shape=(batch_size, data_dim))
h = Dense(256, activation='relu')(Dense(512, activation='relu')(x))
logits_y = Dense(M*N)(h)
def sampling(logits_y):
U = K.random_uniform(K.shape(logits_y), 0, 1)
y = logits_y - K.log(-K.log(U + 1e-20) + 1e-20) # logits + gumbel noise
y = softmax(K.reshape(y, (-1, N, M)) / tau)
y = K.reshape(y, (-1, N*M))
return y
z = Lambda(sampling, output_shape=(M*N,))(logits_y)
generator = Sequential()
generator.add(Dense(256, activation='relu', input_shape=(N*M, )))
generator.add(Dense(512, activation='relu'))
generator.add(Dense(data_dim, activation='sigmoid'))
x_hat = generator(z)
# x_hat = Dense(data_dim, activation='softmax')(Dense(512, activation='relu')(Dense(256, activation='relu')(z)))
def gumbel_loss(x, x_hat):
q_y = K.reshape(logits_y, (-1, N, M))
q_y = softmax(q_y)
log_q_y = K.log(q_y + 1e-20)
kl_tmp = q_y * (log_q_y - K.log(1.0/M))
KL = K.sum(kl_tmp, axis=(1, 2))
elbo = data_dim * bce(x, x_hat) - KL
return elbo
vae = Model(x, x_hat)
vae.compile(optimizer='adam', loss=gumbel_loss)
# train the VAE on MNIST digits
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
for e in range(nb_epoch):
vae.fit(x_train, x_train,
shuffle=True,
nb_epoch=1,
batch_size=batch_size,
validation_data=(x_test, x_test))
K.set_value(tau, np.max([K.get_value(tau) * np.exp(- anneal_rate * e), min_temperature]))
In [6]:
argmax_y = K.max(K.reshape(logits_y, (-1, N, M)), axis=-1, keepdims=True)
argmax_y = K.equal(K.reshape(logits_y, (-1, N, M)), argmax_y)
encoder = K.function([x], [argmax_y, x_hat])
In [7]:
code, x_hat_test = encoder([x_test[:100]])
subplot(131)
imshow(x_test[:1].reshape(28, 28), cmap='gray'), axis('off')
subplot(132)
imshow(code[:1].reshape(N, M), cmap='gray'), axis('off')
subplot(133)
imshow(x_hat_test[:1].reshape(28, 28), cmap='gray'), axis('off')
Out[7]:
In [8]:
C = np.zeros((6000, N*M))
for i in range(0, 6000, 100):
c, _ = encoder([x_test[i:i+100]])
C[i:i+100] = c.reshape(100, -1)
In [9]:
from sklearn.manifold.t_sne import TSNE
tsne = TSNE(metric='hamming')
viz = tsne.fit_transform(C)
In [11]:
from agnez import embedding2dplot
_ = embedding2dplot(viz, y_test[:6000], show_median=False)