Data Science Summer School - Split '17

5. Generating images of digits with Generative Adversarial Networks


In [ ]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os, util

Goals:

  1. Implement the model from "Generative Adversarial Networks" by Goodfellow et al. (1284 citations since 2014.)

  2. Understand how the model learns to generate realistic images

In ~two hours.

5.1 Downloading the datasets and previewing data


In [ ]:
data_folder = 'data'; dataset = 'mnist'  # the folder in which the dataset is going to be stored

download_folder = util.download_mnist(data_folder, dataset)
images, labels = util.load_mnist(download_folder)

print("Folder:", download_folder)
print("Image shape:", images.shape) # greyscale, so the last dimension (color channel) = 1
print("Label shape:", labels.shape) # one-hot encoded

In [ ]:
show_n_images = 25
sample_images, mode = util.get_sample_images(images, n=show_n_images)
mnist_sample = util.images_square_grid(sample_images, mode)
plt.imshow(mnist_sample, cmap='gray')

In [ ]:
sample = images[3]*50 # 
sample = sample.reshape((28, 28))
print(np.array2string(sample.astype(int), max_line_width=100, separator=',', precision=0))

In [ ]:
plt.imshow(sample, cmap='gray')

What are we going to do with the data?

  • We have $70000$ images of hand-written digits generated from some distribution $X \sim P_{real}$
  • We have $70000$ labels $y_i \in \{0,..., 9\}$ indicating which digit is written on the image $x_i$

Problem: Imagine that the number of images we have is not enough - a common issue in computer vision and machine learning.

  1. We can pay experts to create new images
    • Expensive
    • Slow
    • Realiable
  2. We can generate new images ourselves
    • Cheap
    • Fast
    • Unreliable?

Problem: Not every image that we generate is going to be perfect (or even close to perfect). Therefore, we need some method to determine which images are realistic.

  1. We can pay experts to determine which images are good enough
    • Expensive
    • Slow
    • Reliable
  2. We can train a model to determine which images are good enough
    • Cheap
    • Fast
    • Unreliable?

Formalization

  • $X \sim P_{real}$ : existing images of shape $s$
  • $Z \sim P_z$ : a $k$-dimensional random vector
  • $G(z; \theta_G): Z \to \hat{X}$ : the generator, a function that transforms the random vector $z$ into an image of shape $s$
  • $D(x, \theta_D): X \to (Real, Fake)$ : the discriminator a function that given an image of shape $s$ decides if the image is real or fake

Details

The existing images $X$ in our setup are images from the mnist dataset. We will arbitrarily decide that vectors $z$ will be sampled from a uniform distribution, and $G$ and $D$ will both be 'deep' neural networks.

For simplicity, and since we are using the mnist dataset, both $G$ and $D$ will be multi-layer perceptrons (and not deep convolutional networks) with one hidden layer. The generated images $G(z) \sim P_{fake}$ as well as real images $x \sim P_{real}$ will be passed on to the discriminator, which will classify them into $(Real, Fake)$.

Figure 1. General adversarial network architecture

Discriminator

The goal of the discriminator is to successfully recognize which image is sampled from the true distribution, and which image is sampled from the generator.

Figure 2. Discriminator network sketch

Generator

The goal of the generator is that the discriminator missclassifies the images that the generator generated as if they were generated by the true distribution.

Figure 3. Generator network sketch

5.2 Data transformation

Since we are going to use a fully connected network (we are not going to use local convolutional filters), we are going to flatten the input images for simplicity. Also, the pixel values are scaled to the interval $[0,1]$ (this was already done beforehand).

We will also use a pre-made Dataset class to iterate over the dataset in batches. The class is defined in util.py, and only consists of a constructor and a method next_batch.

Question: Having seen the architecture of the network, why are we the pixels scaled to $[0,1]$ and not, for example, $[-1, 1]$, or left at $[0, 255]$?

Answer:


In [ ]:

5.3 The generator network


In [ ]:
class Generator:
    """The generator network
    
    the generator network takes as input a vector z of dimension input_dim, and transforms it 
    to a vector of size output_dim. The network has one hidden layer of size hidden_dim.
    
    We will define the following methods: 
    
    __init__: initializes all variables by using tf.get_variable(...) 
                and stores them to the class, as well a list in self.theta
    forward: defines the forward pass of the network - how do the variables
                interact with respect to the inputs
    """
    
    def __init__(self, input_dim, hidden_dim, output_dim):
        """Constructor for the generator network. In the constructor, we will
        just initialize all the variables in the network.
        
        Args:
            input_dim: The dimension of the input data vector (z).
            hidden_dim: The dimension of the hidden layer of the neural network (h)
            output_dim: The dimension of the output layer (equivalent to the size of the image)            
            
        """
        
        with tf.variable_scope("generator"):
            pass
    
    def forward(self, z):
        """The forward pass of the network -- here we will define the logic of how we combine
        the variables through multiplication and activation functions in order to get the
        output.
        
        """
        pass

5.4 The basic network for the discriminator


In [ ]:
class Discriminator:
    """The discriminator network
    
    the discriminator network takes as input a vector x of dimension input_dim, and transforms it 
    to a vector of size output_dim. The network has one hidden layer of size hidden_dim.
    
    You will define the following methods: 
    
    __init__: initializes all variables by using tf.get_variable(...) 
                and stores them to the class, as well a list in self.theta
    forward: defines the forward pass of the network - how do the variables
                interact with respect to the inputs
    """
    
    def __init__(self, input_dim, hidden_dim, output_dim):
        
        with tf.variable_scope("discriminator"):
            pass

    def forward(self, x):
        """The forward pass of the network -- here we will define the logic of how we combine
        the variables through multiplication and activation functions in order to get the
        output.
        
        Along with the probabilities, also return the unnormalized probabilities
        (the values in the output layer before being passed through the sigmoid function)
        """
        pass

Intermezzo: Xavier initialization of weights

Glorot, X., & Bengio, Y. (2010, March). Understanding the difficulty of training deep feedforward neural networks. In Proceedings of the Thirteenth International Conference on Artificial Intelligence and Statistics (pp. 249-256).

Implemented in tensorflow, as part of the standard library: https://www.tensorflow.org/api_docs/python/tf/contrib/layers/xavier_initializer

1. Idea:

  • If the weights in a network are initialized to too small values, then the signal shrinks as it passes through each layer until it’s too tiny to be useful.
  • If the weights in a network are initialized to too large, then the signal grows as it passes through each layer until it’s too massive to be useful.

2. Goal:

  • We need initial weight values that are just right for the signal not to explode or vanish during the forward pass

3. Math

  • Trivial

4. Solution

  • $v = \frac{2}{n_{in} + n_{out}}$

In the case of a Gaussian distribution, we set the variance to $v$.

In the case of a uniform distribution, we set the interval to $\pm v$ (the default distr. in tensorflow is the uniform).

http://andyljones.tumblr.com/post/110998971763/an-explanation-of-xavier-initialization

5.5 Define the model parameters

We will take a brief break to set the values for the parameters of the model. Since we know the dataset we are working with, as well as the shape of the generator and discriminator networks, your task is to fill in the values of the following variables.


In [ ]:
image_dim = # The dimension of the input image vector to the discrminator
discriminator_hidden_dim = # The dimension of the hidden layer of the discriminator
discriminator_output_dim = # The dimension of the output layer of the discriminator 

random_sample_dim =  # The dimension of the random noise vector z
generator_hidden_dim = # The dimension of the hidden layer of the generator
generator_output_dim = # The dimension of the output layer of the generator

5.6 Check the implementation of the classes


In [ ]:
d = Discriminator(image_dim, discriminator_hidden_dim, discriminator_output_dim)
for param in d.theta:
    print (param)

In [ ]:
g = Generator(random_sample_dim, generator_hidden_dim, generator_output_dim)
for param in g.theta:
    print (param)

Drawing samples from the latent space


In [ ]:
def sample_Z(m, n):
    pass

plt.imshow(sample_Z(16, 100), cmap='gray')

5.5 Define the model loss -- Vanilla GAN

The objective for the vanilla version of the GAN was defined as follows:

$\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{real}} [log(D(x))] + \mathbb{E}_{z \sim p_{z}} [log(1 -D(G(z)))]$

The function contains a minimax formulation, and cannot be directly optimized. However, if we freeze $D$, we can derive the loss for $G$ and vice versa.

Discriminator loss:

$p_{fake} = G(p_z)$
$D_{loss} = \mathbb{E}_{x \sim p_{real}} [log(D(x))] + \mathbb{E}_{\hat{x} \sim p_{fake}} [log(1 -D(\hat{x}))]$

We estimate the expectation over each minibatch and arrive to the following formulation:

$D_{loss} = \frac{1}{m}\sum_{i=0}^{m} log(D(x_i)) + \frac{1}{m}\sum_{i=0}^{m} log(1 -D(\hat{x_i}))$

Generator loss:

$G_{loss} = - \mathbb{E}_{z \sim p_{z}} [log(1 -D(G(z)))]$
$G_{loss} = \frac{1}{m}\sum_{i=0}^{m} [log(D(G(z)))]$

Model loss, translated from math

The discriminator wants to:

  • maximize the (log) probability of a real image being classified as real,
  • minimize the (log) probability of a fake image being classified as real.

The generator wants to:

  • maximize the (log) probability of a fake image being classified as real.

Model loss, translated to practical machine learning

The output of the discriminator is a scalar, $p$, which we interpret as the probability that an input image is real ($1-p$ is the probability that the image is fake).

The discriminator takes as input:

  • a minibatch of images from our training set with a vector of ones for class labels: $D_{loss\_real}$.
  • a minibatch of images from the generator with a vector of zeros for class labels: $D_{loss\_fake}$.
  • a minibatch of images from the generator with a vector of ones for class labels: $G_{loss}$.

The generator takes as input:

  • a minibatch of vectors sampled from the latent space and transforms them to a minibatch of generated images

In [ ]:

Intermezzo: sigmoid cross entropy with logits

We defined the loss of the model as the log of the probability, but we are not using a $log$ function or the model probablities anywhere?

Enter sigmoid cross entropy with logits: https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits

From the tensorflow documentation

Putting it all together


In [ ]:
X = tf.placeholder(tf.float32, name="input", shape=[None, image_dim])
Z = tf.placeholder(tf.float32, name="latent_sample", shape=[None, random_sample_dim])

G_sample, D_loss, G_loss = gan_model_loss(X, Z, d, g)

with tf.variable_scope('optim'):
    D_solver = tf.train.AdamOptimizer(name='discriminator').minimize(D_loss, var_list=d.theta)
    G_solver = tf.train.AdamOptimizer(name='generator').minimize(G_loss, var_list=g.theta)

In [ ]:
saver = tf.train.Saver()

# Some runtime parameters predefined for you
minibatch_size = 128 # The size of the minibatch

num_epoch = 500 # For how many epochs do we run the training
plot_every_epochs = 5 # After this many epochs we will save & display samples of generated images 
print_every_batches = 1000 # After this many minibatches we will print the losses

restore = True
checkpoint = 'fc_2layer_e100_2.170.ckpt'
model = 'gan'
model_save_folder = os.path.join('data', 'chkp', model)
print ("Model checkpoints will be saved to:", model_save_folder)
image_save_folder = os.path.join('data', 'model_output', model)
print ("Image samples will be saved to:", image_save_folder)

In [ ]:
minibatch_counter = 0
epoch_counter = 0

d_losses = []
g_losses = []

with tf.device("/gpu:0"), tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    if restore:
        saver.restore(sess, os.path.join(model_save_folder, checkpoint))
        print("Restored model:", checkpoint, "from:", model_save_folder)
                      
    while epoch_counter < num_epoch:
            
        new_epoch, X_mb = mnist.next_batch(minibatch_size)

        _, D_loss_curr = sess.run([D_solver, D_loss], 
                                  feed_dict={
                                      X: X_mb, 
                                      Z: sample_Z(minibatch_size, random_sample_dim)
                                    })
                      
        _, G_loss_curr = sess.run([G_solver, G_loss], 
                                  feed_dict={
                                      Z: sample_Z(minibatch_size, random_sample_dim)
                                  })

        # Plotting and saving images and the model
        if new_epoch and epoch_counter % plot_every_epochs == 0:
            samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, random_sample_dim)})

            fig = util.plot(samples)
            figname = '{}.png'.format(str(minibatch_counter).zfill(3))
            plt.savefig(os.path.join(image_save_folder, figname), bbox_inches='tight')
            plt.show()
            plt.close(fig)
            
            im = util.plot_single(samples[0], epoch_counter)
            plt.savefig(os.path.join(image_save_folder, 'single_' + figname), bbox_inches='tight')
            plt.show()
            
            chkpname = "fc_2layer_e{}_{:.3f}.ckpt".format(epoch_counter, G_loss_curr)
            saver.save(sess, os.path.join(model_save_folder, chkpname))

        # Printing runtime statistics
        if minibatch_counter % print_every_batches == 0:
            print('Epoch: {}/{}'.format(epoch_counter, num_epoch))
            print('Iter: {}/{}'.format(mnist.position_in_epoch, mnist.n))
            print('Discriminator loss: {:.4}'. format(D_loss_curr))
            print('Generator loss: {:.4}'.format(G_loss_curr))
            print()
        
        # Bookkeeping
        minibatch_counter += 1
        if new_epoch:
            epoch_counter += 1
        
        d_losses.append(D_loss_curr)
        g_losses.append(G_loss_curr)
        
    # Save the final model
    chkpname = "fc_2layer_e{}_{:.3f}.ckpt".format(epoch_counter, G_loss_curr)
    saver.save(sess, os.path.join(model_save_folder, chkpname))

In [ ]:
disc_line, = plt.plot(range(len(d_losses[:10000])), d_losses[:10000], c='b', label="Discriminator loss")
gen_line, = plt.plot(range(len(d_losses[:10000])), g_losses[:10000], c='r', label="Generator loss")
plt.legend([disc_line, gen_line], ["Discriminator loss", "Generator loss"])

Mode collapse

Example of mode collapse in a GAN

Second image is from: Reed, S., van den Oord, A., Kalchbrenner, N., Bapst, V., Botvinick, M., & de Freitas, N. (2016). Generating interpretable images with controllable structure.


In [ ]: