Copyright 2017 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

A version of this notebook with accompanying cloud instance can be found at

Latent Constraints: Conditional Generation from Unconditional Generative Models


Abstract: Deep generative neural networks have proven effective at both conditional and unconditional modeling of complex data distributions. Conditional generation enables interactive control, but creating new controls often requires expensive retraining. In this paper, we develop a method to condition generation without retraining the model. By post-hoc learning latent constraints, value functions that identify regions in latent space that generate outputs with desired attributes, we can conditionally sample from these regions with gradient-based optimization or amortized actor functions. Combining attribute constraints with a universal realism constraint, which enforces similarity to the data distribution, we generate realistic conditional images from an unconditional variational autoencoder. Further, using gradient-based optimization, we demonstrate identity-preserving transformations that make the minimal adjustment in latent space to modify the attributes of an image. Finally, with discrete sequences of musical notes, we demonstrate zero-shot conditional generation, learning latent constraints in the absence of labeled data or a differentiable reward function.



This notebook contains code for running experiments related to the paper. First, we load pretrained checkpoints:

  • VAE models trained on CelebA with pixelwise gaussian data liklihoods of $\mathcal{N}(\mu(z), \sigma_x=0.1)$ and $\mathcal{N}(\mu(z), \sigma_x=1)$.
  • We also provide embeddings of the training and eval set from the VAE models.
  • A generator ($G$) and discriminator ($D$) from a conditional-GAN, trained to shift samples from the prior to new points in latent space that satisfy the realism constraint ($r$) and attribute constraints ($r_{attr}$).
  • We have versions trained with no distance penalty, and also with a penalty of 1e-1.
  • A seperately trained attribute classifier in both z-space ($D_{attr}$) and pixel space ($Classifier$).

We then proceed to:

  • Demonstrate that VAE reconstructions sharpen as $\sigma_x$ lowers, at the expense of sample quality, which is compensated with latent constraints.
  • Plot conditional generation using CGANs ($D$, $G$) both with and without distance penalty.
  • Perform identity preserving transformations doing SGD in z-space wrt $D_{attr}$
  • Evaluate the accuracy of generating images with conditional attributes.

Training loops are also provided for demonstration purposes at the end of the notebook.


This colab notebook is self-contained and should run natively on google cloud. The code and checkpoints can be downloaded separately and run locally, which is recommended if you want to train your own model. Pretrained model checkpoints are available at download.magenta.tensorflow.org/models/latent_constraints/latent_constraints.tar.

Download them and extract files to /tmp/ where this notebook assumes that the files exist.

Tips: Don't forget you can navigate with the Table of Contents in the left hand sidebar, and collapse all sections with (Ctrl + Shift + ])


In [0]:
# This notebook requires DeepMind's sonnet library, which itself
# requires the nightly build of TensorFlow. The command below 
# installs both.
!pip install -q -U dm-sonnet tf-nightly

import os
import PIL

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import sklearn.metrics
import sonnet as snt
import tensorflow as tf

ds = tf.contrib.distributions

%matplotlib inline

Load the Data


In [0]:
basepath = '/tmp/'

# Load CelebA embeddings
# VAE with x_sigma = 0.1
train_mu = np.load(basepath + 'train_mu.npy')
train_sigma = np.load(basepath + 'train_sigma.npy')
eval_mu = np.load(basepath + 'eval_mu.npy')
eval_sigma = np.load(basepath + 'eval_sigma.npy')

# VAE with x_sigma = 1.0
eval_mu_xsigma1 = np.load(basepath + 'eval_mu_xsigma1.npy')
eval_sigma_xsigma1 = np.load(basepath + 'eval_sigma_xsigma1.npy')

np.random.seed(10003)
n_train = train_mu.shape[0]
n_eval = eval_mu.shape[0]

# Load Attributes
# Only use 10 salient attributes
attr_train = np.load(basepath + 'attr_train.npy')
attr_eval = np.load(basepath + 'attr_eval.npy')
attr_test = np.load(basepath + 'attr_test.npy')

attr_mask = [4, 8, 9, 11, 15, 20, 24, 31, 35, 39]
attribute_names = [
    'Bald',
    'Black_Hair',
    'Blond_Hair',
    'Brown_Hair',
    'Eyeglasses',
    'Male',
    'No_Beard',
    'Smiling',
    'Wearing_Hat',
    'Young',
]

attr_train = attr_train[:, attr_mask]
attr_eval = attr_eval[:, attr_mask]
attr_test = attr_test[:, attr_mask]

Define the Graph

All the functions with variables are wrapped in sonnet modules.

As is the model that ties them all together.

Tensors (endpoints) are accessible as attributes of the model.


In [0]:
class Encoder(snt.AbstractModule):
  '''VAE Convolutional Encoder.'''
  def __init__(self,
               n_latent,
               layers=((256, 5, 2),
                       (512, 5, 2),
                       (1024, 3, 2),
                       (2048, 3, 2)),
               name='encoder'):
    super(Encoder, self).__init__(name=name)
    self.n_latent = n_latent
    self.layers = layers

  def _build(self, x):
    h = x
    for unused_i, l in enumerate(self.layers):
      h = tf.nn.relu(snt.Conv2D(l[0], l[1], l[2])(h))

    h_shape = h.get_shape().as_list()
    h = tf.reshape(h, [-1, h_shape[1] * h_shape[2] * h_shape[3]])
    pre_z = snt.Linear(2 * self.n_latent)(h)
    mu = pre_z[:, :self.n_latent]
    sigma = tf.nn.softplus(pre_z[:, self.n_latent:])
    return mu, sigma


class Decoder(snt.AbstractModule):
  '''VAE Convolutional Decoder.'''
  def __init__(self,
               layers=((2048, 4, 4),
                       (1024, 3, 2),
                       (512, 3, 2),
                       (256, 5, 2),
                       (3, 5, 2)),
               name='decoder'):
    super(Decoder, self).__init__(name=name)
    self.layers = layers

  def _build(self, x):
    for i, l in enumerate(self.layers):
      if i == 0:
        h = snt.Linear(l[1] * l[2] * l[0])(x)
        h = tf.reshape(h, [-1, l[1], l[2], l[0]])
      elif i == len(self.layers) - 1:
        h = snt.Conv2DTranspose(l[0], None, l[1], l[2])(h)
      else:
        h = tf.nn.relu(snt.Conv2DTranspose(l[0], None, l[1], l[2])(h))
    logits = h
    return logits


class G(snt.AbstractModule):
  '''CGAN Generator. Maps from z-space to z-space.'''
  def __init__(self,
               n_latent,
               layers=(2048,)*4,
               name='generator'):
    super(G, self).__init__(name=name)
    self.layers = layers
    self.n_latent = n_latent

  def _build(self, z_and_labels):
    z, labels = z_and_labels
    labels = tf.cast(labels, tf.float32)
    size = self.layers[0]
    x = tf.concat([z, snt.Linear(size)(labels)], axis=-1)
    for l in self.layers:
      x = tf.nn.relu(snt.Linear(l)(x))
    x = snt.Linear(2 * self.n_latent)(x)
    dz = x[:, :self.n_latent]
    gates = tf.nn.sigmoid(x[:, self.n_latent:])
    z_prime = (1-gates) * z + gates * dz
    return z_prime


class D(snt.AbstractModule):
  '''CGAN Discriminator.'''
  def __init__(self,
               output_size=1,
               layers=(2048,)*4,
               name='D'):
    super(D, self).__init__(name=name)
    self.layers = layers
    self.output_size = output_size

  def _build(self, z_and_labels):
    z, labels = z_and_labels
    labels = tf.cast(labels, tf.float32)
    size = self.layers[0]
    x = tf.concat([z, snt.Linear(size)(labels)], axis=-1)
    for l in self.layers:
      x = tf.nn.relu(snt.Linear(l)(x))
    logits = snt.Linear(self.output_size)(x)
    return logits


class DAttr(snt.AbstractModule):
  '''Attribute Classifier from z-space.'''
  def __init__(self,
               output_size=1,
               layers=(2048,)*4,
               name='D'):
    super(DAttr, self).__init__(name=name)
    self.layers = layers
    self.output_size = output_size

  def _build(self, x):
    for l in self.layers:
      x = tf.nn.relu(snt.Linear(l)(x))
    logits = snt.Linear(self.output_size)(x)
    return logits
  
  
class Classifier(snt.AbstractModule):
  '''Convolutional Attribute Classifier from Pixels.'''
  def __init__(self,
               output_size,
               layers=((256, 5, 2),
                       (256, 3, 1),
                       (512, 5, 2),
                       (512, 3, 1),
                       (1024, 3, 2),
                       (2048, 3, 2)),
               name='encoder'):
    super(Classifier, self).__init__(name=name)
    self.output_size = output_size
    self.layers = layers

  def _build(self, x):
    h = x
    for unused_i, l in enumerate(self.layers):
      h = tf.nn.relu(snt.Conv2D(l[0], l[1], l[2])(h))

    h_shape = h.get_shape().as_list()
    h = tf.reshape(h, [-1, h_shape[1] * h_shape[2] * h_shape[3]])
    logits = snt.Linear(self.output_size)(h)
    return logits

In [0]:
class Model(snt.AbstractModule):
  '''All the components glued together.'''
  def __init__(self, config, name=''):
    super(Model, self).__init__(name=name)
    self.config = config

  def _build(self, unused_input=None):
    config = self.config

    # Constants
    batch_size = config['batch_size']
    n_latent = config['n_latent']
    img_width = config['img_width']
    half_batch = int(batch_size / 2)
    n_labels = 10

    #---------------------------------------------------------------------------
    ### Placeholders
    #---------------------------------------------------------------------------
    x = tf.placeholder(tf.float32, 
                       shape=(None, img_width, img_width, 3), name='x')
    # Attributes
    labels = tf.placeholder(tf.int32, shape=(None, n_labels), name='labels')
    # Real / fake label reward
    r = tf.placeholder(tf.float32, shape=(None, 1), name='D_label')
    # Transform through optimization
    z0 = tf.placeholder(tf.float32, shape=(None, n_latent), name='z0')
    z_prime = tf.get_variable('z_prime', 
                              shape=(half_batch, n_latent), dtype=tf.float32)

    #---------------------------------------------------------------------------
    ### Modules with parameters
    #---------------------------------------------------------------------------
    encoder = Encoder(n_latent=n_latent, name='encoder')
    decoder = Decoder(name='decoder')
    g = G(n_latent=n_latent, name='generator')
    d = D(output_size=1, name='d_z')
    d_attr = DAttr(output_size=n_labels, name='d_attr')
    classifier = Classifier(output_size=n_labels, name='classifier')


    #---------------------------------------------------------------------------
    ### VAE
    #---------------------------------------------------------------------------
    # Encode
    mu, sigma = encoder(x)
    q_z = ds.Normal(loc=mu, scale=sigma)

    # Optimize / Amortize or feedthrough
    q_z_sample = q_z.sample()

    transform = tf.constant(False)
    z = tf.cond(transform, lambda: z_prime, lambda: q_z_sample)

    amortize = tf.constant(False)
    z = tf.cond(amortize, lambda: g((z, labels)), lambda: z)

    # Decode
    logits = decoder(z)
    x_sigma = tf.constant(config['x_sigma'])
    p_x = ds.Normal(loc=tf.nn.sigmoid(logits), scale=x_sigma)
    x_mean = p_x.mean()

    # Reconstruction Loss
    recons = tf.reduce_sum(p_x.log_prob(x), axis=[1, 2, 3])

    mean_recons = tf.reduce_mean(recons)

    # Prior
    p_z = ds.Normal(loc=0., scale=1.)
    prior_sample = p_z.sample(sample_shape=[batch_size, n_latent])

    # KL Loss
    KL_qp = ds.kl_divergence(q_z, p_z)
    KL = tf.reduce_sum(KL_qp, axis=-1)
    mean_KL = tf.reduce_mean(KL)

    beta = tf.constant(config['beta'])

    # VAE Loss
    vae_loss = -mean_recons + mean_KL * beta

    #---------------------------------------------------------------------------
    ### Discriminator Constraint in Img and Z space and Digit space (implicit)
    #---------------------------------------------------------------------------
    d_logits = d([z, labels])

    r_pred = tf.nn.sigmoid(d_logits)  # r = [0 prior, 1 data]
    d_loss = tf.losses.sigmoid_cross_entropy(r, d_logits)

    # Mean over examples
    d_loss = tf.reduce_mean(d_loss)

    # Gradient Penalty
    real_data = z[:half_batch]
    fake_data = z[half_batch:batch_size]
    alpha = tf.random_uniform(shape=[half_batch, n_latent], minval=0., maxval=1.)
    differences = fake_data - real_data
    interpolates = real_data + (alpha * differences)
    interp_pred = d([interpolates, labels[:half_batch]])
    gradients = tf.gradients(interp_pred, [interpolates])[0]
    slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]) + 1e-10)
    gradient_penalty = tf.reduce_mean((slopes - 1.)**2)

    # Add penalty
    lambda_weight = tf.constant(config['lambda_weight'])
    d_loss_training = d_loss + lambda_weight * gradient_penalty

    
    #---------------------------------------------------------------------------
    ### Discriminator Attribute classification (implicit constraint)
    #---------------------------------------------------------------------------
    # Z-Space
    attr_weights = tf.constant(np.ones([1, n_labels]).astype(np.float32))
    logits_attr = d_attr(z)
    pred_attr = tf.nn.sigmoid(logits_attr)
    d_loss_attr =  tf.losses.sigmoid_cross_entropy(labels, 
                                                   logits=logits_attr, 
                                                   weights=attr_weights)

    
    #---------------------------------------------------------------------------
    ### OPTIMIZTION TRANSFORMATION (SGD)
    #---------------------------------------------------------------------------
    # Realism Constraint
    transform_r_weight = tf.constant(1.0)
    loss_transform = transform_r_weight * tf.reduce_mean(d_loss)

    # Attribute Constraint
    transform_attr_weight = tf.constant(0.0)
    loss_transform += transform_attr_weight * d_loss_attr
    
    # Distance Penalty
    transform_penalty_weight = tf.constant(0.0)
    z_sigma_mean = tf.constant(np.ones([1, n_latent]).astype(np.float32))
    transform_penalty = tf.log(1 + (z_prime - z0)**2)
    transform_penalty = transform_penalty * z_sigma_mean**-2
    loss_transform += tf.reduce_mean(transform_penalty_weight * transform_penalty)


    #---------------------------------------------------------------------------
    ### AMORTIZED TRANSFORMATION (Generator)
    #---------------------------------------------------------------------------
    # Realism and Attribute Constraint
    g_loss = -tf.log(tf.clip_by_value(r_pred, 1e-15, 1 - 1e-15))
    g_loss = tf.reduce_mean(g_loss)

    # Distance Penalty
    g_penalty_weight = tf.constant(0.0)
    g_penalty = tf.log(1 + (z - q_z_sample)**2)
    g_penalty = g_penalty * z_sigma_mean**-2
    g_penalty = tf.reduce_mean(g_penalty) 
    g_loss += g_penalty_weight * g_penalty

    #---------------------------------------------------------------------------
    ### Classify Attributes from pixels
    #---------------------------------------------------------------------------
    logits_classifier = classifier(x)
    pred_classifier = tf.nn.sigmoid(logits_classifier)
    classifier_loss =  tf.losses.sigmoid_cross_entropy(labels, 
                                                       logits=logits_classifier)

    
    #---------------------------------------------------------------------------
    ### Training
    #---------------------------------------------------------------------------
    # Learning rates
    d_lr = tf.constant(3e-4)
    d_attr_lr = tf.constant(3e-4)
    vae_lr = tf.constant(3e-4)
    g_lr = tf.constant(3e-4)
    classifier_lr = tf.constant(3e-4)
    transform_lr = tf.constant(3e-4)

    # Training Ops
    vae_vars = list(encoder.get_variables())
    vae_vars.extend(decoder.get_variables())
    train_vae = tf.train.AdamOptimizer(vae_lr).minimize(vae_loss, var_list=vae_vars)

    d_vars = d.get_variables()
    train_d = tf.train.AdamOptimizer(d_lr, beta1=0, beta2=0.9).minimize(
        d_loss_training, var_list=d_vars)

    classifier_vars = classifier.get_variables()
    train_classifier = tf.train.AdamOptimizer(classifier_lr).minimize(
        classifier_loss, var_list=classifier_vars)

    g_vars = g.get_variables()
    train_g = tf.train.AdamOptimizer(g_lr, beta1=0, beta2=0.9).minimize(
        g_loss, var_list=g_vars)

    d_attr_vars = d_attr.get_variables()
    train_d_attr = tf.train.AdamOptimizer(d_attr_lr).minimize(
        d_loss_attr, var_list=d_attr_vars)

    train_transform = tf.train.AdamOptimizer(transform_lr).minimize(
            loss_transform, var_list=[z_prime])
    
    # Savers
    vae_saver = tf.train.Saver(vae_vars, max_to_keep=100)
    g_saver = tf.train.Saver(g_vars, max_to_keep=1000)
    d_saver = tf.train.Saver(d_vars, max_to_keep=1000)
    d_attr_saver = tf.train.Saver(d_attr_vars, max_to_keep=1000)
    classifier_saver = tf.train.Saver(classifier_vars, max_to_keep=1000)

    # Add all endpoints as object attributes
    for k, v in locals().iteritems():
      self.__dict__[k] = v

Load all models


In [0]:
config = {
    'n_latent': 1024,
    'img_width': 64,
    'crop_width': 64,
    # Optimization parameters
    'batch_size': 128,
    'beta': 1.0,
    'x_sigma': 0.1,
    'lambda_weight': 10.0,
    'penalty_weight': 0.0,
}

In [0]:
tf.reset_default_graph()
sess = tf.Session()

# Declare
m = Model(config)
# Build
_ = m()
# Initialize
sess.run(tf.global_variables_initializer())

In [0]:
# Load VAE
ckpt = os.path.join(basepath, 'vae_best_celeba_0_crop128_beta1.ckpt')
m.vae_saver.restore(sess, ckpt)

In [0]:
# Load D
ckpt = os.path.join(basepath, 'D_d_2_conditional_penalty1e-1_399000.ckpt')
m.d_saver.restore(sess, ckpt)

In [0]:
# Load G
ckpt = os.path.join(basepath, 'G_d_2_conditional_penalty1e-1_399000.ckpt')
m.g_saver.restore(sess, ckpt)

In [0]:
# Load D_attr
ckpt = os.path.join(basepath, 'D_attr_best_d_attr_0.ckpt')
m.d_attr_saver.restore(sess, ckpt)

In [0]:
# Load Classifier
ckpt = os.path.join(basepath, 'classifier_best_classifier_0.ckpt')
m.classifier_saver.restore(sess, ckpt)

GENERATE PLOTS


In [0]:
def im(x):
  plt.imshow(np.maximum(0, np.minimum(1, x)), interpolation='none')
  plt.xticks([])
  plt.yticks([])    
  
def batch_image(b, max_images=64, rows=None, cols=None):
  """Turn a batch of images into a single image mosaic."""
  mb = min(b.shape[0], max_images)
  if rows is None:
    rows = int(np.ceil(np.sqrt(mb)))
    cols = rows
  diff = rows * cols - mb
  b = np.vstack([b[:mb], np.zeros([diff, b.shape[1], b.shape[2], b.shape[3]])])
  tmp = b.reshape(-1, cols * b.shape[1], b.shape[2], b.shape[3])
  img = np.hstack(tmp[i] for i in range(rows))
  return img

In [0]:
# A list of attributes from which to condition generation
# Each list element corresponds to a different fully-speciffied condition

cond_attr_list = [
    [
        (0, 'Bald'),
        (0, 'Black_Hair'),
        (1, 'Blond_Hair'),
        (0, 'Brown_Hair'),
        (0, 'Eyeglasses'),
        (0, 'Male'),
        (1, 'No_Beard'),
        (1, 'Smiling'),
        (0, 'Wearing_Hat'),
        (1, 'Young'),
    ],
    [
        (0, 'Bald'),
        (0, 'Black_Hair'),
        (0, 'Blond_Hair'),
        (1, 'Brown_Hair'),
        (0, 'Eyeglasses'),
        (0, 'Male'),
        (1, 'No_Beard'),
        (1, 'Smiling'),
        (0, 'Wearing_Hat'),
        (1, 'Young'),
    ],
    [
        (0, 'Bald'),
        (1, 'Black_Hair'),
        (0, 'Blond_Hair'),
        (0, 'Brown_Hair'),
        (0, 'Eyeglasses'),
        (0, 'Male'),
        (1, 'No_Beard'),
        (1, 'Smiling'),
        (0, 'Wearing_Hat'),
        (1, 'Young'),
    ],
    [
        (0, 'Bald'),
        (1, 'Black_Hair'),
        (0, 'Blond_Hair'),
        (0, 'Brown_Hair'),
        (0, 'Eyeglasses'),
        (1, 'Male'),
        (1, 'No_Beard'),
        (1, 'Smiling'),
        (0, 'Wearing_Hat'),
        (1, 'Young'),
    ],
    [
        (0, 'Bald'),
        (1, 'Black_Hair'),
        (0, 'Blond_Hair'),
        (0, 'Brown_Hair'),
        (0, 'Eyeglasses'),
        (1, 'Male'),
        (0, 'No_Beard'),
        (1, 'Smiling'),
        (0, 'Wearing_Hat'),
        (1, 'Young'),
    ],
    [
        (0, 'Bald'),
        (1, 'Black_Hair'),
        (0, 'Blond_Hair'),
        (0, 'Brown_Hair'),
        (1, 'Eyeglasses'),
        (1, 'Male'),
        (0, 'No_Beard'),
        (1, 'Smiling'),
        (0, 'Wearing_Hat'),
        (1, 'Young'),
    ],
    [
        (1, 'Bald'),
        (0, 'Black_Hair'),
        (0, 'Blond_Hair'),
        (0, 'Brown_Hair'),
        (0, 'Eyeglasses'),
        (1, 'Male'),
        (0, 'No_Beard'),
        (1, 'Smiling'),
        (0, 'Wearing_Hat'),
        (1, 'Young'),
    ],
    [
        (1, 'Bald'),
        (0, 'Black_Hair'),
        (0, 'Blond_Hair'),
        (0, 'Brown_Hair'),
        (0, 'Eyeglasses'),
        (1, 'Male'),
        (0, 'No_Beard'),
        (0, 'Smiling'),
        (0, 'Wearing_Hat'),
        (0, 'Young'),
    ],
]
    
  
cond_attrs = []
for attrs in cond_attr_list:
  cond_attrs.append( np.repeat(np.array([[a[0] for a in attrs]]).astype(np.int32), m.batch_size, axis=0) )

VAE Reconstructions


In [0]:
# Load D
ckpt = os.path.join(basepath, 'D_d_2_conditional_399000.ckpt')
m.d_saver.restore(sess, ckpt)
# Load G
ckpt = os.path.join(basepath, 'G_d_2_conditional_399000.ckpt')
m.g_saver.restore(sess, ckpt)

In [0]:
z_prior = np.random.randn(9, m.n_latent)
labels = attr_eval[:9]

ckpt = os.path.join(basepath, 'vae_best_celeba_0_crop128_beta1_xsigma1.ckpt')
m.vae_saver.restore(sess, ckpt)

z_eval_blurry = (eval_mu_xsigma1[:9] + 
                 eval_sigma_xsigma1[:9] * np.random.randn(9, m.n_latent))
b_eval_recon_blurry = sess.run(m.x_mean, {m.z:z_eval_blurry, m.labels:np.zeros([m.batch_size, m.n_labels])})
b_samples_blurry = sess.run(m.x_mean, {m.z:z_prior})


ckpt = os.path.join(basepath, 'vae_best_celeba_0_crop128_beta1.ckpt')
m.vae_saver.restore(sess, ckpt)

z_eval_sharp = (eval_mu[:9] + 
                eval_sigma[:9] * np.random.randn(9, m.n_latent))
b_eval_recon_sharp = sess.run(m.x_mean, {m.z:z_eval_sharp, m.labels:np.zeros([m.batch_size, m.n_labels])})
b_samples_sharp = sess.run(m.x_mean, {m.z:z_prior})
b_samples_refined = sess.run(m.x_mean, {m.q_z_sample:z_prior, 
                                        m.amortize:True, 
                                        m.labels:labels})

In [0]:
# Visualize Reconstructions
tot = 6
row = 5
plt.figure(figsize=[10, 10])
for i in range(row):
  plt.subplot(tot, row, 1 +  i+5*1)
  im(b_eval_recon_blurry[i])
  plt.title('Recon')
  plt.subplot(tot, row, 1 +  i+5*2)
  im(b_samples_blurry[i])
  plt.title('Sample')
  plt.subplot(tot, row, 1 +  i+5*3)
  im(b_eval_recon_sharp[i])
  plt.title('Recon')
  plt.subplot(tot, row, 1 + i+5*4)
  im(b_samples_sharp[i])
  plt.title('Sample')
  plt.subplot(tot, row, 1 + i+5*5)
  im(b_samples_refined[i])
  plt.title('Refinement')

Conditional Generation

Z-Penalty = 0.0


In [0]:
# Load D
ckpt = os.path.join(basepath, 'D_d_2_conditional_399000.ckpt')
m.d_saver.restore(sess, ckpt)
# Load G
ckpt = os.path.join(basepath, 'G_d_2_conditional_399000.ckpt')
m.g_saver.restore(sess, ckpt)


INFO:tensorflow:Restoring parameters from /content/latent_constraints/D_d_2_conditional_399000.ckpt
INFO:tensorflow:Restoring parameters from /content/latent_constraints/G_d_2_conditional_399000.ckpt

In [0]:
# Compute the Conditional Samples
z_original = sess.run(m.prior_sample)
z_new = [z_original]
b_new = [sess.run(m.x_mean, {m.z:z_original})]

for cond_attr in cond_attrs:
  z_new.append(sess.run(m.z, {m.q_z_sample: z_original, m.amortize:True, m.labels:cond_attr}))
  b_new.append(sess.run(m.x_mean, {m.z:z_new[-1]}))

In [0]:
# Plot them
idxs = range(10)
plt.figure(figsize=(12, 14))
n_b = len(b_new)
tot = 6
barr = np.array(b_new)
barr = np.swapaxes(barr, 0, 1)
barr = barr[idxs, :, :, :, :]

plt.figure(figsize=(18, 12))
n_b = len(b_new)
tot = 16
for i, b in enumerate(b_new):
  plt.subplot(n_b, 1, i + 1)
  im(batch_image(b, max_images=tot, rows=tot, cols=1))

Z-Penalty = 0.1


In [0]:
# Load D
ckpt = os.path.join(basepath, 'D_d_2_conditional_penalty1e-1_399000.ckpt')
m.d_saver.restore(sess, ckpt)
# Load G
ckpt = os.path.join(basepath, 'G_d_2_conditional_penalty1e-1_399000.ckpt')
m.g_saver.restore(sess, ckpt)

In [0]:
# Compute the Conditional Samples
z_original = sess.run(m.prior_sample)
z_new = [z_original]
b_new = [sess.run(m.x_mean, {m.z:z_original})]

for cond_attr in cond_attrs:
  z_new.append(sess.run(m.z, {m.q_z_sample: z_original, m.amortize:True, m.labels:cond_attr}))
  b_new.append(sess.run(m.x_mean, {m.z:z_new[-1]}))

In [0]:
# Plot them
idxs = range(10)
plt.figure(figsize=(12, 14))
n_b = len(b_new)
tot = 6
barr = np.array(b_new)
barr = np.swapaxes(barr, 0, 1)
barr = barr[idxs, :, :, :, :]

plt.figure(figsize=(18, 12))
n_b = len(b_new)
tot = 16
for i, b in enumerate(b_new):
  plt.subplot(n_b, 1, i + 1)
  im(batch_image(b, max_images=tot, rows=tot, cols=1))

Identity Preserving Transformations


In [0]:
# Load D
ckpt = os.path.join(basepath, 'D_d_2_conditional_penalty1e-1_399000.ckpt')
m.d_saver.restore(sess, ckpt)
# Load G
ckpt = os.path.join(basepath, 'G_d_2_conditional_penalty1e-1_399000.ckpt')
m.g_saver.restore(sess, ckpt)

In [0]:
def transform(z_original, 
              labels,
              z0=None,
              lr=1e-1, 
              n_opt=100, 
              penalty_weight=0.0,
              r_weight=1.0, 
              attr_weight=1.0, 
              attr_weights=np.ones([1, m.n_labels]),
              r_threshold = 0.9,
              attr_threshold = 0.9,
              adaptive=False,
             ):

  if z0 is None:
    z0 = z_original
  _ = sess.run(tf.assign(m.z_prime, z_original))
  z_new = np.zeros([m.half_batch, m.n_latent])
  i_threshold = np.zeros([m.half_batch])
  z_trace = []
  for i in range(n_opt):
    res = sess.run([m.train_transform, 
                    m.loss_transform, 
                    m.transform_penalty,
                    m.z_prime,
                    m.r_pred,
                    m.pred_attr], 
                   {m.z0: z0,
                    m.r: np.ones([m.half_batch, 1]),
                    m.labels: labels,
                    m.transform: True, 
                    m.transform_lr: lr,
                    m.transform_penalty_weight: penalty_weight,
                    m.transform_r_weight: r_weight,
                    m.transform_attr_weight: attr_weight,
                    m.attr_weights: attr_weights,
                    m.x: np.zeros([m.half_batch, m.img_width, m.img_width, 3])
                   })

    z_prime = res[3]
    z_trace.append(z_prime)
    pred_r = res[-2]
    pred_attr = res[-1]
    attr_acc = 1.0 - np.mean(attr_weights * np.abs(labels - pred_attr), axis=1)
    transform_penalty = np.mean(res[2])
    check_idx = np.where(i_threshold == 0)[0]
    if len(check_idx) == 0:
      break
    for idx in check_idx:
      if pred_r[idx] > r_threshold and attr_acc[idx] > attr_threshold:
        z_new[idx] = z_prime[idx]
        i_threshold[idx] = i
    if adaptive:
      r_weight = 1 - np.mean(pred_r)
      attr_weight = 1 - np.mean(attr_acc)

    if i % 100 == 1:
      print( 'Step %d, NotConverged: %d, Loss: %0.3e, Penalty: %0.3f, '
            'r: %0.3f, r_min: %0.3f, '
            'attr: %0.3f, attr_min:%0.3f, ' % (
                i,
                len(check_idx), 
                res[1], 
                transform_penalty, 
                np.mean(pred_r), 
                np.min(pred_r),
                np.mean(attr_acc), 
                np.min(attr_acc),
           ))
      

  check_idx = np.where(i_threshold == 0)[0]
  print('%d did not converge' % len(check_idx))
  for idx in check_idx:
    z_new[idx] = z_prime[idx]
  return z_new, i_threshold, z_trace

In [0]:
# BETA1 TRANSFORMATION
z_original = (eval_mu[:m.half_batch] + 
              eval_sigma[:m.half_batch] * np.random.randn(m.half_batch, m.n_latent))
z = z_original

label_list = (
    (1, 1, 1e-2, 1e-3, 0.05), 
    (2, 1, 1e-2, 1e-3, 0.05), 
    (3, 1, 1e-2, 1e-3, 0.05), 
    (4, 1, 1e-2, 1e-5, 0.05), 
    (5, 0, 1e-2, 1e-3, 0.05), 
    (5, 1, 1e-2, 1e-4, 0.05), 
    (6, 0, 5e-3, 1e-4, 0.005), 
    (7, 0, 1e-2, 1e-3, 0.05), 
    (7, 1, 1e-2, 1e-3, 0.05), 
    (9, 0, 3e-2, 1e-4, 0.01), 
    (9, 1, 3e-2, 1e-4, 0.01)
)
z_list = [z_original]
b_list = [sess.run(m.x_mean, {m.z: z_original})]


for attr, value, lr, r_thresh, r_weight in label_list:
  print('Label: %d, Value: %d' % (attr, value))
  attr_weights=np.zeros([1, m.n_labels])
  labels = attr_eval[:m.half_batch].copy()
  labels[:, attr] = value
  attr_weights[:, attr] = 1

  z, i_threshold, z_trace = transform(
      z_original, 
      labels, 
      lr=lr,
      n_opt=300,
      r_weight=r_weight,
      attr_weight=1.0,
      attr_weights=attr_weights,
      r_threshold=0.3,
      attr_threshold=(1 - r_thresh),
  )
  z_list.append(z)
  b_list.append(sess.run(m.x_mean, {m.z: z}))

In [0]:
labels = attr_eval[:m.half_batch].copy()
idxs = (0, 2, 5, 10, 7, 13,  3,)
labels = labels[idxs, :]

plt.figure(figsize=(24, 18))
n_b = len(b_list)
tot = 6
barr = np.array(b_list)
barr = np.swapaxes(barr, 0, 1)
barr = barr[idxs, :, :, :, :]
for i, b in enumerate(barr):
  for j, (k, v, _, _, _) in enumerate(label_list):
    if labels[i, k] == v:
      barr[i, j+1] = 0.

for i, b in enumerate(barr):
  plt.subplot(12, 1, i + 1)
  im(batch_image(b, max_images=n_b, rows=n_b, cols=1))

Attribute Classification Accuracy


In [0]:
# Load D
ckpt = os.path.join(basepath, 'D_d_2_conditional_penalty1e-1_399000.ckpt')
m.d_saver.restore(sess, ckpt)
# Load G
ckpt = os.path.join(basepath, 'G_d_2_conditional_penalty1e-1_399000.ckpt')
m.g_saver.restore(sess, ckpt)

Original Images


In [0]:
tf.train.AdamOptimizer()
batch_size = 256
# Prediction Accuracy
train_pred = []
eval_pred = []
test_pred = []

for i in range(n_train / 10 /batch_size):
  start = (i * batch_size)
  end = start + batch_size
  batch_images = train_data[start:end]

  start = i * batch_size
  end = start + batch_size
  res = sess.run([m.pred_classifier], {m.x: batch_images})
  train_pred.append(res[0])
train_pred = np.vstack(train_pred)

for i in range(n_eval/batch_size):
  start = (i * batch_size)
  end = start + batch_size
  batch_images = eval_data[start:end]

  res = sess.run([m.pred_classifier], {m.x: batch_images})
  eval_pred.append(res[0])
eval_pred = np.vstack(eval_pred)

train_acc = (train_pred > 0.5) == attr_train[:train_pred.shape[0]]
eval_acc = (eval_pred > 0.5) == attr_eval[:eval_pred.shape[0]]

print "Train Accuracy: %.4f" % (np.mean(train_acc) * 100)
print "Eval Accuracy: %.4f" % (np.mean(eval_acc) * 100)

In [0]:
y_true = attr_eval[:eval_pred.shape[0]]
y_pred = eval_pred >= 0.5

prec, recall, f1, support = sklearn.metrics.classification.precision_recall_fscore_support(y_true, y_pred, )
report = sklearn.metrics.classification_report(y_true, y_pred, digits=3, target_names=attribute_names)
print(report)

Conditional Generation


In [0]:
# Prediction Accuracy
train_pred = []
eval_pred = []
test_pred = []

for i in range(n_eval/batch_size):
  start = (i * batch_size)
  end = start + batch_size
  labels = attr_eval[start:end]
  batch_z = np.random.randn(labels.shape[0], config["n_latent"])
  xtmp = np.zeros([labels.shape[0], img_width*img_width])
  xsamp, z_prime = sess.run([m.x_mean, m.z],
                  {m.q_z_sample: batch_z, 
                   m.amortize:True, 
                   m.labels:labels})
  res = sess.run([m.pred_classifier], {m.x: xsamp})
  eval_pred.append(res[0])
eval_pred = np.vstack(eval_pred)

eval_acc = (eval_pred > 0.5) == attr_eval[:eval_pred.shape[0]]

eval_softmax = eval_pred.argmax(axis=-1) == attr_eval[:eval_pred.shape[0]].argmax(axis=-1)

print "Eval Accuracy: %.4f" % (np.mean(eval_acc) * 100)

In [0]:
y_true = attr_eval[:eval_pred.shape[0]]
y_pred = eval_pred >= 0.5

prec, recall, f1, support = sklearn.metrics.classification.precision_recall_fscore_support(y_true, y_pred, )
report = sklearn.metrics.classification_report(y_true, y_pred, digits=3, target_names=attribute_names)
print(report)

TRAINING MODELS

This code is for demonstration purposes.

Training of the VAE and D and G can take about a day on a TitanX GPU.

To train models from scratch, you will need to download the entire CelebA dataset to a base directory (assumed as ~/Desktop/CelebA/). The steps for preprocessing the data are provided below.

(img_align_celeba/, list_attr_celeba.txt, list_eval_partition.txt)


In [0]:
# Running Average
running_N = 100
running_N_eval = 10
rmean = lambda data: np.mean(data[-running_N:])
rmeane = lambda data: np.mean(data[-running_N_eval:])
batch_size = 256

Prepare the Data (Crop and Pad)


In [0]:
basepath = os.path.expanduser('~/Desktop/CelebA/')
save_path = basepath

In [0]:
partition = np.loadtxt(basepath + 'list_eval_partition.txt', usecols=(1,))
train_mask = (partition == 0)
eval_mask = (partition == 1)
test_mask = (partition == 2)

print "Train: %d, Validation: %d, Test: %d, Total: %d" % (train_mask.sum(), eval_mask.sum(), test_mask.sum(), partition.shape[0])

In [0]:
attributes = pd.read_table(basepath + 'list_attr_celeba.txt', skiprows=1, delim_whitespace=True, usecols=range(1, 41))
attribute_names = attributes.columns.values
attribute_values = attributes.values

In [0]:
attr_train = attribute_values[train_mask]
attr_eval = attribute_values[eval_mask]
attr_test = attribute_values[test_mask]

attr_train[attr_train == -1] = 0
attr_eval[attr_eval == -1] = 0
attr_test[attr_test == -1] = 0

np.save(basepath + 'attr_train.npy', attr_train)
np.save(basepath + 'attr_eval.npy', attr_eval)
np.save(basepath + 'attr_test.npy', attr_test)

In [0]:
def pil_crop_downsample(x, width, out_width):
  half_shape = tuple((i - width) / 2 for i in x.size)
  x = x.crop([half_shape[0], half_shape[1], half_shape[0] + width, half_shape[1] + width])
  return x.resize([out_width, out_width], resample=PIL.Image.ANTIALIAS)

def load_and_adjust_file(filename, width, outwidth):
  img = PIL.Image.open(filename)
  img = pil_crop_downsample(img, width, outwidth)
  img = np.array(img, np.float32) / 255.
  return img

In [0]:
# CELEBA images are (218 x 178) originally
filenames = np.sort(glob(basepath + 'img_align_celeba/*.jpg'))

crop_width = 128
img_width = 64
postfix = '_crop_%d_res_%d.npy' % (crop_width, img_width)

n_files = len(filenames)
all_data = np.zeros([n_files, img_width, img_width, 3], np.float32)
for i, fname in enumerate(filenames):
  all_data[i, :, :] = load_and_adjust_file(fname, crop_width, img_width)
  if i % 10000 == 0:
    print('%.2f percent done' % (float(i)/n_files * 100.0))
train_data = all_data[train_mask]
eval_data = all_data[eval_mask]
test_data = all_data[test_mask]
np.save(basepath + 'train' + postfix, train_data)
np.save(basepath + 'eval' + postfix, eval_data)
np.save(basepath + 'test' + postfix, test_data)

Train the VAE


In [0]:
sess.run(tf.variables_initializer(var_list=m.vae_vars))

# Train the VAE
results = []
results_eval = []

traces = {'i': [],
          'i_eval': [],
          'loss': [],
          'loss_eval': [],
          'recons': [],
          'recons_eval': [],
          'kl': [],
          'kl_eval': []}

n_iters = 200000
vae_lr_ = np.logspace(np.log10(3e-4), np.log10(1e-6), n_iters)

for i in range(n_iters):
  start = (i * batch_size) % n_train
  end = start + batch_size
  batch = train_data[start:end]

  res = sess.run([m.train_vae, 
                  m.vae_loss, 
                  m.mean_recons, 
                  m.mean_KL], 
                 {m.x: batch,
                  m.vae_lr: vae_lr_[i],
                  m.amortize: False,
                  m.labels: attr_train[start:end]})
  
  traces['loss'].append(res[1])
  traces['recons'].append(res[2])
  traces['kl'].append(res[3])
  traces['i'].append(i)

  if i % 10 == 0:
    start = (i * batch_size) % n_eval
    end = start + batch_size
    batch = eval_data[start:end]
    res_eval = sess.run([m.vae_loss, m.mean_recons, m.mean_KL], 
                        {m.x: batch, m.labels: attr_eval[start:end]})
    traces['loss_eval'].append(res_eval[0])
    traces['recons_eval'].append(res_eval[1])
    traces['kl_eval'].append(res_eval[2])
    traces['i_eval'].append(i)

    print('Step %5d \t TRAIN \t Loss: %0.3f, Recon: %0.3f, KL: %0.3f '
          '\t EVAL \t  Loss: %0.3f, Recon: %0.3f, KL: %0.3f' % (i, 
                                                                rmean(traces['loss']), 
                                                                rmean(traces['recons']), 
                                                                rmean(traces['kl']), 
                                                                rmeane(traces['loss_eval']), 
                                                                rmeane(traces['recons_eval']), 
                                                                rmeane(traces['kl_eval']) ))

In [0]:
plt.figure(figsize=(18,6))

plt.subplot(131)
plt.plot(traces['i'], traces['loss'])
plt.plot(traces['i_eval'], traces['loss_eval'])
plt.title('Loss')
# plt.ylim(30, 100)

plt.subplot(132)
plt.plot(traces['i'], traces['recons'])
plt.plot(traces['i_eval'], traces['recons_eval'])
plt.title('Recons')
# plt.ylim(-100, -30)

plt.subplot(133)
plt.plot(traces['i'], traces['kl'])
plt.plot(traces['i_eval'], traces['kl_eval'])
plt.title('KL')
# plt.ylim(10, 100)

Train D and G jointly


In [0]:
# Precompute means and vars
train_mu = []
train_sigma = []
n_batches = int(np.ceil(float(n_train) / batch_size))
for i in range(n_batches):
  if i % 1000 == 0:
    print '%.1f Done' % (float(i) / n_train * batch_size * 100)
  start = i * batch_size
  end = start + batch_size
  res = sess.run([m.mu, m.sigma], {m.x: train_data[start:end]})
  train_mu.append(res[0])
  train_sigma.append(res[1])
train_mu = np.vstack(train_mu)
train_sigma = np.vstack(train_sigma)
sigma_mean = train_sigma.mean(0, keepdims=True)
print train_mu.shape, train_sigma.shape, train_data.shape

In [0]:
# Precompute means and vars
eval_mu = []
eval_sigma = []
n_batches = int(np.ceil(float(n_eval) / batch_size))
for i in range(n_batches):
  if i % 1000 == 0:
    print '%.1f Done' % (float(i) / n_eval * batch_size * 100)
  start = i * batch_size
  end = start + batch_size
  res = sess.run([m.mu, m.sigma], {m.x: eval_data[start:end]})
  eval_mu.append(res[0])
  eval_sigma.append(res[1])
eval_mu = np.vstack(eval_mu)
eval_sigma = np.vstack(eval_sigma)
sigma_mean_eval = eval_sigma.mean(0, keepdims=True)
print eval_mu.shape, eval_sigma.shape, eval_data.shape

In [0]:
plt.plot(np.sort(sigma_mean.flatten()))
plt.plot(np.sort(sigma_mean_eval.flatten()))

In [0]:
# With eval loss, looking for overfitting
sess.run(tf.variables_initializer(var_list=m.d_vars))
sess.run(tf.variables_initializer(var_list=m.g_vars))

# Declare hyperparameters
results_D = []
results_G = []
traces = {'i': [], 
          'i_pred': [], 
          'D': [], 
          'G': [], 
          'G_real': [],
          'g_penalty': [],
          'pred_train': [], 
          'pred_eval': [], 
          'pred_prior': [], 
          'pred_gen': [],
          'z_dist_eval': [],
          'attr_loss': [],
          'attr_acc': [],
         }


n_iters = 200000
d_lr_ = np.logspace(-4, -4, n_iters)
g_lr_ = np.logspace(-4, -4, n_iters)

g_penalty_weight_ = 0.1
lambda_weight_ = 10
z_sigma_mean_ = sigma_mean

percentage_prior_fake = 0.1
N_between_update_G = 10
N_between_eval = 100

n_train = train_mu.shape[0]
n_eval = eval_mu.shape[0]

# Training Loop
for i in range(n_iters):
  start = (i * batch_size/2) % n_train
  end = start + batch_size/2
  fake_start = np.random.choice(np.arange(n_train - batch_size/2))
  fake_end = fake_start + batch_size/2
  
  real_img = train_data[start:end]
  n_batch = real_img.shape[0]
  if n_batch == batch_size/2 and start != fake_start:
    # Compare real vs. fake
    fake_z_prior = np.random.randn(batch_size/2, n_latent)
    real_attr = attr_train[start:end].astype(np.int32)
    fake_attr = attr_train[fake_start:fake_end].astype(np.int32)            
    real_z = train_mu[start:end] + train_sigma[start:end] * np.random.randn(batch_size/2, n_latent)


    
    if np.random.rand(1) < percentage_prior_fake:
      # Use Prior for fake_samples      
      all_z = np.vstack([real_z, fake_z_prior, real_z])
    else:
      # Use Generator to make fake_samples 
      fake_z_gen = sess.run(m.z, {m.q_z_sample: fake_z_prior, 
                                  m.amortize:True, 
                                  m.labels: real_attr,})      
      all_z = np.vstack([real_z, fake_z_gen, real_z])
    all_attr = np.vstack([real_attr, real_attr, fake_attr]) 

      
    # Train Discriminator
    real_r = np.ones([batch_size/2, 1])
    fake_r = np.zeros([batch_size/2, 1])
    all_r = np.concatenate([real_r, fake_r, fake_r])
    res_d = sess.run([m.train_d, m.d_loss], {m.z: all_z, 
                                         m.r: all_r,
                                         m.d_lr: d_lr_[i],
                                         m.lambda_weight: lambda_weight_,
                                         m.labels: all_attr,})
    
    # Train Generator
    if i % N_between_update_G == 0:
      if g_penalty_weight_ > 0:
        # Train on real data
        res_g_real = sess.run([m.train_g, m.g_loss, m.g_penalty], 
                              {m.q_z_sample: real_z,
                               m.amortize: True,
                               m.g_penalty_weight: g_penalty_weight_,
                               m.z_sigma_mean: z_sigma_mean_,
                               m.g_lr: g_lr_[i],
                               m.labels: real_attr,})
        traces['G_real'].append(res_g_real[1])

      # Train on generated data
      res_g = sess.run([m.train_g, m.g_loss, m.g_penalty], 
                       {m.q_z_sample: fake_z_prior,
                        m.amortize: True,
                        m.g_penalty_weight: g_penalty_weight_,
                        m.z_sigma_mean: z_sigma_mean_,
                        m.g_lr: g_lr_[i],
                        m.labels: real_attr,})



    traces['i'].append(i)
    traces['D'].append(res_d[1])
    traces['G'].append(res_g[1])
    traces['g_penalty'].append(res_g[2])

    if i % N_between_eval == 0:
      eval_start = np.random.choice(np.arange(n_eval - batch_size/2))
      eval_end = eval_start + batch_size/2
      real_attr_eval = attr_eval[eval_start:eval_end].astype(np.int32)
      real_z_eval = eval_mu[eval_start:eval_end] + eval_sigma[eval_start:eval_end] * np.random.randn(batch_size/2, n_latent)
      z_eval_gen = sess.run(m.z, {m.q_z_sample: real_z_eval, 
                                  m.amortize:True, 
                                  m.labels: real_attr,})      
      fake_z_gen = sess.run(m.z, {m.q_z_sample: fake_z_prior, 
                                  m.amortize:True, 
                                  m.labels: real_attr,})      
      
      pred_train_ = np.mean(sess.run([m.r_pred], {m.z: real_z, m.labels: real_attr}))
      pred_eval_ = np.mean(sess.run([m.r_pred], {m.z: real_z_eval, m.labels: real_attr_eval}))
      pred_prior_ = np.mean(sess.run([m.r_pred], {m.z: fake_z_prior, m.labels: real_attr}))
      pred_gen_ = np.mean(sess.run([m.r_pred], {m.z: fake_z_gen, m.labels: real_attr}))

      traces['i_pred'].append(i)
      traces['pred_train'].append(pred_train_)
      traces['pred_eval'].append(pred_eval_)
      traces['pred_prior'].append(pred_prior_)
      traces['pred_gen'].append(pred_gen_)
      traces['z_dist_eval'].append(np.mean(((z_eval_gen - real_z_eval)/z_sigma_mean_)**2))
      print 'PRED Step %d, \t TRAIN: %.2e \t EVAL: %.2e \t PRIOR: %.2e \t GEN: %.2e ' % (i, 
                                                                                         rmeanp(traces['pred_train']),  
                                                                                         rmeanp(traces['pred_eval']),  
                                                                                         rmeanp(traces['pred_prior']),  
                                                                                         rmeanp(traces['pred_gen']))

In [0]:
plt.figure(figsize=(18,12))

plt.subplot(4, 1, 1)
plt.plot(traces['i_pred'], traces['pred_train'], label='train')
plt.plot(traces['i_pred'], traces['pred_eval'], label='eval')
plt.plot(traces['i_pred'], traces['pred_prior'], label='prior')
plt.plot(traces['i_pred'], traces['pred_gen'], label='gen')
plt.ylabel('Predictions')
plt.legend(loc='upper right')

plt.subplot(4, 1, 2)
plt.plot(traces['i'], traces['G'])
plt.ylabel('G Loss')

plt.subplot(4, 1, 3)
# plt.semilogy(traces['i'], traces['D'])
plt.plot(traces['i'], traces['D'])
plt.ylabel('D Loss')

plt.subplot(4, 1, 4)
plt.semilogy(traces['i_pred'], traces['z_dist_eval'])
plt.ylabel('Weighted Z Distance Eval')

Train Classifier


In [0]:
sess.run(tf.variables_initializer(var_list=m.classifier_vars))
# Train the Classifier
results = []
results_eval = []

running_N = 100
running_loss = 1
running_loss_eval = 1

classifier_lr_ = 3e-4

# Train
for i in range(40000):
  start = (i * batch_size) % n_train
  end = start + batch_size
  batch_images = train_data[start:end]
  batch_labels = attr_train[start:end]

  res = sess.run([m.train_classifier, 
                  m.classifier_loss], 
                 {m.x: batch_images, 
                  m.labels: batch_labels.astype(np.int32),
                  m.classifier_lr: classifier_lr_})
  running_loss += (res[1] - running_loss) / running_N 
  if i % 10 == 1:
    start = (i * batch_size) % n_eval
    end = start + batch_size
    eval_images = eval_data[start:end]
    eval_labels = attr_eval[start:end]
    res_eval = sess.run([m.classifier_loss], 
                        {m.x: eval_images, 
                         m.labels: eval_labels.astype(np.int32)})
    running_loss_eval += (res_eval[0] - running_loss_eval) / (running_N / 10)
      
    results.append([i] + res[1:])
    results_eval.append([i] + res_eval[0:])

  if i % 10 == 1:
    print('Step %d, \t TRAIN \t Loss: %0.3f \t EVAL \t Loss: %0.3f' % (i, running_loss, running_loss_eval))

In [0]:
plot_train = np.array(results).T
plot_eval = np.array(results_eval).T
plt.figure(figsize=(18,6))
plt.plot(plot_train[0],plot_train[1])
plt.plot(plot_eval[0],plot_eval[1])
plt.ylim(1e-1, 1)
plt.title('Loss')

In [0]:
# Prediction Accuracy
train_pred = []
eval_pred = []
test_pred = []

for i in range(n_train / 10 /batch_size):
  start = (i * batch_size)
  end = start + batch_size
  batch_images = train_data[start:end]

  start = i * batch_size
  end = start + batch_size
  res = sess.run([m.pred_classifier], {m.x: batch_images})
  train_pred.append(res[0])
train_pred = np.vstack(train_pred)

for i in range(n_eval/batch_size):
  start = (i * batch_size)
  end = start + batch_size
  batch_images = eval_data[start:end]

  res = sess.run([m.pred_classifier], {m.x: batch_images})
  eval_pred.append(res[0])
eval_pred = np.vstack(eval_pred)

train_acc = (train_pred > 0.5) == attr_train[:train_pred.shape[0]]
eval_acc = (eval_pred > 0.5) == attr_eval[:eval_pred.shape[0]]

print "Train Accuracy: %.4f" % (np.mean(train_acc) * 100)
print "Eval Accuracy: %.4f" % (np.mean(eval_acc) * 100)

In [0]:
y_true = attr_eval[:eval_pred.shape[0]]
y_pred = eval_pred >= 0.5

prec, recall, f1, support = sklearn.metrics.classification.precision_recall_fscore_support(y_true, y_pred, )
report = sklearn.metrics.classification_report(y_true, y_pred, digits=2, target_names=attribute_names)
print report

Train Attribute Classifier (D_attr) in z space


In [0]:
# Train the Discriminator
sess.run(tf.variables_initializer(var_list=m.d_attr_vars))
traces = {"i": [],
          "i_eval": [],
          "D_loss": [],
          "D_loss_eval": [],
          "accuracy": [],
         }
running_N = 800
running_N_eval = 80

n_iters = 10000
d_attr_lr_ = np.logspace(-4, -4, n_iters)

for i in range(n_iters):
  start = (i * batch_size/2) % n_train
  end = min(start + batch_size/2, n_train)
  batch = train_mu[start:end] + train_sigma[start:end] * np.random.randn(end-start, n_latent)
  if batch.shape[0] == batch_size/2:
    labels = attr_train[start:end]

    # Train
    res = sess.run([m.train_d_attr, m.d_loss_attr], 
                   {m.z: batch, 
                    m.labels: labels, 
                    m.d_attr_lr: d_attr_lr_[i],})

    traces['i'].append(i)
    traces['D_loss'].append(res[1])

  if i % 10 == 1:
    start = (i * batch_size/2) % n_eval
    end = min(start + batch_size/2, n_eval)
    batch = eval_mu[start:end] + eval_sigma[start:end] * np.random.randn(end-start, n_latent)
    if batch.shape[0] == batch_size/2:
      labels = attr_eval[start:end]

      res_eval = sess.run([m.d_loss_attr, m.pred_attr], 
                          {m.z: batch,
                           m.labels: labels,})
      
      y_true = labels
      y_pred = (res_eval[1] > 0.5)
      accuracy = np.mean(y_true == y_pred)
      
      traces['i_eval'].append(i)
      traces['D_loss_eval'].append(res_eval[0])
      traces['accuracy'].append(accuracy)
      

  if i % 100 == 0:
    print 'Step %d, \t TRAIN \t Loss: %0.3f \t EVAL \t Loss: %0.3f \t Accuracy: %0.3f' % (i, 
                                                                       np.mean(traces['D_loss'][-running_N_eval:]), 
                                                                       np.mean(traces['D_loss_eval'][-running_N_eval:]),
                                                                       np.mean(traces['accuracy'][-running_N_eval:]), 
                                                                       )

In [0]:
plt.figure(figsize=(18,18))
plt.subplot(3, 1, 1)
plt.semilogy(traces['i'], traces['D_loss'], label='train')
plt.semilogy(traces['i_eval'], traces['D_loss_eval'], label='eval')
plt.ylabel('Loss')
plt.legend(loc='upper right')

plt.subplot(3, 1, 2)
plt.plot(traces['i_eval'], traces['accuracy'], label="eval")
plt.ylabel('Prediction Accuracy')
plt.legend(loc='upper right')

In [0]:
# Prediction Accuracy
train_pred = []
eval_pred = []
test_pred = []

for i in range(n_train / 10 /batch_size):
  start = (i * batch_size)
  end = start + batch_size
  batch_images = train_data[start:end]

  start = i * batch_size
  end = start + batch_size
  res = sess.run([m.pred_attr], {m.x: batch_images, 
                                 m.labels: attr_train[start:end]})
  train_pred.append(res[0])
train_pred = np.vstack(train_pred)

for i in range(n_eval/batch_size):
  start = (i * batch_size)
  end = start + batch_size
  batch_images = eval_data[start:end]

  res = sess.run([m.pred_attr], {m.x: batch_images,
                                 m.labels: attr_eval[start:end]})
  eval_pred.append(res[0])
eval_pred = np.vstack(eval_pred)

train_acc = (train_pred > 0.5) == attr_train[:train_pred.shape[0]]
eval_acc = (eval_pred > 0.5) == attr_eval[:eval_pred.shape[0]]

print "Train Accuracy: %.4f" % (np.mean(train_acc) * 100)
print "Eval Accuracy: %.4f" % (np.mean(eval_acc) * 100)

In [0]:
y_true = attr_eval[:eval_pred.shape[0]]
y_pred = eval_pred >= 0.5

prec, recall, f1, support = sklearn.metrics.classification.precision_recall_fscore_support(y_true, y_pred, )
report = sklearn.metrics.classification_report(y_true, y_pred, digits=2, target_names=attribute_names)
print(report)