Unrolled generative adversarial networks on a toy dataset

This notebook demos a simple implementation of unrolled generative adversarial networks on a 2d mixture of Gaussians dataset. See the paper for a better description of the technique, experiments, results, and other good stuff. Note that the architecture and hyperparameters used in this notebook are not identical to the one in the paper.

Motivation

The GAN learning problem is to find the optimal parameters $\theta_G^*$ for a generator function $G\left( z; \theta_G\right)$ in a minimax objective, $$\begin{align} \theta_G^* &= \underset{\theta_G}{\text{argmin}} \underset{\theta_D}{\max} f\left(\theta_G, \theta_D\right) \\ &= \underset{\theta_G}{\text{argmin}} \;f\left(\theta_G, \theta_D^*\left(\theta_G\right)\right)\\ \theta_D^*\left(\theta_G\right) &= \underset{\theta_D}{\max} \;f\left(\theta_G, \theta_D\right), \end{align}$$ where the saddle objective $f$ is the standard GAN loss: $$f\left(\theta_G, \theta_D\right) = \mathbb{E}_{x\sim p_{data}}\left[\mathrm{log}\left(D\left(x; \theta_D\right)\right)\right] + \mathbb{E}_{z \sim \mathcal{N}(0,I)}\left[\mathrm{log}\left(1 - D\left(G\left(z; \theta_G\right); \theta_D\right)\right)\right]. $$

In unrolled GANs, we approximate $\theta_D^*\left(\theta_G\right)$ using a few steps of gradient ascent: $$\theta_D^*\left(\theta_G\right) \approx \hat{\theta}_D\left(\theta_G\right) \equiv\text{ a few steps of SGD maximizing}\;f\left(\theta_G, \theta_D\right).$$

We can then compute the update for the generator parameters, $\theta_G$, by computing the gradient of the saddle objective with respect to $\theta_G$ and the optimized discriminator parameters, $\hat{\theta}_D$: $$\frac{d}{d \theta_G} f\left(\theta_G, \hat{\theta}_D\left(\theta_G\right)\right)$$.

Implementation details

To backpropagate through the optimization process, we need to create a symbolic computational graph that includes all the operations from the initial weights to the optimized weights. TensorFlow's built-in optimizers use custom C++ code for efficiency, and do not construct a symbolic graph that is differentiable. For this notebook, we use the optimization routines from keras to compute updates. Next, we use tf.contrib.graph_editor.graph_replace to build a copy of the graph containing the mapping from initial weights to updated weights after one optimization iteration, but replacing the initial weights with the last iteration's weights:

This yields a new graph that allows us to backprop from $\theta_D^2$ back to $\theta_D^0$. We can then plug $\theta_D^2$ into the loss function to get the final objective that the generator optimizes. Using the magic of graph_replace we can write the unrolled optimization procedure in just a few lines:

# update_dict contains a dictionary mapping from variables (\theta_D^0)
# to their values after one step of optimization (\theta_D^1)
cur_update_dict = update_dict
for i in xrange(params['unrolling_steps'] - 1):
    # Compute variable updates given the previous iteration's updated variable
    cur_update_dict = graph_replace(update_dict, cur_update_dict)
# Final unrolled loss uses the parameters at the last time step
unrolled_loss = graph_replace(loss, cur_update_dict)

Note there are many other ways of implementing unrolled optimization that don't use graph rewriting. For example, if we created a function that takes weights as inputs and returns the updated weights, we could just iteratively call that function.


In [1]:
%pylab inline
from collections import OrderedDict
import tensorflow as tf
ds = tf.contrib.distributions
slim = tf.contrib.slim
        
from keras.optimizers import Adam

try:
    from moviepy.video.io.bindings import mplfig_to_npimage
    import moviepy.editor as mpy
    generate_movie = True
except:
    print("Warning: moviepy not found.")
    generate_movie = False


Using TensorFlow backend.
WARNING:py.warnings:/usr/local/lib/python2.7/dist-packages/skimage/filter/__init__.py:6: skimage_deprecation: The `skimage.filter` module has been renamed to `skimage.filters`.  This placeholder module will be removed in v0.13.
  warn(skimage_deprecation('The `skimage.filter` module has been renamed '

Populating the interactive namespace from numpy and matplotlib

graph_replace is broken in TensorFlow 1.0 (see this issue). We get around this issue with an ugly hack that removes the problematic attribute from all ops in the graph on every call to graph_replace.


In [2]:
_graph_replace = tf.contrib.graph_editor.graph_replace

def remove_original_op_attributes(graph):
    """Remove _original_op attribute from all operations in a graph."""
    for op in graph.get_operations():
        op._original_op = None
        
def graph_replace(*args, **kwargs):
    """Monkey patch graph_replace so that it works with TF 1.0"""
    remove_original_op_attributes(tf.get_default_graph())
    return _graph_replace(*args, **kwargs)

Utility functions


In [3]:
def extract_update_dict(update_ops):
    """Extract variables and their new values from Assign and AssignAdd ops.
    
    Args:
        update_ops: list of Assign and AssignAdd ops, typically computed using Keras' opt.get_updates()

    Returns:
        dict mapping from variable values to their updated value
    """
    name_to_var = {v.name: v for v in tf.global_variables()}
    updates = OrderedDict()
    for update in update_ops:
        var_name = update.op.inputs[0].name
        var = name_to_var[var_name]
        value = update.op.inputs[1]
        if update.op.type == 'Assign':
            updates[var.value()] = value
        elif update.op.type == 'AssignAdd':
            updates[var.value()] = var + value
        else:
            raise ValueError("Update op type (%s) must be of type Assign or AssignAdd"%update_op.op.type)
    return updates

Data creation


In [4]:
def sample_mog(batch_size, n_mixture=8, std=0.01, radius=1.0):
    thetas = np.linspace(0, 2 * np.pi, n_mixture)
    xs, ys = radius * np.sin(thetas), radius * np.cos(thetas)
    cat = ds.Categorical(tf.zeros(n_mixture))
    comps = [ds.MultivariateNormalDiag([xi, yi], [std, std]) for xi, yi in zip(xs.ravel(), ys.ravel())]
    data = ds.Mixture(cat, comps)
    return data.sample(batch_size)

Generator and discriminator architectures


In [5]:
def generator(z, output_dim=2, n_hidden=128, n_layer=2):
    with tf.variable_scope("generator"):
        h = slim.stack(z, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
        x = slim.fully_connected(h, output_dim, activation_fn=None)
    return x

def discriminator(x, n_hidden=128, n_layer=2, reuse=False):
    with tf.variable_scope("discriminator", reuse=reuse):
        h = slim.stack(x, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
        log_d = slim.fully_connected(h, 1, activation_fn=None)
    return log_d

Hyperparameters


In [6]:
params = dict(
    batch_size=512,
    disc_learning_rate=1e-4,
    gen_learning_rate=1e-3,
    beta1=0.5,
    epsilon=1e-8,
    max_iter=25000,
    viz_every=5000,
    z_dim=256,
    x_dim=2,
    unrolling_steps=5,
)

Construct model and training ops


In [ ]:
tf.reset_default_graph()

data = sample_mog(params['batch_size'])

noise = ds.Normal(tf.zeros(params['z_dim']), 
                  tf.ones(params['z_dim'])).sample(params['batch_size'])
# Construct generator and discriminator nets
with slim.arg_scope([slim.fully_connected], weights_initializer=tf.orthogonal_initializer(gain=1.4)):
    samples = generator(noise, output_dim=params['x_dim'])
    real_score = discriminator(data)
    fake_score = discriminator(samples, reuse=True)
    
# Saddle objective    
loss = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=real_score, labels=tf.ones_like(real_score)) +
    tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_score, labels=tf.zeros_like(fake_score)))

gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "generator")
disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "discriminator")

# Vanilla discriminator update
d_opt = Adam(lr=params['disc_learning_rate'], beta_1=params['beta1'], epsilon=params['epsilon'])
updates = d_opt.get_updates(disc_vars, [], loss)
d_train_op = tf.group(*updates, name="d_train_op")

# Unroll optimization of the discrimiantor
if params['unrolling_steps'] > 0:
    # Get dictionary mapping from variables to their update value after one optimization step
    update_dict = extract_update_dict(updates)
    cur_update_dict = update_dict
    for i in xrange(params['unrolling_steps'] - 1):
        # Compute variable updates given the previous iteration's updated variable
        cur_update_dict = graph_replace(update_dict, cur_update_dict)
    # Final unrolled loss uses the parameters at the last time step
    unrolled_loss = graph_replace(loss, cur_update_dict)
else:
    unrolled_loss = loss

# Optimize the generator on the unrolled loss
g_train_opt = tf.train.AdamOptimizer(params['gen_learning_rate'], beta1=params['beta1'], epsilon=params['epsilon'])
g_train_op = g_train_opt.minimize(-unrolled_loss, var_list=gen_vars)

Train!


In [8]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

In [9]:
from tqdm import tqdm
xmax = 3
fs = []
frames = []
np_samples = []
n_batches_viz = 10
viz_every = params['viz_every']
for i in tqdm(xrange(params['max_iter'])):
    f, _, _ = sess.run([[loss, unrolled_loss], g_train_op, d_train_op])
    fs.append(f)
    if i % viz_every == 0:
        np_samples.append(np.vstack([sess.run(samples) for _ in xrange(n_batches_viz)]))
        xx, yy = sess.run([samples, data])
        fig = figure(figsize=(5,5))
        scatter(xx[:, 0], xx[:, 1], edgecolor='none')
        scatter(yy[:, 0], yy[:, 1], c='g', edgecolor='none')
        axis('off')
        if generate_movie:
            frames.append(mplfig_to_npimage(fig))
        show()


  0%|          | 0/25000 [00:00<?, ?it/s]
 20%|█▉        | 4999/25000 [02:58<12:08, 27.44it/s]
 40%|████      | 10000/25000 [05:56<09:03, 27.62it/s]
 60%|█████▉    | 14998/25000 [08:53<05:50, 28.53it/s]
 80%|███████▉  | 19998/25000 [11:50<02:54, 28.66it/s]
100%|██████████| 25000/25000 [14:48<00:00, 28.14it/s]

Visualize results


In [15]:
import seaborn as sns

In [16]:
np_samples_ = np_samples[::1]
cols = len(np_samples_)
bg_color  = sns.color_palette('Greens', n_colors=256)[0]
figure(figsize=(2*cols, 2))
for i, samps in enumerate(np_samples_):
    if i == 0:
        ax = subplot(1,cols,1)
    else:
        subplot(1,cols,i+1, sharex=ax, sharey=ax)
    ax2 = sns.kdeplot(samps[:, 0], samps[:, 1], shade=True, cmap='Greens', n_levels=20, clip=[[-xmax,xmax]]*2)
    ax2.set_axis_bgcolor(bg_color)
    xticks([]); yticks([])
    title('step %d'%(i*viz_every))
ax.set_ylabel('%d unrolling steps'%params['unrolling_steps'])
gcf().tight_layout()



In [19]:
fs = np.array(fs)
plot(fs)
legend(('loss', 'unrolled loss'))


Out[19]:
<matplotlib.legend.Legend at 0x7fd408752c50>

In [20]:
plot(fs[:, 0] - fs[:, 1])
legend('optimized loss - initial loss')


Out[20]:
<matplotlib.legend.Legend at 0x7fd40862cb90>

In [14]:
#clip = mpy.ImageSequenceClip(frames[::], fps=30)
#clip.ipython_display()