In [8]:
import edward as ed
import tensorflow as tf
from edward.models import Normal, OneHotCategorical, Multinomial
from scipy.misc import imsave
from tensorflow.contrib import slim
from tensorflow.examples.tutorials.mnist import input_data
In [9]:
M = 1 # minibatch size
H = 20 # hidden layer size
D = 28**2 # number of features
K = 10 # number of class labels
In [ ]:
In [10]:
# data
mnist = input_data.read_data_sets("data/mnist", one_hot=True)
In [ ]:
In [12]:
W_0 = Normal(loc=tf.zeros([D, H]), scale=tf.ones([D, H]))
W_1 = Normal(loc=tf.zeros([H, K]), scale=tf.ones([H, K]))
b_0 = Normal(loc=tf.zeros(H), scale=tf.ones(H))
b_1 = Normal(loc=tf.zeros(K), scale=tf.ones(K))
def neural_network(x):
h = tf.nn.tanh(tf.matmul(x, W_0) + b_0)
h = tf.matmul(h, W_1) + b_1
return h
x = tf.placeholder(tf.float32, [M, D])
y = OneHotCategorical(logits=neural_network(x))
In [13]:
qW_0 = Normal(loc=tf.Variable(tf.random_normal([D, H])),
scale=tf.nn.softplus(tf.Variable(tf.random_normal([D, H]))))
qW_1 = Normal(loc=tf.Variable(tf.random_normal([H, K]), name="loc"),
scale=tf.nn.softplus(tf.Variable(tf.random_normal([H, K]), name="scale")))
qb_0 = Normal(loc=tf.Variable(tf.random_normal([H]), name="loc"),
scale=tf.nn.softplus(tf.Variable(tf.random_normal([H]), name="scale")))
qb_1 = Normal(loc=tf.Variable(tf.random_normal([K]), name="loc"),
scale=tf.nn.softplus(tf.Variable(tf.random_normal([K]), name="scale")))
In [15]:
# INFERENCE
y_ph = tf.placeholder(tf.int32, [M,K])
latent_vars = {W_0: qW_0, b_0: qb_0,
W_1: qW_1, b_1: qb_1}
inference = ed.KLqp(latent_vars=latent_vars, data={y: y_ph})
In [16]:
N = mnist.train.num_examples
n_batch = int(N / M)
n_epoch = 5
inference.initialize(n_iter=n_batch * n_epoch, n_samples=5, scale={y: N / M})
tf.global_variables_initializer().run()
for _ in range(inference.n_iter):
x_train, y_train = mnist.train.next_batch(M)
info_dict = inference.update(feed_dict={x: x_train, y_ph: y_train})
inference.print_progress(info_dict)
In [ ]:
In [ ]:
inference.
In [9]:
y
Out[9]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]: