Generative Adversarial Network

In this notebook, we'll be building a generative adversarial network (GAN) trained on the MNIST dataset. From this, we'll be able to generate new handwritten digits!

GANs were first reported on in 2014 from Ian Goodfellow and others in Yoshua Bengio's lab. Since then, GANs have exploded in popularity. Here are a few examples to check out:

The idea behind GANs is that you have two networks, a generator $G$ and a discriminator $D$, competing against each other. The generator makes fake data to pass to the discriminator. The discriminator also sees real data and predicts if the data it's received is real or fake. The generator is trained to fool the discriminator, it wants to output data that looks as close as possible to real data. And the discriminator is trained to figure out which data is real and which is fake. What ends up happening is that the generator learns to make data that is indistiguishable from real data to the discriminator.

The general structure of a GAN is shown in the diagram above, using MNIST images as data. The latent sample is a random vector the generator uses to contruct it's fake images. As the generator learns through training, it figures out how to map these random vectors to recognizable images that can foold the discriminator.

The output of the discriminator is a sigmoid function, where 0 indicates a fake image and 1 indicates an real image. If you're interested only in generating new images, you can throw out the discriminator after training. Now, let's see how we build this thing in TensorFlow.


In [1]:
%matplotlib inline

import pickle as pkl
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
print(tf.__version__)


1.1.0

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data')


Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting MNIST_data\train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz

Model Inputs

First we need to create the inputs for our graph. We need two inputs, one for the discriminator and one for the generator. Here we'll call the discriminator input inputs_real and the generator input inputs_z. We'll assign them the appropriate sizes for each of the networks.

Exercise: Finish the model_inputs function below. Create the placeholders for inputs_real and inputs_z using the input sizes real_dim and z_dim respectively.


In [14]:
def model_inputs(real_dim, z_dim):
    inputs_real = tf.placeholder(tf.float32, shape = (None, real_dim), name="inputs_real")
    inputs_z = tf.placeholder(tf.float32, shape = (None, z_dim), name ="inputs_z")
    
    return inputs_real, inputs_z

Generator network

Here we'll build the generator network. To make this network a universal function approximator, we'll need at least one hidden layer. We should use a leaky ReLU to allow gradients to flow backwards through the layer unimpeded. A leaky ReLU is like a normal ReLU, except that there is a small non-zero output for negative input values.

Variable Scope

Here we need to use tf.variable_scope for two reasons. Firstly, we're going to make sure all the variable names start with generator. Similarly, we'll prepend discriminator to the discriminator variables. This will help out later when we're training the separate networks.

We could just use tf.name_scope to set the names, but we also want to reuse these networks with different inputs. For the generator, we're going to train it, but also sample from it as we're training and after training. The discriminator will need to share variables between the fake and real input images. So, we can use the reuse keyword for tf.variable_scope to tell TensorFlow to reuse the variables instead of creating new ones if we build the graph again.

To use tf.variable_scope, you use a with statement:

with tf.variable_scope('scope_name', reuse=False):
    # code here

Here's more from the TensorFlow documentation to get another look at using tf.variable_scope.

Leaky ReLU

TensorFlow doesn't provide an operation for leaky ReLUs, so we'll need to make one . For this you can use take the outputs from a linear fully connected layer and pass them to tf.maximum. Typically, a parameter alpha sets the magnitude of the output for negative values. So, the output for negative input (x) values is alpha*x, and the output for positive x is x: $$ f(x) = max(\alpha * x, x) $$

Tanh Output

The generator has been found to perform the best with $tanh$ for the generator output. This means that we'll have to rescale the MNIST images to be between -1 and 1, instead of 0 and 1.

Exercise: Implement the generator network in the function below. You'll need to return the tanh output. Make sure to wrap your code in a variable scope, with 'generator' as the scope name, and pass the reuse keyword argument from the function to tf.variable_scope.


In [36]:
def generator(z, out_dim, n_units=128, reuse=False,  alpha=0.01):
    ''' Build the generator network.
    
        Arguments
        ---------
        z : Input tensor for the generator
        out_dim : Shape of the generator output
        n_units : Number of units in hidden layer
        reuse : Reuse the variables with tf.variable_scope
        alpha : leak parameter for leaky ReLU
        
        Returns
        -------
        out, logits: 
    '''
    with tf.variable_scope('Generator', reuse=reuse):
        # Hidden layer
        h1 = tf.layers.dense(z, n_units, activation = None)
        # Leaky ReLU
        h1 = tf.maximum( (alpha * h1),  h1)
        
        # Logits and tanh output
        logits = tf.layers.dense(h1, out_dim, activation = None)
        out = tf.tanh(logits)
        
        return out

Discriminator

The discriminator network is almost exactly the same as the generator network, except that we're using a sigmoid output layer.

Exercise: Implement the discriminator network in the function below. Same as above, you'll need to return both the logits and the sigmoid output. Make sure to wrap your code in a variable scope, with 'discriminator' as the scope name, and pass the reuse keyword argument from the function arguments to tf.variable_scope.


In [40]:
def discriminator(x, n_units=128, reuse=False, alpha=0.01):
    ''' Build the discriminator network.
    
        Arguments
        ---------
        x : Input tensor for the discriminator
        n_units: Number of units in hidden layer
        reuse : Reuse the variables with tf.variable_scope
        alpha : leak parameter for leaky ReLU
        
        Returns
        -------
        out, logits: 
    '''
    with tf.variable_scope('Discriminator', reuse=reuse):
        # Hidden layer
        h1 = tf.layers.dense(x,  n_units, activation = None)
        # Leaky ReLU
        h1 = tf.maximum ( (alpha * h1), h1)
        
        logits = tf.layers.dense(h1, 1, activation = None)
        out = tf.sigmoid(logits)
        
        return out, logits

Hyperparameters


In [96]:
# Size of input image to discriminator
input_size = 784 # 28x28 MNIST images flattened
# Size of latent vector to generator
z_size = 784
# Sizes of hidden layers in generator and discriminator
g_hidden_size = 256
d_hidden_size = 256
# Leak factor for leaky ReLU
alpha = 0.01
# Label smoothing 
smooth = 0.1

Build network

Now we're building the network from the functions defined above.

First is to get our inputs, input_real, input_z from model_inputs using the sizes of the input and z.

Then, we'll create the generator, generator(input_z, input_size). This builds the generator with the appropriate input and output sizes.

Then the discriminators. We'll build two of them, one for real data and one for fake data. Since we want the weights to be the same for both real and fake data, we need to reuse the variables. For the fake data, we're getting it from the generator as g_model. So the real data discriminator is discriminator(input_real) while the fake discriminator is discriminator(g_model, reuse=True).

Exercise: Build the network from the functions you defined earlier.


In [97]:
tf.reset_default_graph()
# Create our input placeholders
input_real, input_z = model_inputs(input_size, z_size)

# Generator network here
g_model = generator(input_z, input_size)
# g_model is the generator output

# Disriminator network here
d_model_real, d_logits_real = discriminator(input_real)
d_model_fake, d_logits_fake = discriminator(g_model, reuse=True)

Discriminator and Generator Losses

Now we need to calculate the losses, which is a little tricky. For the discriminator, the total loss is the sum of the losses for real and fake images, d_loss = d_loss_real + d_loss_fake. The losses will by sigmoid cross-entropys, which we can get with tf.nn.sigmoid_cross_entropy_with_logits. We'll also wrap that in tf.reduce_mean to get the mean for all the images in the batch. So the losses will look something like

tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))

For the real image logits, we'll use d_logits_real which we got from the discriminator in the cell above. For the labels, we want them to be all ones, since these are all real images. To help the discriminator generalize better, the labels are reduced a bit from 1.0 to 0.9, for example, using the parameter smooth. This is known as label smoothing, typically used with classifiers to improve performance. In TensorFlow, it looks something like labels = tf.ones_like(tensor) * (1 - smooth)

The discriminator loss for the fake data is similar. The logits are d_logits_fake, which we got from passing the generator output to the discriminator. These fake logits are used with labels of all zeros. Remember that we want the discriminator to output 1 for real images and 0 for fake images, so we need to set up the losses to reflect that.

Finally, the generator losses are using d_logits_fake, the fake image logits. But, now the labels are all ones. The generator is trying to fool the discriminator, so it wants to discriminator to output ones for fake images.

Exercise: Calculate the losses for the discriminator and the generator. There are two discriminator losses, one for real images and one for fake images. For the real image loss, use the real logits and (smoothed) labels of ones. For the fake image loss, use the fake logits with labels of all zeros. The total discriminator loss is the sum of those two losses. Finally, the generator loss again uses the fake logits from the discriminator, but this time the labels are all ones because the generator wants to fool the discriminator.


In [98]:
# Calculate losses

# One's like for real labels for Discriminator 
real_labels = tf.ones_like(d_logits_real) * (1 - smooth)

d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
                                    logits = d_logits_real, labels=real_labels))

# Zeros's like for real labels for Discriminator 
fake_labels = tf.zeros_like(d_logits_real)

d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits = d_logits_fake, labels= fake_labels))


d_loss = d_loss_real + d_loss_fake


# One's like for fake labels for generator
generated_labels = tf.ones_like(d_logits_fake)

g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = d_logits_fake, 
                                                                labels = generated_labels))

Optimizers

We want to update the generator and discriminator variables separately. So we need to get the variables for each part build optimizers for the two parts. To get all the trainable variables, we use tf.trainable_variables(). This creates a list of all the variables we've defined in our graph.

For the generator optimizer, we only want to generator variables. Our past selves were nice and used a variable scope to start all of our generator variable names with generator. So, we just need to iterate through the list from tf.trainable_variables() and keep variables to start with generator. Each variable object has an attribute name which holds the name of the variable as a string (var.name == 'weights_0' for instance).

We can do something similar with the discriminator. All the variables in the discriminator start with discriminator.

Then, in the optimizer we pass the variable lists to var_list in the minimize method. This tells the optimizer to only update the listed variables. Something like tf.train.AdamOptimizer().minimize(loss, var_list=var_list) will only train the variables in var_list.

Exercise: Below, implement the optimizers for the generator and discriminator. First you'll need to get a list of trainable variables, then split that list into two lists, one for the generator variables and another for the discriminator variables. Finally, using AdamOptimizer, create an optimizer for each network that update the network variables separately.


In [99]:
# Optimizers
learning_rate = 0.002

# Get the trainable_variables, split into G and D parts
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if var.name.startswith('Generator')]
d_vars = [var for var in t_vars if var.name.startswith('Discriminator')]

d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list = d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list = g_vars)

Training


In [100]:
batch_size = 100
epochs = 80
samples = []
losses = []
saver = tf.train.Saver(var_list = g_vars)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(epochs):
        for ii in range(mnist.train.num_examples//batch_size):
            batch = mnist.train.next_batch(batch_size)
            
            # Get images, reshape and rescale to pass to D
            batch_images = batch[0].reshape((batch_size, 784))
            batch_images = batch_images*2 - 1
            
            # Sample random noise for G
            batch_z = np.random.uniform(-1, 1, size=(batch_size, z_size))
            
            # Run optimizers
            _ = sess.run(d_train_opt, feed_dict={input_real: batch_images, input_z: batch_z})
            _ = sess.run(g_train_opt, feed_dict={input_z: batch_z})
        
        # At the end of each epoch, get the losses and print them out
        train_loss_d = sess.run(d_loss, {input_z: batch_z, input_real: batch_images})
        train_loss_g = g_loss.eval({input_z: batch_z})
            
        print("Epoch {}/{}...".format(e+1, epochs),
              "Discriminator Loss: {:.4f}...".format(train_loss_d),
              "Generator Loss: {:.4f}".format(train_loss_g),
              "Difference Loss: {:.4f}...".format(train_loss_d-train_loss_g),
             )    
        # Save losses to view after training
        losses.append((train_loss_d, train_loss_g))
        
        # Sample from generator as we're training for viewing afterwards
        sample_z = np.random.uniform(-1, 1, size=(16, z_size))
        gen_samples = sess.run(
                       generator(input_z, input_size, reuse=True),
                       feed_dict={input_z: sample_z})
        samples.append(gen_samples)
        saver.save(sess, './checkpoints/generator.ckpt')

# Save training generator samples
with open('train_samples.pkl', 'wb') as f:
    pkl.dump(samples, f)


Epoch 1/80... Discriminator Loss: 0.4425... Generator Loss: 2.8760 Difference Loss: -2.4335...
Epoch 2/80... Discriminator Loss: 0.8649... Generator Loss: 4.6707 Difference Loss: -3.8059...
Epoch 3/80... Discriminator Loss: 0.5695... Generator Loss: 3.7862 Difference Loss: -3.2167...
Epoch 4/80... Discriminator Loss: 1.1659... Generator Loss: 2.9606 Difference Loss: -1.7947...
Epoch 5/80... Discriminator Loss: 1.8704... Generator Loss: 1.8585 Difference Loss: 0.0119...
Epoch 6/80... Discriminator Loss: 1.8201... Generator Loss: 1.9842 Difference Loss: -0.1641...
Epoch 7/80... Discriminator Loss: 1.6348... Generator Loss: 0.9816 Difference Loss: 0.6532...
Epoch 8/80... Discriminator Loss: 2.0758... Generator Loss: 1.6577 Difference Loss: 0.4181...
Epoch 9/80... Discriminator Loss: 1.7692... Generator Loss: 2.2681 Difference Loss: -0.4989...
Epoch 10/80... Discriminator Loss: 1.4375... Generator Loss: 1.6568 Difference Loss: -0.2193...
Epoch 11/80... Discriminator Loss: 1.3330... Generator Loss: 3.2855 Difference Loss: -1.9526...
Epoch 12/80... Discriminator Loss: 1.0370... Generator Loss: 1.7489 Difference Loss: -0.7119...
Epoch 13/80... Discriminator Loss: 2.4128... Generator Loss: 1.7430 Difference Loss: 0.6697...
Epoch 14/80... Discriminator Loss: 0.9600... Generator Loss: 1.8819 Difference Loss: -0.9220...
Epoch 15/80... Discriminator Loss: 1.4469... Generator Loss: 1.1186 Difference Loss: 0.3282...
Epoch 16/80... Discriminator Loss: 1.4471... Generator Loss: 1.6183 Difference Loss: -0.1713...
Epoch 17/80... Discriminator Loss: 1.0488... Generator Loss: 1.7481 Difference Loss: -0.6993...
Epoch 18/80... Discriminator Loss: 1.5066... Generator Loss: 2.4485 Difference Loss: -0.9419...
Epoch 19/80... Discriminator Loss: 1.1833... Generator Loss: 1.8220 Difference Loss: -0.6387...
Epoch 20/80... Discriminator Loss: 1.3408... Generator Loss: 1.2924 Difference Loss: 0.0484...
Epoch 21/80... Discriminator Loss: 1.3709... Generator Loss: 1.0623 Difference Loss: 0.3086...
Epoch 22/80... Discriminator Loss: 1.3033... Generator Loss: 1.3027 Difference Loss: 0.0007...
Epoch 23/80... Discriminator Loss: 1.2793... Generator Loss: 1.3861 Difference Loss: -0.1068...
Epoch 24/80... Discriminator Loss: 1.0635... Generator Loss: 2.6765 Difference Loss: -1.6130...
Epoch 25/80... Discriminator Loss: 0.9273... Generator Loss: 2.5718 Difference Loss: -1.6445...
Epoch 26/80... Discriminator Loss: 1.2080... Generator Loss: 1.5664 Difference Loss: -0.3584...
Epoch 27/80... Discriminator Loss: 0.9275... Generator Loss: 2.1867 Difference Loss: -1.2592...
Epoch 28/80... Discriminator Loss: 1.1091... Generator Loss: 1.6780 Difference Loss: -0.5689...
Epoch 29/80... Discriminator Loss: 1.1394... Generator Loss: 1.5004 Difference Loss: -0.3610...
Epoch 30/80... Discriminator Loss: 1.1007... Generator Loss: 1.6363 Difference Loss: -0.5356...
Epoch 31/80... Discriminator Loss: 0.9284... Generator Loss: 1.6571 Difference Loss: -0.7287...
Epoch 32/80... Discriminator Loss: 1.0327... Generator Loss: 1.5925 Difference Loss: -0.5598...
Epoch 33/80... Discriminator Loss: 1.0271... Generator Loss: 1.5354 Difference Loss: -0.5083...
Epoch 34/80... Discriminator Loss: 0.9506... Generator Loss: 1.6436 Difference Loss: -0.6931...
Epoch 35/80... Discriminator Loss: 0.9758... Generator Loss: 1.8441 Difference Loss: -0.8683...
Epoch 36/80... Discriminator Loss: 1.3653... Generator Loss: 1.3617 Difference Loss: 0.0036...
Epoch 37/80... Discriminator Loss: 1.1974... Generator Loss: 1.1603 Difference Loss: 0.0371...
Epoch 38/80... Discriminator Loss: 1.1631... Generator Loss: 1.2691 Difference Loss: -0.1060...
Epoch 39/80... Discriminator Loss: 1.2434... Generator Loss: 1.4074 Difference Loss: -0.1640...
Epoch 40/80... Discriminator Loss: 0.9168... Generator Loss: 1.8218 Difference Loss: -0.9050...
Epoch 41/80... Discriminator Loss: 1.1142... Generator Loss: 1.3254 Difference Loss: -0.2113...
Epoch 42/80... Discriminator Loss: 1.1721... Generator Loss: 1.4403 Difference Loss: -0.2682...
Epoch 43/80... Discriminator Loss: 1.2639... Generator Loss: 1.4644 Difference Loss: -0.2005...
Epoch 44/80... Discriminator Loss: 1.3026... Generator Loss: 1.3177 Difference Loss: -0.0151...
Epoch 45/80... Discriminator Loss: 1.2626... Generator Loss: 1.4706 Difference Loss: -0.2081...
Epoch 46/80... Discriminator Loss: 1.0161... Generator Loss: 1.6525 Difference Loss: -0.6364...
Epoch 47/80... Discriminator Loss: 1.1347... Generator Loss: 1.4505 Difference Loss: -0.3158...
Epoch 48/80... Discriminator Loss: 1.3174... Generator Loss: 1.3299 Difference Loss: -0.0126...
Epoch 49/80... Discriminator Loss: 1.0591... Generator Loss: 1.4575 Difference Loss: -0.3984...
Epoch 50/80... Discriminator Loss: 1.1734... Generator Loss: 1.9260 Difference Loss: -0.7527...
Epoch 51/80... Discriminator Loss: 1.0059... Generator Loss: 1.7600 Difference Loss: -0.7541...
Epoch 52/80... Discriminator Loss: 1.2851... Generator Loss: 1.1668 Difference Loss: 0.1183...
Epoch 53/80... Discriminator Loss: 1.1780... Generator Loss: 1.2625 Difference Loss: -0.0846...
Epoch 54/80... Discriminator Loss: 1.0170... Generator Loss: 1.5435 Difference Loss: -0.5265...
Epoch 55/80... Discriminator Loss: 1.1132... Generator Loss: 1.3062 Difference Loss: -0.1929...
Epoch 56/80... Discriminator Loss: 1.1084... Generator Loss: 1.3557 Difference Loss: -0.2473...
Epoch 57/80... Discriminator Loss: 1.2846... Generator Loss: 1.2462 Difference Loss: 0.0384...
Epoch 58/80... Discriminator Loss: 1.0062... Generator Loss: 1.4142 Difference Loss: -0.4080...
Epoch 59/80... Discriminator Loss: 1.2745... Generator Loss: 1.1870 Difference Loss: 0.0875...
Epoch 60/80... Discriminator Loss: 1.2649... Generator Loss: 1.2455 Difference Loss: 0.0195...
Epoch 61/80... Discriminator Loss: 1.2507... Generator Loss: 1.3175 Difference Loss: -0.0667...
Epoch 62/80... Discriminator Loss: 1.1241... Generator Loss: 1.2867 Difference Loss: -0.1626...
Epoch 63/80... Discriminator Loss: 1.1645... Generator Loss: 1.2192 Difference Loss: -0.0547...
Epoch 64/80... Discriminator Loss: 1.2231... Generator Loss: 1.0268 Difference Loss: 0.1963...
Epoch 65/80... Discriminator Loss: 1.0558... Generator Loss: 1.5315 Difference Loss: -0.4757...
Epoch 66/80... Discriminator Loss: 1.0592... Generator Loss: 1.4043 Difference Loss: -0.3450...
Epoch 67/80... Discriminator Loss: 1.1810... Generator Loss: 1.4129 Difference Loss: -0.2319...
Epoch 68/80... Discriminator Loss: 1.1422... Generator Loss: 1.4156 Difference Loss: -0.2734...
Epoch 69/80... Discriminator Loss: 1.0945... Generator Loss: 1.3094 Difference Loss: -0.2148...
Epoch 70/80... Discriminator Loss: 1.1504... Generator Loss: 1.3257 Difference Loss: -0.1753...
Epoch 71/80... Discriminator Loss: 1.1245... Generator Loss: 1.1868 Difference Loss: -0.0622...
Epoch 72/80... Discriminator Loss: 1.0467... Generator Loss: 1.2523 Difference Loss: -0.2056...
Epoch 73/80... Discriminator Loss: 1.1899... Generator Loss: 1.1791 Difference Loss: 0.0108...
Epoch 74/80... Discriminator Loss: 1.0329... Generator Loss: 1.4332 Difference Loss: -0.4003...
Epoch 75/80... Discriminator Loss: 1.2118... Generator Loss: 1.2174 Difference Loss: -0.0056...
Epoch 76/80... Discriminator Loss: 1.1392... Generator Loss: 1.2048 Difference Loss: -0.0656...
Epoch 77/80... Discriminator Loss: 1.1396... Generator Loss: 1.5353 Difference Loss: -0.3957...
Epoch 78/80... Discriminator Loss: 1.0808... Generator Loss: 1.6528 Difference Loss: -0.5719...
Epoch 79/80... Discriminator Loss: 1.1598... Generator Loss: 1.4892 Difference Loss: -0.3294...
Epoch 80/80... Discriminator Loss: 1.1646... Generator Loss: 1.3657 Difference Loss: -0.2011...

Results with 128 hidden units

Epoch 72/100... Discriminator Loss: 1.2292... Generator Loss: 1.0937 Difference Loss: 0.1355... Epoch 73/100... Discriminator Loss: 1.1977... Generator Loss: 1.0838 Difference Loss: 0.1139... Epoch 74/100... Discriminator Loss: 1.0160... Generator Loss: 1.4791 Difference Loss: -0.4632... Epoch 75/100... Discriminator Loss: 1.1122... Generator Loss: 1.0486 Difference Loss: 0.0637... Epoch 76/100... Discriminator Loss: 1.0662... Generator Loss: 1.5303 Difference Loss: -0.4641... Epoch 77/100... Discriminator Loss: 1.1943... Generator Loss: 1.1728 Difference Loss: 0.0215... Epoch 78/100... Discriminator Loss: 1.1579... Generator Loss: 1.3853 Difference Loss: -0.2274... Epoch 79/100... Discriminator Loss: 1.1481... Generator Loss: 1.1773 Difference Loss: -0.0292... Epoch 80/100... Discriminator Loss: 1.1529... Generator Loss: 1.6801 Difference Loss: -0.5272...

Training loss

Here we'll check out the training losses for the generator and discriminator.


In [71]:
%matplotlib inline

import matplotlib.pyplot as plt

In [72]:
# With 128 hidden
fig, ax = plt.subplots()
losses = np.array(losses)
plt.plot(losses.T[0], label='Discriminator')
plt.plot(losses.T[1], label='Generator')
plt.title("Training Losses")
plt.legend()


Out[72]:
<matplotlib.legend.Legend at 0x2bab019bba8>

In [101]:
# With 256 hidden
fig, ax = plt.subplots()
losses = np.array(losses)
plt.plot(losses.T[0], label='Discriminator')
plt.plot(losses.T[1], label='Generator')
plt.title("Training Losses")
plt.legend()


Out[101]:
<matplotlib.legend.Legend at 0x2bab09f7748>

Generator samples from training

Here we can view samples of images from the generator. First we'll look at images taken while training.


In [102]:
def view_samples(epoch, samples):
    fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)
    for ax, img in zip(axes.flatten(), samples[epoch]):
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')
    
    return fig, axes

In [103]:
# Load samples from generator taken while training
with open('train_samples.pkl', 'rb') as f:
    samples = pkl.load(f)

In [104]:
plt.imshow(mnist.train.images[3].reshape(28,28), cmap='Greys_r')


Out[104]:
<matplotlib.image.AxesImage at 0x2bab4a00c18>

These are samples from the final training epoch. You can see the generator is able to reproduce numbers like 5, 7, 3, 0, 9. Since this is just a sample, it isn't representative of the full range of images this generator can make.


In [75]:
# with 128
_ = view_samples(-1, samples)



In [105]:
# with 256
_ = view_samples(-1, samples)


Below I'm showing the generated images as the network was training, every 10 epochs. With bonus optical illusion!


In [106]:
# with 256
rows, cols = 10, 6
fig, axes = plt.subplots(figsize=(7,12), nrows=rows, ncols=cols, sharex=True, sharey=True)

for sample, ax_row in zip(samples[::int(len(samples)/rows)], axes):
    for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):
        ax.imshow(img.reshape((28,28)), cmap='Greys_r')
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)



In [76]:
# with 128
rows, cols = 10, 6
fig, axes = plt.subplots(figsize=(7,12), nrows=rows, ncols=cols, sharex=True, sharey=True)

for sample, ax_row in zip(samples[::int(len(samples)/rows)], axes):
    for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):
        ax.imshow(img.reshape((28,28)), cmap='Greys_r')
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)


It starts out as all noise. Then it learns to make only the center white and the rest black. You can start to see some number like structures appear out of the noise. Looks like 1, 9, and 8 show up first. Then, it learns 5 and 3.

Sampling from the generator

We can also get completely new images from the generator by using the checkpoint we saved after training. We just need to pass in a new latent vector $z$ and we'll get new samples!


In [77]:
saver = tf.train.Saver(var_list=g_vars)
with tf.Session() as sess:
    saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
    sample_z = np.random.uniform(-1, 1, size=(16, z_size))
    gen_samples = sess.run(
                   generator(input_z, input_size, reuse=True),
                   feed_dict={input_z: sample_z})
view_samples(0, [gen_samples])


INFO:tensorflow:Restoring parameters from checkpoints\generator.ckpt
Out[77]:
(<matplotlib.figure.Figure at 0x2baa8d03438>,
 array([[<matplotlib.axes._subplots.AxesSubplot object at 0x000002BAB2A9E5C0>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x000002BAB2B5AA58>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x000002BAB2ABB080>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x000002BAB2DAB828>],
        [<matplotlib.axes._subplots.AxesSubplot object at 0x000002BAB2A25BA8>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x000002BAB2CB86D8>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x000002BAB2D07860>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x000002BAB2D449E8>],
        [<matplotlib.axes._subplots.AxesSubplot object at 0x000002BAB32D3B70>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x000002BAB3311F98>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x000002BAB3367128>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x000002BAB33A5898>],
        [<matplotlib.axes._subplots.AxesSubplot object at 0x000002BAB33F29E8>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x000002BAB3437198>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x000002BAB3483400>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x000002BAB34896D8>]], dtype=object))

In [ ]: