This notebook contains the second code sample found in Chapter 8, Section 5 of Deep Learning with Python. Note that the original text features far more content, in particular further explanations and figures: in this notebook, you will only find source code and related comments.
In what follows, we explain how to implement a GAN in Keras, in its barest form -- since GANs are quite advanced, diving deeply into the
technical details would be out of scope for us. Our specific implementation will be a deep convolutional GAN, or DCGAN: a GAN where the
generator and discriminator are deep convnets. In particular, it leverages a Conv2DTranspose
layer for image upsampling in the generator.
We will train our GAN on images from CIFAR10, a dataset of 50,000 32x32 RGB images belong to 10 classes (5,000 images per class). To make things even easier, we will only use images belonging to the class "frog".
Schematically, our GAN looks like this:
network maps vectors of shape (latent_dim,)
to images of shape (32, 32, 3)
network maps images of shape (32, 32, 3) to a binary score estimating the probability that the image is real.gan
network chains the generator and the discriminator together: gan(x) = discriminator(generator(x))
. Thus this gan
network maps
latent space vectors to the discriminator's assessment of the realism of these latent vectors as decoded by the generator.gan
model. This means that, at
every step, we move the weights of the generator in a direction that will make the discriminator more likely to classify as "real" the
images decoded by the generator. I.e. we train the generator to fool the discriminator.Training GANs and tuning GAN implementations is notoriously difficult. There are a number of known "tricks" that one should keep in mind. Like most things in deep learning, it is more alchemy than science: these tricks are really just heuristics, not theory-backed guidelines. They are backed by some level of intuitive understanding of the phenomenon at hand, and they are known to work well empirically, albeit not necessarily in every context.
Here are a few of the tricks that we leverage in our own implementation of a GAN generator and discriminator below. It is not an exhaustive list of GAN-related tricks; you will find many more across the GAN literature.
as the last activation in the generator, instead of sigmoid
, which would be more commonly found in other types of models.LeakyReLU
layer instead of a ReLU activation. It is similar to ReLU but it
relaxes sparsity constraints by allowing small negative activation values.Conv2DTranpose
or Conv2D
in both the
generator and discriminator.First, we develop a generator
model, which turns a vector (from the latent space -- during training it will sampled at random) into a
candidate image. One of the many issues that commonly arise with GANs is that the generator gets stuck with generated images that look like
noise. A possible solution is to use dropout on both the discriminator and generator.
import keras
from keras import layers
import numpy as np
latent_dim = 32
height = 32
width = 32
channels = 3
generator_input = keras.Input(shape=(latent_dim,))
# First, transform the input into a 16x16 128-channels feature map
x = layers.Dense(128 * 16 * 16)(generator_input)
x = layers.LeakyReLU()(x)
x = layers.Reshape((16, 16, 128))(x)
# Then, add a convolution layer
x = layers.Conv2D(256, 5, padding='same')(x)
x = layers.LeakyReLU()(x)
# Upsample to 32x32
x = layers.Conv2DTranspose(256, 4, strides=2, padding='same')(x)
x = layers.LeakyReLU()(x)
# Few more conv layers
x = layers.Conv2D(256, 5, padding='same')(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(256, 5, padding='same')(x)
x = layers.LeakyReLU()(x)
# Produce a 32x32 1-channel feature map
x = layers.Conv2D(channels, 7, activation='tanh', padding='same')(x)
generator = keras.models.Model(generator_input, x)
discriminator_input = layers.Input(shape=(height, width, channels))
x = layers.Conv2D(128, 3)(discriminator_input)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Flatten()(x)
# One dropout layer - important trick!
x = layers.Dropout(0.4)(x)
# Classification layer
x = layers.Dense(1, activation='sigmoid')(x)
discriminator = keras.models.Model(discriminator_input, x)
# To stabilize training, we use learning rate decay
# and gradient clipping (by value) in the optimizer.
discriminator_optimizer = keras.optimizers.RMSprop(lr=0.0008, clipvalue=1.0, decay=1e-8)
discriminator.compile(optimizer=discriminator_optimizer, loss='binary_crossentropy')
Finally, we setup the GAN, which chains the generator and the discriminator. This is the model that, when trained, will move the generator
in a direction that improves its ability to fool the discriminator. This model turns latent space points into a classification decision,
"fake" or "real", and it is meant to be trained with labels that are always "these are real images". So training gan
will updates the
weights of generator
in a way that makes discriminator
more likely to predict "real" when looking at fake images. Very importantly, we
set the discriminator to be frozen during training (non-trainable): its weights will not be updated when training gan
. If the
discriminator weights could be updated during this process, then we would be training the discriminator to always predict "real", which is
not what we want!
# Set discriminator weights to non-trainable
# (will only apply to the `gan` model)
discriminator.trainable = False
gan_input = keras.Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input))
gan = keras.models.Model(gan_input, gan_output)
gan_optimizer = keras.optimizers.RMSprop(lr=0.0004, clipvalue=1.0, decay=1e-8)
gan.compile(optimizer=gan_optimizer, loss='binary_crossentropy')
Now we can start training. To recapitulate, this is schematically what the training loop looks like:
for each epoch:
* Draw random points in the latent space (random noise).
* Generate images with `generator` using this random noise.
* Mix the generated images with real ones.
* Train `discriminator` using these mixed images, with corresponding targets, either "real" (for the real images) or "fake" (for the generated images).
* Draw new random points in the latent space.
* Train `gan` using these random vectors, with targets that all say "these are real images". This will update the weights of the generator (only, since discriminator is frozen inside `gan`) to move them towards getting the discriminator to predict "these are real images" for generated images, i.e. this trains the generator to fool the discriminator.
Let's implement it:
import os
from keras.preprocessing import image
# Load CIFAR10 data
(x_train, y_train), (_, _) = keras.datasets.cifar10.load_data()
# Select frog images (class 6)
x_train = x_train[y_train.flatten() == 6]
# Normalize data
x_train = x_train.reshape(
(x_train.shape[0],) + (height, width, channels)).astype('float32') / 255.
iterations = 10000
batch_size = 20
save_dir = '/home/ubuntu/gan_images/'
# Start training loop
start = 0
for step in range(iterations):
# Sample random points in the latent space
random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))
# Decode them to fake images
generated_images = generator.predict(random_latent_vectors)
# Combine them with real images
stop = start + batch_size
real_images = x_train[start: stop]
combined_images = np.concatenate([generated_images, real_images])
# Assemble labels discriminating real from fake images
labels = np.concatenate([np.ones((batch_size, 1)),
np.zeros((batch_size, 1))])
# Add random noise to the labels - important trick!
labels += 0.05 * np.random.random(labels.shape)
# Train the discriminator
d_loss = discriminator.train_on_batch(combined_images, labels)
# sample random points in the latent space
random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))
# Assemble labels that say "all real images"
misleading_targets = np.zeros((batch_size, 1))
# Train the generator (via the gan model,
# where the discriminator weights are frozen)
a_loss = gan.train_on_batch(random_latent_vectors, misleading_targets)
start += batch_size
if start > len(x_train) - batch_size:
start = 0
# Occasionally save / plot
if step % 100 == 0:
# Save model weights
# Print metrics
print('discriminator loss at step %s: %s' % (step, d_loss))
print('adversarial loss at step %s: %s' % (step, a_loss))
# Save one generated image
img = image.array_to_img(generated_images[0] * 255., scale=False), 'generated_frog' + str(step) + '.png'))
# Save one real image, for comparison
img = image.array_to_img(real_images[0] * 255., scale=False), 'real_frog' + str(step) + '.png'))
Let's display a few of our fake images:
import matplotlib.pyplot as plt
# Sample random points in the latent space
random_latent_vectors = np.random.normal(size=(10, latent_dim))
# Decode them to fake images
generated_images = generator.predict(random_latent_vectors)
for i in range(generated_images.shape[0]):
img = image.array_to_img(generated_images[i] * 255., scale=False)
Froggy with some pixellated artifacts.