In [1]:
%load_ext autoreload
%autoreload 2
In [2]:
import tensorflow as tf
import numpy as np
In [3]:
batch_size = 64
In [4]:
seed = 40
np.random.seed(seed)
tf.set_random_seed(seed)
In [5]:
def sample_generator(uniform_range=8, N=batch_size):
samples = np.linspace(-uniform_range, uniform_range, N) + \
np.random.random(N) * 0.01 # jitter
samples = np.reshape(samples, (N, 1))
return samples
In [6]:
def generator(input, h_dim):
h0 = tf.nn.softplus(linear(input, h_dim, 'g0'))
h1 = linear(h0, 1, 'g1')
return h1
In [7]:
def sample_data(mu=3, sigma=2, N=batch_size):
samples = np.random.normal(mu, sigma, N)
samples.sort()
samples = np.reshape(samples, (N, 1))
return samples
In [8]:
def discriminator(input, h_dim):
h0 = tf.tanh(linear(input, h_dim * 2, scope='d0'))
h1 = tf.tanh(linear(h0, h_dim * 2, scope='d1'))
h2 = tf.tanh(linear(h1, h_dim * 2, scope='d2'))
# h2 = minibatch(h1, scope='minibatch')
h3 = tf.sigmoid(linear(h2, 1, 'd3'))
return h3
In [9]:
# def minibatch(input, num_kernels=5, kernel_dim=3, scope='minibatch'):
# x = linear(input, num_kernels * kernel_dim, scope=scope, stddev=0.02)
# activation = tf.reshape(x, (-1, num_kernels, kernel_dim))
# # Calculate L1-distance
# diffs = tf.expand_dims(activation, 3) \
# - tf.expand_dims(tf.transpose(activation, [1, 2, 0]), 0)
# abs_diffs = tf.reduce_sum(tf.abs(diffs), 2)
# # Apply negative exponential
# minibatch_features = tf.reduce_sum(tf.exp(-abs_diffs), 2)
# return tf.concat(1, [input, minibatch_features])
In [10]:
def linear(input, output_dim, scope=None, stddev=1.0):
norm = tf.random_normal_initializer(stddev=stddev)
const = tf.constant_initializer(0.0)
with tf.variable_scope(scope or 'linear'):
W = tf.get_variable('W', [input.get_shape()[1], output_dim], initializer=norm)
b = tf.get_variable('b', [output_dim], initializer=const)
return tf.matmul(input, W) + b
In [11]:
global_step = tf.Variable(0, name='global_step', trainable=False)
In [12]:
def optimizer(loss, var_list):
initial_learning_rate = 0.01
decay = 0.95
num_decay_steps = 200
learning_rate = tf.train.exponential_decay(
initial_learning_rate,
global_step,
num_decay_steps,
decay,
staircase=True)
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(
loss,
global_step=global_step,
var_list=var_list)
return optimizer
In [13]:
hidden_size = 4
with tf.variable_scope('D_pre'):
pre_input = tf.placeholder(tf.float32, shape=(batch_size, 1))
pre_labels = tf.placeholder(tf.float32, shape=(batch_size, 1))
D_pre = discriminator(pre_input, hidden_size)
loss_d_pre = tf.reduce_mean(tf.square(D_pre - pre_labels))
opt_d_pre = optimizer(loss_d_pre, None)
with tf.variable_scope('G'):
z_placeholder = tf.placeholder(tf.float32, shape=(batch_size, 1))
G = generator(z_placeholder, hidden_size)
with tf.variable_scope('D') as scope:
x_placeholder = tf.placeholder(tf.float32, shape=(batch_size, 1))
D1 = discriminator(x_placeholder, hidden_size)
scope.reuse_variables()
D2 = discriminator(G, hidden_size)
loss_d = tf.reduce_mean(-tf.log(D1) - tf.log(1 - D2), name='loss_d')
loss_g = tf.reduce_mean(-tf.log(D2), name='loss_g')
In [14]:
vars = tf.trainable_variables()
d_pre_params = [v for v in vars if v.name.startswith('D_pre/')]
d_params = [v for v in vars if v.name.startswith('D/')]
g_params = [v for v in vars if v.name.startswith('G/')]
opt_d = optimizer(loss_d, d_params)
opt_g = optimizer(loss_g, g_params)
In [15]:
# Number of steps to apply to the discriminator
k = 1
# Number of iterations to run training
n_epochs = 5000
In [16]:
def samples(session, num_points=10000, num_bins=batch_size):
# num_points: num of points we sample from G to compute our histogram
xs = np.linspace(-8, 8, num_points)
bins = np.linspace(-8, 8, num_bins)
# Decision boundary
db = np.zeros((num_points, 1))
for i in range(num_points // batch_size):
db[batch_size * i:batch_size * (i + 1)] = session.run(D1, {
x_placeholder: np.reshape(
xs[batch_size * i:batch_size * (i + 1)], (batch_size, 1)
)
})
# Data distribution
d = sample_data(N=num_points)
pd, _ = np.histogram(d, bins=bins, density=True)
# Generated samples
zs = np.linspace(-8, 8, num_points)
g = np.zeros((num_points, 1))
for i in range(num_points // batch_size):
g[batch_size * i:batch_size * (i + 1)] = session.run(G, {
z_placeholder: np.reshape(
zs[batch_size * i:batch_size * (i + 1)], (batch_size, 1)
)
})
pg, _ = np.histogram(g, bins=bins, density=True)
return db, pd, pg
In [17]:
from tqdm import tqdm
from scipy.stats import norm
from utils import run_animation
In [18]:
anim_frames = []
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
# Pretrain discriminator
num_pretrain_steps = 1000
for step in range(num_pretrain_steps):
d = (np.random.random(batch_size) - 0.5) * 10.0
labels = norm.pdf(d, loc=4, scale=0.5)
pretrain_loss, _ = sess.run([loss_d_pre, opt_d_pre], {
pre_input: np.reshape(d, (batch_size, 1)),
pre_labels: np.reshape(labels, (batch_size, 1))
})
weights_d = sess.run(d_pre_params)
# Copy weights from pre-training over to new D network
for i, v in enumerate(d_params):
sess.run(v.assign(weights_d[i]))
for e in tqdm(range(n_epochs)):
for i in range(k):
x, z = sample_data(), sample_generator()
loss_d_val, _, i = sess.run([loss_d, opt_d, global_step],
{ x_placeholder: x, z_placeholder: z })
assert not np.isnan(loss_d_val), 'Model diverged with loss_d_val = NaN'
z = sample_generator()
loss_g_val, _, i = sess.run([loss_g, opt_g, global_step], { z_placeholder: z })
assert not np.isnan(loss_g_val), 'Model diverged with loss_g_val = NaN'
# Animate
if i % 20 == 0:
anim_frames.append(samples(sess))
In [19]:
run_animation(anim_frames)
Out[19]: