In [1]:
from __future__ import print_function

import warnings
warnings.filterwarnings('ignore')

import numpy as np
import holoviews as hv
%reload_ext holoviews.ipython
%output size=150 fig='svg'


The Gumbel Distribution


In [2]:
random_num_generator = np.random.RandomState(42)

In [3]:
def gumbel_dist(*args):
    u = random_num_generator.rand(*args)
    return -np.log(-np.log(u))

In [4]:
hv.Histogram(np.histogram(gumbel_dist(100000), bins=50))


Out[4]:

Gumbel-max Trick


In [5]:
def softmax(x, axis=None):
    x = np.exp(x)
    return (x/np.sum(x, axis=axis))

In [6]:
alpha = 3*random_num_generator.rand(4)
print(softmax(alpha))


[ 0.32032081  0.27257147  0.16078961  0.24631812]

In [7]:
hv.Histogram(np.histogram(gumbel_dist(100000), bins=50))\
*hv.VLine(np.log(alpha)[0])*hv.VLine(np.log(alpha)[1])\
*hv.VLine(np.log(alpha)[2])*hv.VLine(np.log(alpha)[3])


Out[7]:

In [8]:
%opts Histogram (alpha=0.5)

diff_hists = hv.HoloMap(kdims=['Alpha'])

for i in range(len(alpha)):
    diff_hists[i] = hv.Histogram(np.histogram(np.log(alpha[i])\
                        + gumbel_dist(10000), bins=50))
    
diff_hists.overlay('Alpha')


Out[8]:

In [9]:
def gumbel_max(alpha, N=100000):
    d = len(alpha)
    return np.argmax(np.log(alpha.reshape(d,1)) + gumbel_dist(d*N).reshape(d,N),axis=0)

def count(counts):
    bins = {i:0 for i in range(max(counts)+1)}
    for i in bins:
        for c in counts:
            if i == c:
                bins[i] += 1
    return [float(bins[i]) for i in range(max(counts)+1)]

In [12]:
samples = gumbel_max(alpha)
print("Some samples from our categorical:", samples[:10])


Some samples from our categorical: [1 2 0 1 0 3 0 2 3 2]

In [13]:
proportions = count(samples)/np.sum(count(samples))
print("Proportions observed in our samples:", proportions)
print("Probabilities we expected:", alpha/np.sum(alpha))


Proportions observed in our samples: [ 0.29465  0.27001  0.18262  0.25272]
Probabilities we expected: [ 0.2975302   0.26996494  0.17983448  0.25267038]

Concrete Distributions


In [14]:
def gumbel_softmax(alpha, temperature=1, N=100000):
    lmbda = temperature
    d = len(alpha)
    return softmax((np.log(alpha.reshape(d,1)) + gumbel_dist(d*N).reshape(d,N))/lmbda,axis=0)

In [15]:
samples = gumbel_softmax(alpha)
gumbel_softmax(alpha,N=1000).T[0]


Out[15]:
array([ 0.71756818,  0.10186108,  0.15072575,  0.029845  ])

In [16]:
def barycentric(p):
    x = 0.0
    y = 0.0
    y += p[1]
    y += p[2]
    x += p[2]
    x += p[3]
    return x,y

In [17]:
barycentric(gumbel_softmax(alpha, N=1000).T[0])


Out[17]:
(0.6678769070214583, 0.47276603080533042)

In [18]:
points = np.array([barycentric(s) for s in gumbel_softmax(alpha, N=100000).T])
h = np.histogram2d(points[:,1], points[:,0], bins=20)
hv.QuadMesh((h[1], h[2], h[0])).hist()


Out[18]:

In [19]:
vary_temperature = hv.HoloMap(kdims=['Lambda'])
for l in np.linspace(0.5, 5.0, 10):
    points = np.array([barycentric(s) for s in gumbel_softmax(alpha, temperature=l, N=100000).T])
    h = np.histogram2d(points[:,1], points[:,0], bins=20)
    vary_temperature[l] = hv.QuadMesh((h[1], h[2], h[0]))

In [20]:
vary_temperature.layout('Lambda').cols(2)


Out[20]:

VAE using Gumbel-Softmax/Concrete


In [31]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

In [14]:
slim=tf.contrib.slim
Bernoulli = tf.contrib.distributions.Bernoulli
OneHotCategorical = tf.contrib.distributions.OneHotCategorical
RelaxedOneHotCategorical = tf.contrib.distributions.RelaxedOneHotCategorical

In [15]:
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

In [16]:
batch_size = 100
tau0 = 1.0 # initial temperature
K = 10 # number of classes
N = 200//K # number of categorical distributions
straight_through = False # if True, use Straight-through Gumbel-Softmax
kl_type = 'relaxed' # choose between ('relaxed', 'categorical')
learn_temp = False

In [17]:
x=tf.placeholder(tf.float32, shape=(batch_size,784), name='x')
net = tf.cast(tf.random_uniform(tf.shape(x)) < x, x.dtype) # dynamic binarization
net = slim.stack(net,slim.fully_connected,[512,256])
logits_y = tf.reshape(slim.fully_connected(net,K*N,activation_fn=None),[-1,N,K])
tau = tf.Variable(tau0,name="temperature",trainable=learn_temp)
q_y = RelaxedOneHotCategorical(tau,logits_y)
y = q_y.sample()
if straight_through:
  y_hard = tf.cast(tf.one_hot(tf.argmax(y,-1),K), y.dtype)
  y = tf.stop_gradient(y_hard - y) + y
net = slim.flatten(y)
net = slim.stack(net,slim.fully_connected,[256,512])
logits_x = slim.fully_connected(net,784,activation_fn=None)
p_x = Bernoulli(logits=logits_x)
x_mean = p_x.mean()

In [27]:
recons = tf.reduce_sum(p_x.log_prob(x),1)
logits_py = tf.ones_like(logits_y) * 1./K

if kl_type=='categorical' or straight_through:
  # Analytical KL with Categorical prior
  p_cat_y = OneHotCategorical(logits=logits_py)
  q_cat_y = OneHotCategorical(logits=logits_y)
  KL_qp = tf.contrib.distributions.kl(q_cat_y, p_cat_y)
else:
  # Monte Carlo KL with Relaxed prior
  p_y = RelaxedOneHotCategorical(tau,logits=logits_py)
  KL_qp = q_y.log_prob(y) - p_y.log_prob(y)

In [28]:
KL = tf.reduce_sum(KL_qp,1)
mean_recons = tf.reduce_mean(recons)
mean_KL = tf.reduce_mean(KL)
loss = -tf.reduce_mean(recons-KL)

In [29]:
train_op=tf.train.AdamOptimizer(learning_rate=3e-4).minimize(loss)

In [30]:
data = []
from tqdm import tnrange, tqdm_notebook

with tf.train.MonitoredSession() as sess:
  for i in range(1,50000):
    batch = mnist.train.next_batch(batch_size)
    res = sess.run([train_op, loss, tau, mean_recons, mean_KL], {x : batch[0]})
    if i % 100 == 1:
      data.append([i] + res[1:])
    if i % 1000 == 1:
      print('Step %d, Loss: %0.3f' % (i,res[1]))
  # end training - do an eval
  batch = mnist.test.next_batch(batch_size)
  np_x = sess.run(x_mean, {x : batch[0]})


---------------------------------------------------------------------------
FailedPreconditionError                   Traceback (most recent call last)
<ipython-input-30-aef5cbeb8fbe> in <module>()
      5   for i in range(1,50000):
      6     batch = mnist.train.next_batch(batch_size)
----> 7     res = sess.run([train_op, loss, tau, mean_recons, mean_KL], {x : batch[0]})
      8     if i % 100 == 1:
      9       data.append([i] + res[1:])

/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.pyc in run(self, fetches, feed_dict, options, run_metadata)
    482                           feed_dict=feed_dict,
    483                           options=options,
--> 484                           run_metadata=run_metadata)
    485 
    486   def should_stop(self):

/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.pyc in run(self, fetches, feed_dict, options, run_metadata)
    818                               feed_dict=feed_dict,
    819                               options=options,
--> 820                               run_metadata=run_metadata)
    821       except _PREEMPTION_ERRORS as e:
    822         logging.info('An error was raised. This may be due to a preemption in '

/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.pyc in run(self, *args, **kwargs)
    774 
    775   def run(self, *args, **kwargs):
--> 776     return self._sess.run(*args, **kwargs)
    777 
    778 

/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.pyc in run(self, fetches, feed_dict, options, run_metadata)
    928                                   feed_dict=feed_dict,
    929                                   options=options,
--> 930                                   run_metadata=run_metadata)
    931 
    932     for hook in self._hooks:

/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.pyc in run(self, *args, **kwargs)
    774 
    775   def run(self, *args, **kwargs):
--> 776     return self._sess.run(*args, **kwargs)
    777 
    778 

/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
    776     try:
    777       result = self._run(None, fetches, feed_dict, options_ptr,
--> 778                          run_metadata_ptr)
    779       if run_metadata:
    780         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
    980     if final_fetches or final_targets:
    981       results = self._do_run(handle, final_targets, final_fetches,
--> 982                              feed_dict_string, options, run_metadata)
    983     else:
    984       results = []

/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1030     if handle is None:
   1031       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1032                            target_list, options, run_metadata)
   1033     else:
   1034       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_call(self, fn, *args)
   1050         except KeyError:
   1051           pass
-> 1052       raise type(e)(node_def, op, message)
   1053 
   1054   def _extend_graph(self):

FailedPreconditionError: Attempting to use uninitialized value Stack_2/fully_connected_1/weights
	 [[Node: Stack_2/fully_connected_1/weights/read = Identity[T=DT_FLOAT, _class=["loc:@Stack_2/fully_connected_1/weights"], _device="/job:localhost/replica:0/task:0/cpu:0"](Stack_2/fully_connected_1/weights)]]

Caused by op u'Stack_2/fully_connected_1/weights/read', defined at:
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/runpy.py", line 174, in _run_module_as_main
    "__main__", fname, loader, pkg_name)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/runpy.py", line 72, in _run_code
    exec code in run_globals
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/ipykernel/kernelapp.py", line 477, in start
    ioloop.IOLoop.instance().start()
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/zmq/eventloop/ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tornado/ioloop.py", line 888, in start
    handler_func(fd_obj, events)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 235, in dispatch_shell
    handler(stream, idents, msg)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/ipykernel/ipkernel.py", line 196, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/ipykernel/zmqshell.py", line 533, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2717, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2821, in run_ast_nodes
    if self.run_code(code, result):
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2881, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-17-62162a5a4a58>", line 3, in <module>
    net = slim.stack(net,slim.fully_connected,[512,256])
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/contrib/layers/python/layers/layers.py", line 1985, in stack
    outputs = layer(outputs, *layer_args, **kwargs)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/contrib/framework/python/ops/arg_scope.py", line 181, in func_with_args
    return func(*args, **current_args)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/contrib/layers/python/layers/layers.py", line 1433, in fully_connected
    outputs = layer.apply(inputs)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/layers/base.py", line 320, in apply
    return self.__call__(inputs, **kwargs)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/layers/base.py", line 286, in __call__
    self.build(input_shapes[0])
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/layers/core.py", line 123, in build
    trainable=True)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 1049, in get_variable
    use_resource=use_resource, custom_getter=custom_getter)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 948, in get_variable
    use_resource=use_resource, custom_getter=custom_getter)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 349, in get_variable
    validate_shape=validate_shape, use_resource=use_resource)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 1389, in wrapped_custom_getter
    *args, **kwargs)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/layers/base.py", line 275, in variable_getter
    variable_getter=functools.partial(getter, **kwargs))
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/layers/base.py", line 228, in _add_variable
    trainable=trainable and self.trainable)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/contrib/layers/python/layers/layers.py", line 1334, in layer_variable_getter
    return _model_variable_getter(getter, *args, **kwargs)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/contrib/layers/python/layers/layers.py", line 1326, in _model_variable_getter
    custom_getter=getter, use_resource=use_resource)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/contrib/framework/python/ops/arg_scope.py", line 181, in func_with_args
    return func(*args, **current_args)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/contrib/framework/python/ops/variables.py", line 262, in model_variable
    use_resource=use_resource)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/contrib/framework/python/ops/arg_scope.py", line 181, in func_with_args
    return func(*args, **current_args)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/contrib/framework/python/ops/variables.py", line 217, in variable
    use_resource=use_resource)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 341, in _true_getter
    use_resource=use_resource)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 714, in _get_single_variable
    validate_shape=validate_shape)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/ops/variables.py", line 197, in __init__
    expected_shape=expected_shape)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/ops/variables.py", line 316, in _init_from_args
    self._snapshot = array_ops.identity(self._variable, name="read")
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 1338, in identity
    result = _op_def_lib.apply_op("Identity", input=input, name=name)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 768, in apply_op
    op_def=op_def)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2336, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/export/mlrg/vthangar/anaconda2/envs/concrete-env/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1228, in __init__
    self._traceback = _extract_stack()

FailedPreconditionError (see above for traceback): Attempting to use uninitialized value Stack_2/fully_connected_1/weights
	 [[Node: Stack_2/fully_connected_1/weights/read = Identity[T=DT_FLOAT, _class=["loc:@Stack_2/fully_connected_1/weights"], _device="/job:localhost/replica:0/task:0/cpu:0"](Stack_2/fully_connected_1/weights)]]

In [ ]:
data = np.array(data).T

In [10]:
f,axarr=plt.subplots(1,4,figsize=(18,6))
axarr[0].plot(data[0],data[1])
axarr[0].set_title('Loss')

axarr[1].plot(data[0],data[2])
axarr[1].set_title('Temperature')

axarr[2].plot(data[0],data[3])
axarr[2].set_title('Recons')

axarr[3].plot(data[0],data[4])
axarr[3].set_title('KL')


Out[10]:
<matplotlib.text.Text at 0x7f30fba5d390>

In [11]:
tmp = np.reshape(np_x,(-1,280,28)) # (10,280,28)
img = np.hstack([tmp[i] for i in range(10)])
plt.imshow(img)
plt.grid('off')


---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-11-f8c8d6c03e98> in <module>()
----> 1 tmp = np.reshape(np_x,(-1,280,28)) # (10,280,28)
      2 img = np.hstack([tmp[i] for i in range(10)])
      3 plt.imshow(img)
      4 plt.grid('off')

NameError: name 'np_x' is not defined

In [ ]: