In [ ]:
from __future__ import division
from __future__ import print_function
import os.path

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST')

input_dim = 784
hidden_encoder_dim = 400
hidden_decoder_dim = 400
latent_dim = 20
lam = 0

def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.001)
    return tf.Variable(initial)

def bias_variable(shape):
    initial = tf.constant(0., shape=shape)
    return tf.Variable(initial)

x = tf.placeholder("float", shape=[None, input_dim])
l2_loss = tf.constant(0.0)

W_encoder_input_hidden = weight_variable([input_dim,hidden_encoder_dim])
b_encoder_input_hidden = bias_variable([hidden_encoder_dim])
l2_loss += tf.nn.l2_loss(W_encoder_input_hidden)

# Hidden layer encoder
hidden_encoder = tf.nn.relu(tf.matmul(x, W_encoder_input_hidden) + b_encoder_input_hidden)

W_encoder_hidden_mu = weight_variable([hidden_encoder_dim,latent_dim])
b_encoder_hidden_mu = bias_variable([latent_dim])
l2_loss += tf.nn.l2_loss(W_encoder_hidden_mu)

# Mu encoder
mu_encoder = tf.matmul(hidden_encoder, W_encoder_hidden_mu) + b_encoder_hidden_mu

W_encoder_hidden_logvar = weight_variable([hidden_encoder_dim,latent_dim])
b_encoder_hidden_logvar = bias_variable([latent_dim])
l2_loss += tf.nn.l2_loss(W_encoder_hidden_logvar)

# Sigma encoder
logvar_encoder = tf.matmul(hidden_encoder, W_encoder_hidden_logvar) + b_encoder_hidden_logvar

# Sample epsilon
epsilon = tf.random_normal(tf.shape(logvar_encoder), name='epsilon')

# Sample latent variable
std_encoder = tf.exp(0.5 * logvar_encoder)
z = mu_encoder + tf.multiply(std_encoder, epsilon)

W_decoder_z_hidden = weight_variable([latent_dim,hidden_decoder_dim])
b_decoder_z_hidden = bias_variable([hidden_decoder_dim])
l2_loss += tf.nn.l2_loss(W_decoder_z_hidden)

# Hidden layer decoder
hidden_decoder = tf.nn.relu(tf.matmul(z, W_decoder_z_hidden) + b_decoder_z_hidden)

W_decoder_hidden_reconstruction = weight_variable([hidden_decoder_dim, input_dim])
b_decoder_hidden_reconstruction = bias_variable([input_dim])
l2_loss += tf.nn.l2_loss(W_decoder_hidden_reconstruction)

KLD = -0.5 * tf.reduce_sum(1 + logvar_encoder - tf.pow(mu_encoder, 2) - tf.exp(logvar_encoder), reduction_indices=1)

x_hat = tf.matmul(hidden_decoder, W_decoder_hidden_reconstruction) + b_decoder_hidden_reconstruction
BCE = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=x_hat, labels=x), reduction_indices=1)

loss = tf.reduce_mean(BCE + KLD)

regularized_loss = loss + lam * l2_loss

loss_summ = tf.summary.scalar("lowerbound", loss)
train_step = tf.train.AdamOptimizer(0.01).minimize(regularized_loss)

# add op for merging summary
summary_op = tf.summary.merge_all()

# add Saver ops
saver = tf.train.Saver()

n_steps = int(1e6)
batch_size = 100

with tf.Session() as sess:
    summary_writer = tf.summary.FileWriter('experiment',
                                          graph=sess.graph)
  
    print("Initializing parameters")
    sess.run(tf.global_variables_initializer())
    
    for step in range(1, n_steps):
        batch = mnist.train.next_batch(batch_size)
        feed_dict = {x: batch[0]}
        _, cur_loss, summary_str = sess.run([train_step, loss, summary_op], feed_dict=feed_dict)
        summary_writer.add_summary(summary_str, step)

        if step % 50 == 0:
            print("Step {0} | Loss: {1}".format(step, cur_loss))


Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting MNIST/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting MNIST/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST/t10k-labels-idx1-ubyte.gz
Initializing parameters
Step 50 | Loss: 208.06788635253906
Step 100 | Loss: 180.06475830078125
Step 150 | Loss: 203.51260375976562
Step 200 | Loss: 145.33279418945312
Step 250 | Loss: 152.0058135986328
Step 300 | Loss: 132.89476013183594
Step 350 | Loss: 143.39283752441406
Step 400 | Loss: 140.32923889160156
Step 450 | Loss: 134.5644989013672
Step 500 | Loss: 137.3271026611328
Step 550 | Loss: 139.12562561035156
Step 600 | Loss: 130.7576904296875
Step 650 | Loss: 128.86993408203125
Step 700 | Loss: 121.57710266113281
Step 750 | Loss: 121.71798706054688
Step 800 | Loss: 132.4248809814453
Step 850 | Loss: 134.16351318359375
Step 900 | Loss: 128.24562072753906
Step 950 | Loss: 122.6649398803711
Step 1000 | Loss: 121.69792938232422
Step 1050 | Loss: 125.6659927368164
Step 1100 | Loss: 118.87915802001953
Step 1150 | Loss: 125.58793640136719
Step 1200 | Loss: 118.60771179199219
Step 1250 | Loss: 123.4453353881836
Step 1300 | Loss: 132.7337646484375
Step 1350 | Loss: 123.17501831054688
Step 1400 | Loss: 124.4575424194336
Step 1450 | Loss: 118.01510620117188
Step 1500 | Loss: 116.33611297607422
Step 1550 | Loss: 129.4385528564453
Step 1600 | Loss: 120.55036926269531
Step 1650 | Loss: 130.09632873535156
Step 1700 | Loss: 125.3007583618164
Step 1750 | Loss: 122.07498931884766
Step 1800 | Loss: 121.6435546875
Step 1850 | Loss: 118.83889770507812
Step 1900 | Loss: 119.53717041015625
Step 1950 | Loss: 122.92294311523438
Step 2000 | Loss: 120.94338989257812
Step 2050 | Loss: 121.55541229248047
Step 2100 | Loss: 118.6678237915039
Step 2150 | Loss: 123.08565521240234
Step 2200 | Loss: 125.24755859375
Step 2250 | Loss: 115.99163818359375
Step 2300 | Loss: 121.4999771118164
Step 2350 | Loss: 119.26903533935547
Step 2400 | Loss: 120.16402435302734
Step 2450 | Loss: 113.10002136230469
Step 2500 | Loss: 122.14546966552734
Step 2550 | Loss: 122.36914825439453
Step 2600 | Loss: 119.39669799804688
Step 2650 | Loss: 124.0689468383789
Step 2700 | Loss: 122.1514663696289
Step 2750 | Loss: 120.1622543334961
Step 2800 | Loss: 112.24446105957031
Step 2850 | Loss: 107.25701904296875
Step 2900 | Loss: 118.7703628540039
Step 2950 | Loss: 125.22842407226562
Step 3000 | Loss: 125.98233032226562
Step 3050 | Loss: 122.34664154052734
Step 3100 | Loss: 117.41297912597656
Step 3150 | Loss: 120.1456069946289
Step 3200 | Loss: 122.1562728881836
Step 3250 | Loss: 115.88784790039062
Step 3300 | Loss: 123.56607055664062
Step 3350 | Loss: 120.23521423339844
Step 3400 | Loss: 122.25651550292969
Step 3450 | Loss: 114.21204376220703
Step 3500 | Loss: 119.462890625
Step 3550 | Loss: 120.56626892089844
Step 3600 | Loss: 117.36223602294922
Step 3650 | Loss: 110.40019226074219
Step 3700 | Loss: 114.15618133544922
Step 3750 | Loss: 116.40580749511719
Step 3800 | Loss: 115.3848648071289
Step 3850 | Loss: 119.44683837890625
Step 3900 | Loss: 116.95742797851562
Step 3950 | Loss: 118.3202133178711
Step 4000 | Loss: 117.63365173339844
Step 4050 | Loss: 112.71672058105469
Step 4100 | Loss: 115.12691497802734
Step 4150 | Loss: 116.4258804321289
Step 4200 | Loss: 118.01799011230469
Step 4250 | Loss: 115.61293029785156
Step 4300 | Loss: 117.46105194091797
Step 4350 | Loss: 118.05750274658203
Step 4400 | Loss: 114.39136505126953
Step 4450 | Loss: 116.23037719726562
Step 4500 | Loss: 117.78424072265625
Step 4550 | Loss: 118.01495361328125
Step 4600 | Loss: 122.74853515625
Step 4650 | Loss: 117.59159851074219
Step 4700 | Loss: 113.462890625
Step 4750 | Loss: 123.14982604980469
Step 4800 | Loss: 116.58440399169922
Step 4850 | Loss: 118.15299987792969
Step 4900 | Loss: 122.93072509765625
Step 4950 | Loss: 115.26317596435547
Step 5000 | Loss: 119.47685241699219
Step 5050 | Loss: 119.48384857177734
Step 5100 | Loss: 118.50972747802734
Step 5150 | Loss: 110.04559326171875
Step 5200 | Loss: 116.88732147216797
Step 5250 | Loss: 121.12173461914062
Step 5300 | Loss: 122.22545623779297
Step 5350 | Loss: 111.15548706054688
Step 5400 | Loss: 110.47991943359375
Step 5450 | Loss: 117.27287292480469
Step 5500 | Loss: 120.08091735839844
Step 5550 | Loss: 120.08263397216797
Step 5600 | Loss: 112.765625
Step 5650 | Loss: 117.5071792602539
Step 5700 | Loss: 116.00777435302734
Step 5750 | Loss: 117.9403305053711
Step 5800 | Loss: 118.1498794555664
Step 5850 | Loss: 115.70416259765625
Step 5900 | Loss: 121.04488372802734
Step 5950 | Loss: 117.7016372680664
Step 6000 | Loss: 116.72579956054688
Step 6050 | Loss: 117.94696044921875
Step 6100 | Loss: 111.77839660644531
Step 6150 | Loss: 117.66061401367188
Step 6200 | Loss: 118.19465637207031
Step 6250 | Loss: 119.76089477539062
Step 6300 | Loss: 111.56034851074219
Step 6350 | Loss: 115.82870483398438
Step 6400 | Loss: 113.86763763427734
Step 6450 | Loss: 116.8159408569336
Step 6500 | Loss: 117.59591674804688
Step 6550 | Loss: 119.50300598144531
Step 6600 | Loss: 114.06906127929688
Step 6650 | Loss: 117.50646209716797
Step 6700 | Loss: 118.0813217163086
Step 6750 | Loss: 114.482421875
Step 6800 | Loss: 117.34271240234375
Step 6850 | Loss: 114.98595428466797
Step 6900 | Loss: 118.73905944824219
Step 6950 | Loss: 113.58523559570312
Step 7000 | Loss: 113.98788452148438
Step 7050 | Loss: 121.17752075195312
Step 7100 | Loss: 117.97523498535156
Step 7150 | Loss: 121.0462875366211
Step 7200 | Loss: 114.75642395019531
Step 7250 | Loss: 116.09710693359375
Step 7300 | Loss: 117.36449432373047
Step 7350 | Loss: 117.70667266845703
Step 7400 | Loss: 118.22289276123047
Step 7450 | Loss: 113.93761444091797
Step 7500 | Loss: 111.05577087402344
Step 7550 | Loss: 123.77330017089844
Step 7600 | Loss: 117.89901733398438
Step 7650 | Loss: 117.20639038085938
Step 7700 | Loss: 115.84173583984375
Step 7750 | Loss: 117.5258560180664
Step 7800 | Loss: 109.52965545654297
Step 7850 | Loss: 118.94139862060547
Step 7900 | Loss: 119.11836242675781
Step 7950 | Loss: 113.66670227050781
Step 8000 | Loss: 115.4374771118164
Step 8050 | Loss: 110.00509643554688
Step 8100 | Loss: 118.42096710205078
Step 8150 | Loss: 117.70271301269531
Step 8200 | Loss: 110.50149536132812
Step 8250 | Loss: 116.86150360107422
Step 8300 | Loss: 118.66195678710938
Step 8350 | Loss: 115.85578155517578
Step 8400 | Loss: 119.81375885009766
Step 8450 | Loss: 117.7624282836914
Step 8500 | Loss: 115.34953308105469
Step 8550 | Loss: 117.8129653930664
Step 8600 | Loss: 116.90750122070312
Step 8650 | Loss: 111.59365844726562
Step 8700 | Loss: 118.17662048339844
Step 8750 | Loss: 120.67341613769531
Step 8800 | Loss: 112.72423553466797
Step 8850 | Loss: 119.2901382446289
Step 8900 | Loss: 107.13825988769531
Step 8950 | Loss: 118.46914672851562
Step 9000 | Loss: 117.04972839355469
Step 9050 | Loss: 116.61032104492188
Step 9100 | Loss: 121.25959777832031
Step 9150 | Loss: 117.86730194091797
Step 9200 | Loss: 115.31476593017578
Step 9250 | Loss: 115.94746398925781
Step 9300 | Loss: 123.08782196044922
Step 9350 | Loss: 118.65192413330078
Step 9400 | Loss: 117.62959289550781
Step 9450 | Loss: 113.04339599609375
Step 9500 | Loss: 118.16273498535156
Step 9550 | Loss: 119.29137420654297
Step 9600 | Loss: 114.10884857177734
Step 9650 | Loss: 116.94169616699219
Step 9700 | Loss: 111.13204956054688
Step 9750 | Loss: 116.00695037841797
Step 9800 | Loss: 116.18476867675781
Step 9850 | Loss: 112.22908020019531
Step 9900 | Loss: 114.25606536865234
Step 9950 | Loss: 108.73168182373047
Step 10000 | Loss: 114.33287048339844
Step 10050 | Loss: 121.33039093017578
Step 10100 | Loss: 123.09110260009766
Step 10150 | Loss: 119.38517761230469
Step 10200 | Loss: 116.14971923828125
Step 10250 | Loss: 121.53748321533203
Step 10300 | Loss: 120.43973541259766
Step 10350 | Loss: 116.59193420410156
Step 10400 | Loss: 119.08804321289062
Step 10450 | Loss: 113.619140625
Step 10500 | Loss: 112.19953918457031
Step 10550 | Loss: 111.6685562133789
Step 10600 | Loss: 114.97622680664062
Step 10650 | Loss: 119.60580444335938
Step 10700 | Loss: 119.49198913574219
Step 10750 | Loss: 118.43195343017578
Step 10800 | Loss: 122.50007629394531
Step 10850 | Loss: 116.86662292480469
Step 10900 | Loss: 115.89854431152344
Step 10950 | Loss: 118.72781372070312
Step 11000 | Loss: 108.7948989868164
Step 11050 | Loss: 118.50257873535156
Step 11100 | Loss: 110.60926055908203
Step 11150 | Loss: 120.69537353515625
Step 11200 | Loss: 114.28993225097656
Step 11250 | Loss: 115.44013977050781
Step 11300 | Loss: 110.51482391357422
Step 11350 | Loss: 117.07335662841797
Step 11400 | Loss: 117.82571411132812
Step 11450 | Loss: 116.74909210205078
Step 11500 | Loss: 116.3514633178711
Step 11550 | Loss: 118.20392608642578
Step 11600 | Loss: 118.54083251953125
Step 11650 | Loss: 120.45144653320312
Step 11700 | Loss: 112.87490844726562
Step 11750 | Loss: 114.39193725585938
Step 11800 | Loss: 114.47300720214844
Step 11850 | Loss: 114.05213165283203
Step 11900 | Loss: 115.93722534179688
Step 11950 | Loss: 115.79370880126953
Step 12000 | Loss: 121.48307800292969
Step 12050 | Loss: 112.57547760009766
Step 12100 | Loss: 114.92308807373047
Step 12150 | Loss: 118.25947570800781
Step 12200 | Loss: 114.76628875732422
Step 12250 | Loss: 113.80509948730469
Step 12300 | Loss: 124.15151977539062
Step 12350 | Loss: 122.06372833251953
Step 12400 | Loss: 117.94454193115234
Step 12450 | Loss: 117.29557800292969
Step 12500 | Loss: 112.87556457519531
Step 12550 | Loss: 119.79658508300781
Step 12600 | Loss: 117.26377868652344
Step 12650 | Loss: 116.1286392211914
Step 12700 | Loss: 119.35836029052734
Step 12750 | Loss: 118.67609405517578
Step 12800 | Loss: 118.87919616699219
Step 12850 | Loss: 119.9450912475586
Step 12900 | Loss: 117.23912048339844
Step 12950 | Loss: 116.15021514892578
Step 13000 | Loss: 116.6678695678711
Step 13050 | Loss: 119.83767700195312
Step 13100 | Loss: 110.35836029052734
Step 13150 | Loss: 108.51178741455078
Step 13200 | Loss: 115.44359588623047
Step 13250 | Loss: 115.70832061767578
Step 13300 | Loss: 117.70867156982422
Step 13350 | Loss: 111.4530258178711
Step 13400 | Loss: 125.64757537841797
Step 13450 | Loss: 111.11370849609375
Step 13500 | Loss: 119.33474731445312
Step 13550 | Loss: 115.30855560302734
Step 13600 | Loss: 110.55035400390625
Step 13650 | Loss: 116.16586303710938
Step 13700 | Loss: 121.72457122802734
Step 13750 | Loss: 117.2103500366211
Step 13800 | Loss: 114.72425079345703
Step 13850 | Loss: 110.55821228027344
Step 13900 | Loss: 111.8604507446289
Step 13950 | Loss: 116.0633316040039
Step 14000 | Loss: 114.9267349243164
Step 14050 | Loss: 114.90238189697266
Step 14100 | Loss: 113.99726867675781
Step 14150 | Loss: 118.53600311279297
Step 14200 | Loss: 118.81285095214844
Step 14250 | Loss: 124.75857543945312
Step 14300 | Loss: 119.23839569091797
Step 14350 | Loss: 114.2392578125
Step 14400 | Loss: 119.28097534179688
Step 14450 | Loss: 112.29423522949219
Step 14500 | Loss: 117.7653579711914
Step 14550 | Loss: 112.31188201904297
Step 14600 | Loss: 120.16435241699219
Step 14650 | Loss: 115.26638793945312
Step 14700 | Loss: 118.45664978027344
Step 14750 | Loss: 119.6676254272461
Step 14800 | Loss: 116.76998901367188
Step 14850 | Loss: 114.6698989868164
Step 14900 | Loss: 120.42974853515625
Step 14950 | Loss: 121.3724594116211
Step 15000 | Loss: 114.86787414550781
Step 15050 | Loss: 120.95048522949219
Step 15100 | Loss: 118.16832733154297
Step 15150 | Loss: 110.7814712524414
Step 15200 | Loss: 119.54023742675781
Step 15250 | Loss: 119.33824920654297
Step 15300 | Loss: 119.40328216552734
Step 15350 | Loss: 118.04212951660156
Step 15400 | Loss: 113.81212615966797
Step 15450 | Loss: 110.48765563964844
Step 15500 | Loss: 116.61421966552734
Step 15550 | Loss: 110.74839782714844
Step 15600 | Loss: 112.31326293945312
Step 15650 | Loss: 115.25269317626953
Step 15700 | Loss: 123.32250213623047
Step 15750 | Loss: 114.82608032226562
Step 15800 | Loss: 122.91587829589844
Step 15850 | Loss: 121.05029296875
Step 15900 | Loss: 117.34131622314453
Step 15950 | Loss: 115.32041931152344
Step 16000 | Loss: 111.40978240966797
Step 16050 | Loss: 114.45780181884766
Step 16100 | Loss: 114.42517852783203
Step 16150 | Loss: 119.44152069091797
Step 16200 | Loss: 118.29029083251953
Step 16250 | Loss: 114.74148559570312
Step 16300 | Loss: 117.47858428955078
Step 16350 | Loss: 113.23002624511719
Step 16400 | Loss: 109.5634994506836
Step 16450 | Loss: 115.4550552368164
Step 16500 | Loss: 118.49871063232422
Step 16550 | Loss: 115.15196228027344
Step 16600 | Loss: 116.4708480834961
Step 16650 | Loss: 111.94804382324219
Step 16700 | Loss: 113.8924789428711
Step 16750 | Loss: 123.36951446533203
Step 16800 | Loss: 114.1649398803711
Step 16850 | Loss: 115.64714813232422
Step 16900 | Loss: 117.13810729980469
Step 16950 | Loss: 117.12565612792969
Step 17000 | Loss: 119.39795684814453
Step 17050 | Loss: 117.11557006835938
Step 17100 | Loss: 119.05950164794922
Step 17150 | Loss: 118.0962905883789
Step 17200 | Loss: 115.04369354248047
Step 17250 | Loss: 115.11548614501953
Step 17300 | Loss: 119.72411346435547
Step 17350 | Loss: 121.55140686035156
Step 17400 | Loss: 112.74524688720703
Step 17450 | Loss: 118.12261962890625
Step 17500 | Loss: 112.08924102783203
Step 17550 | Loss: 120.76656341552734
Step 17600 | Loss: 112.10700988769531
Step 17650 | Loss: 114.62565612792969
Step 17700 | Loss: 120.84262084960938
Step 17750 | Loss: 110.32308959960938
Step 17800 | Loss: 115.84912109375
Step 17850 | Loss: 112.08287048339844

In [1]:
import collections
collections.