Generative Adversarial Nets

Training a generative adversarial network to sample from a Gaussian distribution.


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import tensorflow as tf
import numpy as np

In [3]:
batch_size = 64

In [4]:
seed = 40
np.random.seed(seed)
tf.set_random_seed(seed)

Generator


In [5]:
def sample_generator(uniform_range=8, N=batch_size):
    samples = np.linspace(-uniform_range, uniform_range, N) + \
        np.random.random(N) * 0.01 # jitter

    samples = np.reshape(samples, (N, 1))
    return samples

In [6]:
def generator(input, h_dim):    
    h0 = tf.nn.softplus(linear(input, h_dim, 'g0'))
    h1 = linear(h0, 1, 'g1')
    return h1

Discriminator


In [7]:
def sample_data(mu=3, sigma=2, N=batch_size):
    samples = np.random.normal(mu, sigma, N)
    samples.sort()
    samples = np.reshape(samples, (N, 1))
    return samples

In [8]:
def discriminator(input, h_dim):    
    h0 = tf.tanh(linear(input, h_dim * 2, scope='d0'))   
    h1 = tf.tanh(linear(h0, h_dim * 2, scope='d1'))    
    h2 = tf.tanh(linear(h1, h_dim * 2, scope='d2'))
#     h2 = minibatch(h1, scope='minibatch')    
    h3 = tf.sigmoid(linear(h2, 1, 'd3'))
    return h3

In [9]:
# def minibatch(input, num_kernels=5, kernel_dim=3, scope='minibatch'):
#     x = linear(input, num_kernels * kernel_dim, scope=scope, stddev=0.02)
#     activation = tf.reshape(x, (-1, num_kernels, kernel_dim))
    
#     # Calculate L1-distance
#     diffs = tf.expand_dims(activation, 3) \
#         - tf.expand_dims(tf.transpose(activation, [1, 2, 0]), 0)
#     abs_diffs = tf.reduce_sum(tf.abs(diffs), 2)
    
#     # Apply negative exponential
#     minibatch_features = tf.reduce_sum(tf.exp(-abs_diffs), 2)
#     return tf.concat(1, [input, minibatch_features])

In [10]:
def linear(input, output_dim, scope=None, stddev=1.0):
    norm = tf.random_normal_initializer(stddev=stddev)
    const = tf.constant_initializer(0.0)
    
    with tf.variable_scope(scope or 'linear'):
        W = tf.get_variable('W', [input.get_shape()[1], output_dim], initializer=norm)
        b = tf.get_variable('b', [output_dim], initializer=const)
        return tf.matmul(input, W) + b

Optimizer


In [11]:
global_step = tf.Variable(0, name='global_step', trainable=False)

In [12]:
def optimizer(loss, var_list):    
    initial_learning_rate = 0.01
    decay = 0.95
    num_decay_steps = 200
    
    learning_rate = tf.train.exponential_decay(
        initial_learning_rate,
        global_step,
        num_decay_steps,
        decay,
        staircase=True)
    
    optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(
        loss,
        global_step=global_step,
        var_list=var_list)

    return optimizer

Build Graph


In [13]:
hidden_size = 4

with tf.variable_scope('D_pre'):
    pre_input = tf.placeholder(tf.float32, shape=(batch_size, 1))
    pre_labels = tf.placeholder(tf.float32, shape=(batch_size, 1))
    D_pre = discriminator(pre_input, hidden_size)
    loss_d_pre = tf.reduce_mean(tf.square(D_pre - pre_labels))
    opt_d_pre = optimizer(loss_d_pre, None)    

with tf.variable_scope('G'):
    z_placeholder = tf.placeholder(tf.float32, shape=(batch_size, 1))
    G = generator(z_placeholder, hidden_size)
    
with tf.variable_scope('D') as scope:
    x_placeholder = tf.placeholder(tf.float32, shape=(batch_size, 1))
    D1 = discriminator(x_placeholder, hidden_size)
    scope.reuse_variables()
    D2 = discriminator(G, hidden_size)

loss_d = tf.reduce_mean(-tf.log(D1) - tf.log(1 - D2), name='loss_d')
loss_g = tf.reduce_mean(-tf.log(D2), name='loss_g')

In [14]:
vars = tf.trainable_variables()
d_pre_params = [v for v in vars if v.name.startswith('D_pre/')]
d_params = [v for v in vars if v.name.startswith('D/')]
g_params = [v for v in vars if v.name.startswith('G/')]

opt_d = optimizer(loss_d, d_params)
opt_g = optimizer(loss_g, g_params)

Setup Plots


In [15]:
# Number of steps to apply to the discriminator
k = 1

# Number of iterations to run training
n_epochs = 5000

In [16]:
def samples(session, num_points=10000, num_bins=batch_size):
    # num_points: num of points we sample from G to compute our histogram
    xs = np.linspace(-8, 8, num_points)
    bins = np.linspace(-8, 8, num_bins)

    # Decision boundary
    db = np.zeros((num_points, 1))
    for i in range(num_points // batch_size):
        db[batch_size * i:batch_size * (i + 1)] = session.run(D1, {
            x_placeholder: np.reshape(
                xs[batch_size * i:batch_size * (i + 1)], (batch_size, 1)
            )
        })

    # Data distribution
    d = sample_data(N=num_points)
    pd, _ = np.histogram(d, bins=bins, density=True)

    # Generated samples
    zs = np.linspace(-8, 8, num_points)
    g = np.zeros((num_points, 1))
    for i in range(num_points // batch_size):
        g[batch_size * i:batch_size * (i + 1)] = session.run(G, {
            z_placeholder: np.reshape(
                zs[batch_size * i:batch_size * (i + 1)], (batch_size, 1)
            )
        })

    pg, _ = np.histogram(g, bins=bins, density=True)

    return db, pd, pg

Train


In [17]:
from tqdm import tqdm
from scipy.stats import norm
from utils import run_animation

In [18]:
anim_frames = []

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    
    # Pretrain discriminator
    num_pretrain_steps = 1000
    for step in range(num_pretrain_steps):
        d = (np.random.random(batch_size) - 0.5) * 10.0
        labels = norm.pdf(d, loc=4, scale=0.5)
        pretrain_loss, _ = sess.run([loss_d_pre, opt_d_pre], {
            pre_input: np.reshape(d, (batch_size, 1)),
            pre_labels: np.reshape(labels, (batch_size, 1))
        })
        
    weights_d = sess.run(d_pre_params)

    # Copy weights from pre-training over to new D network
    for i, v in enumerate(d_params):
        sess.run(v.assign(weights_d[i]))    
    
    for e in tqdm(range(n_epochs)):
        for i in range(k):
            x, z = sample_data(), sample_generator()
            loss_d_val, _, i = sess.run([loss_d, opt_d, global_step],
                                        { x_placeholder: x, z_placeholder: z })
        
            assert not np.isnan(loss_d_val), 'Model diverged with loss_d_val = NaN'
        
        z = sample_generator()
        loss_g_val, _, i = sess.run([loss_g, opt_g, global_step], { z_placeholder: z })
        
        assert not np.isnan(loss_g_val), 'Model diverged with loss_g_val = NaN'
        
        # Animate
        if i % 20 == 0:        
            anim_frames.append(samples(sess))


100%|██████████| 5000/5000 [00:46<00:00, 106.52it/s]

In [19]:
run_animation(anim_frames)


Out[19]: