This notebook demos a simple implementation of unrolled generative adversarial networks on a 2d mixture of Gaussians dataset. See the paper for a better description of the technique, experiments, results, and other good stuff. Note that the architecture and hyperparameters used in this notebook are not identical to the one in the paper.
The GAN learning problem is to find the optimal parameters $\theta_G^*$ for a generator function $G\left( z; \theta_G\right)$ in a minimax objective, $$\begin{align} \theta_G^* &= \underset{\theta_G}{\text{argmin}} \underset{\theta_D}{\max} f\left(\theta_G, \theta_D\right) \\ &= \underset{\theta_G}{\text{argmin}} \;f\left(\theta_G, \theta_D^*\left(\theta_G\right)\right)\\ \theta_D^*\left(\theta_G\right) &= \underset{\theta_D}{\max} \;f\left(\theta_G, \theta_D\right), \end{align}$$ where the saddle objective $f$ is the standard GAN loss: $$f\left(\theta_G, \theta_D\right) = \mathbb{E}_{x\sim p_{data}}\left[\mathrm{log}\left(D\left(x; \theta_D\right)\right)\right] + \mathbb{E}_{z \sim \mathcal{N}(0,I)}\left[\mathrm{log}\left(1 - D\left(G\left(z; \theta_G\right); \theta_D\right)\right)\right]. $$
In unrolled GANs, we approximate $\theta_D^*\left(\theta_G\right)$ using a few steps of gradient ascent: $$\theta_D^*\left(\theta_G\right) \approx \hat{\theta}_D\left(\theta_G\right) \equiv\text{ a few steps of SGD maximizing}\;f\left(\theta_G, \theta_D\right).$$
We can then compute the update for the generator parameters, $\theta_G$, by computing the gradient of the saddle objective with respect to $\theta_G$ and the optimized discriminator parameters, $\hat{\theta}_D$: $$\frac{d}{d \theta_G} f\left(\theta_G, \hat{\theta}_D\left(\theta_G\right)\right)$$.
To backpropagate through the optimization process, we need to create a symbolic computational graph that includes all the operations from the initial weights to the optimized weights. TensorFlow's built-in optimizers use custom C++ code for efficiency, and do not construct a symbolic graph that is differentiable. For this notebook, we use the optimization routines from keras to compute updates. Next, we use tf.contrib.graph_editor.graph_replace to build a copy of the graph containing the mapping from initial weights to updated weights after one optimization iteration, but replacing the initial weights with the last iteration's weights:
This yields a new graph that allows us to backprop from $\theta_D^2$ back to $\theta_D^0$. We can then plug $\theta_D^2$ into the loss function to get the final objective that the generator optimizes. Using the magic of graph_replace we can write the unrolled optimization procedure in just a few lines:
# update_dict contains a dictionary mapping from variables (\theta_D^0)
# to their values after one step of optimization (\theta_D^1)
cur_update_dict = update_dict
for i in xrange(params['unrolling_steps'] - 1):
# Compute variable updates given the previous iteration's updated variable
cur_update_dict = graph_replace(update_dict, cur_update_dict)
# Final unrolled loss uses the parameters at the last time step
unrolled_loss = graph_replace(loss, cur_update_dict)
Note there are many other ways of implementing unrolled optimization that don't use graph rewriting. For example, if we created a function that takes weights as inputs and returns the updated weights, we could just iteratively call that function.
In [1]:
%pylab inline
from collections import OrderedDict
import tensorflow as tf
ds = tf.contrib.distributions
slim = tf.contrib.slim
from keras.optimizers import Adam
try:
from moviepy.video.io.bindings import mplfig_to_npimage
import moviepy.editor as mpy
generate_movie = True
except:
print("Warning: moviepy not found.")
generate_movie = False
graph_replace is broken in TensorFlow 1.0 (see this issue). We get around this issue with an ugly hack that removes the problematic attribute from all ops in the graph on every call to graph_replace.
In [2]:
_graph_replace = tf.contrib.graph_editor.graph_replace
def remove_original_op_attributes(graph):
"""Remove _original_op attribute from all operations in a graph."""
for op in graph.get_operations():
op._original_op = None
def graph_replace(*args, **kwargs):
"""Monkey patch graph_replace so that it works with TF 1.0"""
remove_original_op_attributes(tf.get_default_graph())
return _graph_replace(*args, **kwargs)
In [3]:
def extract_update_dict(update_ops):
"""Extract variables and their new values from Assign and AssignAdd ops.
Args:
update_ops: list of Assign and AssignAdd ops, typically computed using Keras' opt.get_updates()
Returns:
dict mapping from variable values to their updated value
"""
name_to_var = {v.name: v for v in tf.global_variables()}
updates = OrderedDict()
for update in update_ops:
var_name = update.op.inputs[0].name
var = name_to_var[var_name]
value = update.op.inputs[1]
if update.op.type == 'Assign':
updates[var.value()] = value
elif update.op.type == 'AssignAdd':
updates[var.value()] = var + value
else:
raise ValueError("Update op type (%s) must be of type Assign or AssignAdd"%update_op.op.type)
return updates
In [4]:
def sample_mog(batch_size, n_mixture=8, std=0.01, radius=1.0):
thetas = np.linspace(0, 2 * np.pi, n_mixture)
xs, ys = radius * np.sin(thetas), radius * np.cos(thetas)
cat = ds.Categorical(tf.zeros(n_mixture))
comps = [ds.MultivariateNormalDiag([xi, yi], [std, std]) for xi, yi in zip(xs.ravel(), ys.ravel())]
data = ds.Mixture(cat, comps)
return data.sample(batch_size)
In [5]:
def generator(z, output_dim=2, n_hidden=128, n_layer=2):
with tf.variable_scope("generator"):
h = slim.stack(z, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
x = slim.fully_connected(h, output_dim, activation_fn=None)
return x
def discriminator(x, n_hidden=128, n_layer=2, reuse=False):
with tf.variable_scope("discriminator", reuse=reuse):
h = slim.stack(x, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
log_d = slim.fully_connected(h, 1, activation_fn=None)
return log_d
In [6]:
params = dict(
batch_size=512,
disc_learning_rate=1e-4,
gen_learning_rate=1e-3,
beta1=0.5,
epsilon=1e-8,
max_iter=25000,
viz_every=5000,
z_dim=256,
x_dim=2,
unrolling_steps=5,
)
In [ ]:
tf.reset_default_graph()
data = sample_mog(params['batch_size'])
noise = ds.Normal(tf.zeros(params['z_dim']),
tf.ones(params['z_dim'])).sample(params['batch_size'])
# Construct generator and discriminator nets
with slim.arg_scope([slim.fully_connected], weights_initializer=tf.orthogonal_initializer(gain=1.4)):
samples = generator(noise, output_dim=params['x_dim'])
real_score = discriminator(data)
fake_score = discriminator(samples, reuse=True)
# Saddle objective
loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=real_score, labels=tf.ones_like(real_score)) +
tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_score, labels=tf.zeros_like(fake_score)))
gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "generator")
disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "discriminator")
# Vanilla discriminator update
d_opt = Adam(lr=params['disc_learning_rate'], beta_1=params['beta1'], epsilon=params['epsilon'])
updates = d_opt.get_updates(disc_vars, [], loss)
d_train_op = tf.group(*updates, name="d_train_op")
# Unroll optimization of the discrimiantor
if params['unrolling_steps'] > 0:
# Get dictionary mapping from variables to their update value after one optimization step
update_dict = extract_update_dict(updates)
cur_update_dict = update_dict
for i in xrange(params['unrolling_steps'] - 1):
# Compute variable updates given the previous iteration's updated variable
cur_update_dict = graph_replace(update_dict, cur_update_dict)
# Final unrolled loss uses the parameters at the last time step
unrolled_loss = graph_replace(loss, cur_update_dict)
else:
unrolled_loss = loss
# Optimize the generator on the unrolled loss
g_train_opt = tf.train.AdamOptimizer(params['gen_learning_rate'], beta1=params['beta1'], epsilon=params['epsilon'])
g_train_op = g_train_opt.minimize(-unrolled_loss, var_list=gen_vars)
In [8]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
In [9]:
from tqdm import tqdm
xmax = 3
fs = []
frames = []
np_samples = []
n_batches_viz = 10
viz_every = params['viz_every']
for i in tqdm(xrange(params['max_iter'])):
f, _, _ = sess.run([[loss, unrolled_loss], g_train_op, d_train_op])
fs.append(f)
if i % viz_every == 0:
np_samples.append(np.vstack([sess.run(samples) for _ in xrange(n_batches_viz)]))
xx, yy = sess.run([samples, data])
fig = figure(figsize=(5,5))
scatter(xx[:, 0], xx[:, 1], edgecolor='none')
scatter(yy[:, 0], yy[:, 1], c='g', edgecolor='none')
axis('off')
if generate_movie:
frames.append(mplfig_to_npimage(fig))
show()
In [15]:
import seaborn as sns
In [16]:
np_samples_ = np_samples[::1]
cols = len(np_samples_)
bg_color = sns.color_palette('Greens', n_colors=256)[0]
figure(figsize=(2*cols, 2))
for i, samps in enumerate(np_samples_):
if i == 0:
ax = subplot(1,cols,1)
else:
subplot(1,cols,i+1, sharex=ax, sharey=ax)
ax2 = sns.kdeplot(samps[:, 0], samps[:, 1], shade=True, cmap='Greens', n_levels=20, clip=[[-xmax,xmax]]*2)
ax2.set_axis_bgcolor(bg_color)
xticks([]); yticks([])
title('step %d'%(i*viz_every))
ax.set_ylabel('%d unrolling steps'%params['unrolling_steps'])
gcf().tight_layout()
In [19]:
fs = np.array(fs)
plot(fs)
legend(('loss', 'unrolled loss'))
Out[19]:
In [20]:
plot(fs[:, 0] - fs[:, 1])
legend('optimized loss - initial loss')
Out[20]:
In [14]:
#clip = mpy.ImageSequenceClip(frames[::], fps=30)
#clip.ipython_display()