In [1]:
from gumbel_softmax import GumbelSoftmax, GumbelSoftmaxLayer
import theano.tensor as T
import numpy as np
In [2]:
temperature = 0.01
logits = np.linspace(-2,2,10).reshape([1,-1])
gumbel_softmax = GumbelSoftmax(t=temperature)(logits)
softmax = T.nnet.softmax(logits)
In [3]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.title('gumbel-softmax samples')
for i in range(100):
plt.plot(range(10),gumbel_softmax.eval()[0],marker='o',alpha=0.25)
plt.ylim(0,1)
plt.show()
plt.title('average over samples')
plt.plot(range(10),np.mean([gumbel_softmax.eval()[0] for _ in range(500)],axis=0),
marker='o',label='gumbel-softmax average')
plt.plot(softmax.eval()[0],marker='+',label='regular softmax')
plt.legend(loc='best')
Out[3]:
In [4]:
from sklearn.datasets import load_digits
X = load_digits().data
In [5]:
import lasagne
from lasagne.layers import *
import theano
#graph inputs and shareds
input_var = T.matrix()
temp = theano.shared(np.float32(1),'temperature',allow_downcast=True)
#architecture: encoder
nn = l_in = InputLayer((None,64),input_var)
nn = DenseLayer(nn,64,nonlinearity=T.tanh)
nn = DenseLayer(nn,32,nonlinearity=T.tanh)
#bottleneck
nn = DenseLayer(nn,32,nonlinearity=None)
nn = reshape(nn,(-1,4)) #reshape so that softmax would be applied over blocks of 4
nn = GumbelSoftmaxLayer(nn,t=temp)
nn = bottleneck = reshape(nn,(-1,32))
#decoder
nn = DenseLayer(nn,32,nonlinearity=T.tanh)
nn = DenseLayer(nn,64,nonlinearity=T.tanh)
nn = DenseLayer(nn,64,nonlinearity=None)
#loss and updates
loss = T.mean((get_output(nn)-input_var)**2)
updates = lasagne.updates.adam(loss,get_all_params(nn))
#compile
train_step = theano.function([input_var],loss,updates=updates)
evaluate = theano.function([input_var],loss)
In [6]:
for i,t in enumerate(np.logspace(0,-2,10000)):
sample = X[np.random.choice(len(X),32)]
temp.set_value(t)
mse = train_step(sample)
if i %100 ==0:
print '%.3f'%evaluate(X),
In [7]:
#functions for visualization
get_sample = theano.function([input_var],get_output(nn))
get_sample_hard = theano.function([input_var],get_output(nn,hard_max=True))
get_code = theano.function([input_var],get_output(bottleneck,hard_max=False))
In [8]:
for i in range(10):
X_sample = X[np.random.randint(len(X)),None,:]
plt.figure(figsize=[12,4])
plt.subplot(1,4,1)
plt.title("original")
plt.imshow(X_sample.reshape([8,8]),interpolation='none',cmap='gray')
plt.subplot(1,4,2)
plt.title("gumbel")
plt.imshow(get_sample(X_sample).reshape([8,8]),interpolation='none',cmap='gray')
plt.subplot(1,4,3)
plt.title("hard-max")
plt.imshow(get_sample_hard(X_sample).reshape([8,8]),interpolation='none',cmap='gray')
plt.subplot(1,4,4)
plt.title("code")
plt.imshow(get_code(X_sample).reshape(8,4),interpolation='none',cmap='gray')
plt.show()
In [ ]: