Using GAN for Generating Hand-written Digit Images

GAN (Generative Adversarial Network) [1] is a framework proposed by Ian Goodfellow, Yoshua Bengio and others in 2014.

A GAN can be trained to generate images from random noises. For example, we can train a GAN on MNIST (hand-written digits dataset) to generate digit images that look like hand-written digit images from MNIST, which could be used to train other neural networks.

The code in this notebook is based on the GAN MNIST example in TensorFlow by Udacity [2] which uses TensorFlow, but we use Keras on top of TensorFlow for more straightforward construction of networks. Many of the ideas on training are from How to Train a GAN? Tips and tricks to make GANs work [4].

MNIST

MNIST is a well known database of handwritten digits [3].


In [1]:
import numpy as np
import keras
import keras.backend as K
from keras.layers import Input, Dense, Activation, LeakyReLU, BatchNormalization
from keras.models import Sequential
from keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib inline


Using TensorFlow backend.

The below will download MNIST dataset (if not already).


In [2]:
(X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data()

Let's examine sample images. We use 'gray' color map since it has no color information.


In [3]:
plt.figure(figsize=(5, 4))
for i in range(20):
    plt.subplot(4, 5, i+1)
    plt.imshow(X_train[i], cmap='gray')
    plt.xticks([])
    plt.yticks([])
plt.tight_layout()
plt.show()


All MNIST digit images come in 28x28 size.


In [4]:
sample = X_train[17]

plt.figure(figsize=(3, 2))
plt.title(sample.shape)
plt.imshow(sample, cmap='gray')
plt.show()


The minimum and maximum sizes of MNIST image data is 0 and 255 respectively.


In [5]:
X_train.min(), X_train.max()


Out[5]:
(0, 255)

Generator

We want to build a generator that generates realistic hand-written images.

The input to the generator is called 'latent sample' which is a series of randomly generated numbers. We use the normal distribution rather than the uniform distribution as suggested in [4].


In [6]:
def make_latent_samples(n_samples, sample_size):
    #return np.random.uniform(-1, 1, size=(n_samples, sample_size))
    return np.random.normal(loc=0, scale=1, size=(n_samples, sample_size))

The sample size is a hyperparameter. Below, we use a vector of 100 randomly generated number as a sample.


In [7]:
make_latent_samples(1, 100) # generates one sample


Out[7]:
array([[ 0.27473292, -0.10655303,  1.09659154, -0.15366093, -0.61587259,
         0.26621725, -0.01077858,  1.90989847, -1.18569009, -0.03249574,
         0.6695888 ,  0.93482157, -1.00884217, -0.29230866, -1.13929507,
        -0.01707456,  0.37335705,  0.33970841,  1.37859658, -0.53026397,
         1.12131043, -1.66056623, -0.01493777,  0.808652  , -0.36080177,
         1.18045669, -0.33343053,  0.60287437,  0.9825658 , -0.31889656,
         0.1179318 , -0.91808507, -1.03957979,  0.40643158, -0.32754068,
         0.80441626, -0.01966988, -0.79409032,  1.55127879,  0.34457065,
        -0.00255198,  1.01100422,  0.56261678, -0.39284342, -0.03217389,
         1.09418122, -0.90881511, -0.19342759,  0.3317994 ,  0.19549762,
        -1.40058816, -0.16028498, -1.86537691,  0.6165322 , -0.4672151 ,
        -0.23835781, -0.35751269, -0.97823372,  1.26912872, -1.29290883,
        -0.97779726,  1.76487061, -0.33914689, -0.57437618,  0.86655979,
        -0.27751868,  0.71869572, -1.22436248,  0.7134086 , -1.1244994 ,
         0.99746905, -0.00786507, -0.66620361, -1.40849483, -0.26476278,
        -0.40399178,  0.35693832, -1.45288997, -0.79326825, -0.48003003,
        -0.15437594,  0.12191884, -0.00680743, -1.30782153,  0.45268918,
         0.68991131,  1.85151145,  1.05853026, -0.4318387 , -0.19847975,
        -1.13489859, -0.59428163,  0.4483107 , -1.07967249, -0.02395855,
         1.00938931,  0.06230197, -1.74878905, -0.13811637, -0.4902213 ]])

The generator is a simple fully connected neural network with one hidden layer with the leaky ReLU activation. It takes one latent sample (100 values) and produces 784 (=28x28) data points which represent a digit image.


In [8]:
generator = Sequential([
    Dense(128, input_shape=(100,)),
    LeakyReLU(alpha=0.01),
    Dense(784),
    Activation('tanh')
], name='generator')

generator.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 128)               12928     
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 128)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 784)               101136    
_________________________________________________________________
activation_1 (Activation)    (None, 784)               0         
=================================================================
Total params: 114,064
Trainable params: 114,064
Non-trainable params: 0
_________________________________________________________________

The last activation is tanh. According to [4], this works the best. It also means that we need to rescale the MNIST images to be between -1 and 1.

Initially, the generator can only produce garbages.

As such, the generator needs to learn how to generate realistic hand-written images from the latent sample (randomly generated numbers).

How to train this generator? That is the question tackled by GAN.

Before talking about GAN, we shall discuss the discriminator.

Discriminator

The discriminator takes a digit image and classifies whether an image is real (1) or not (0).

If the input image is from the MNIST database, the discriminator should classify it as real.

If the input image is from the generator, the discriminator should classify it as fake.

The discriminator is a simple fully connected neural network with one hidden layer with the leaky ReLU activation.


In [9]:
discriminator = Sequential([
    Dense(128, input_shape=(784,)),
    LeakyReLU(alpha=0.01),
    Dense(1),
    Activation('sigmoid')
], name='discriminator')

discriminator.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 128)               100480    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 128)               0         
_________________________________________________________________
dense_4 (Dense)              (None, 1)                 129       
_________________________________________________________________
activation_2 (Activation)    (None, 1)                 0         
=================================================================
Total params: 100,609
Trainable params: 100,609
Non-trainable params: 0
_________________________________________________________________

The last activation is sigmoid to tell us the probability of whether the input image is real or not.

We train the discriminator using both the MNIST images and the images generated by the generator.

GAN

We connect the generator and the discriminator to produce a GAN.

It takes the latent sample, and the generator inside GAN produces a digit image which the discriminator inside GAN classifies as real or fake.

If the generated digit image is so realistic, the discriminator in the GAN classifies it as real, which is what we want to achieve.

We set the discriminator inside the GAN not-trainable, so it is merely evaluating the quality of the generated image. The label is always 1 (real) so that if the generator fails to produce a realistic digit image, its cost becomes high, and when the back-propagation occurs in GAN, the weights in the generator network gets updated.


In [10]:
# maintain the same shared weights with the generator and the discriminator.
gan = Sequential([
    generator,
    discriminator
])

gan.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
generator (Sequential)       (None, 784)               114064    
_________________________________________________________________
discriminator (Sequential)   (None, 1)                 100609    
=================================================================
Total params: 214,673
Trainable params: 214,673
Non-trainable params: 0
_________________________________________________________________

As you can see, the GAN internally uses the same generator and the discriminator models. The GAN maintains the same shared weights with the generator and the disriminator. Therefore, training the GAN also trains the generator. However, we do not want the discriminator to be affected while training the GAN.

We train the discriminator and the GAN in turn and repeat the training many times until both are trained well.

While training the GAN, the back-propagation should update the weights of the generator but not the discriminator.

As such, we need a way to make the discriminator trainable and non-trainable.


In [11]:
def make_trainable(model, trainable):
    for layer in model.layers:
        layer.trainable = trainable

In [12]:
make_trainable(discriminator, False)
discriminator.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 128)               100480    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 128)               0         
_________________________________________________________________
dense_4 (Dense)              (None, 1)                 129       
_________________________________________________________________
activation_2 (Activation)    (None, 1)                 0         
=================================================================
Total params: 100,609
Trainable params: 0
Non-trainable params: 100,609
_________________________________________________________________

In [13]:
make_trainable(discriminator, True)
discriminator.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 128)               100480    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 128)               0         
_________________________________________________________________
dense_4 (Dense)              (None, 1)                 129       
_________________________________________________________________
activation_2 (Activation)    (None, 1)                 0         
=================================================================
Total params: 100,609
Trainable params: 100,609
Non-trainable params: 0
_________________________________________________________________

The below function combines everything we have discussed so far to build the generator, discriminator, and GAN models and also compile them for training.


In [14]:
def make_simple_GAN(sample_size, 
                    g_hidden_size, 
                    d_hidden_size, 
                    leaky_alpha, 
                    g_learning_rate,
                    d_learning_rate):
    K.clear_session()
    
    generator = Sequential([
        Dense(g_hidden_size, input_shape=(sample_size,)),
        LeakyReLU(alpha=leaky_alpha),
        Dense(784),        
        Activation('tanh')
    ], name='generator')    

    discriminator = Sequential([
        Dense(d_hidden_size, input_shape=(784,)),
        LeakyReLU(alpha=leaky_alpha),
        Dense(1),
        Activation('sigmoid')
    ], name='discriminator')    
    
    gan = Sequential([
        generator,
        discriminator
    ])
    
    discriminator.compile(optimizer=Adam(lr=d_learning_rate), loss='binary_crossentropy')
    gan.compile(optimizer=Adam(lr=g_learning_rate), loss='binary_crossentropy')
    
    return gan, generator, discriminator

Training GAN

Preprocessing

We need to flatten the digit image data as the fully connected input layer expects that. Also, as the generator uses the tanh activation in the output layer, we scale all the MNIST images to have values between -1 and 1.


In [15]:
def preprocess(x):    
    x = x.reshape(-1, 784) # 784=28*28
    x = np.float64(x)
    x = (x / 255 - 0.5) * 2
    x = np.clip(x, -1, 1)
    return x

In [16]:
X_train_real = preprocess(X_train)
X_test_real  = preprocess(X_test)

Deprocessing

We also need a function to reverse the preprocessing so that we can display generated images.


In [17]:
def deprocess(x):
    x = (x / 2 + 1) * 255
    x = np.clip(x, 0, 255)
    x = np.uint8(x)
    x = x.reshape(28, 28)
    return x

In [18]:
plt.figure(figsize=(5, 4))
for i in range(20):
    img = deprocess(X_train_real[i])
    plt.subplot(4, 5, i+1)
    plt.imshow(img, cmap='gray')
    plt.xticks([])
    plt.yticks([])
plt.tight_layout()
plt.show()


Labels

The labels are 1 (real) or 0 (fake) in 2D shape.


In [19]:
def make_labels(size):
    return np.ones([size, 1]), np.zeros([size, 1])

The below is 10 sets of real and fake label values.


In [20]:
y_real_10, y_fake_10 = make_labels(10)

y_real_10, y_fake_10


Out[20]:
(array([[ 1.],
        [ 1.],
        [ 1.],
        [ 1.],
        [ 1.],
        [ 1.],
        [ 1.],
        [ 1.],
        [ 1.],
        [ 1.]]), array([[ 0.],
        [ 0.],
        [ 0.],
        [ 0.],
        [ 0.],
        [ 0.],
        [ 0.],
        [ 0.],
        [ 0.],
        [ 0.]]))

Later, we prepare the labels for training and evaluation using the train batch size and the test size.

Label Smoothing

One last point before we start training is the label smoothing which makes the discriminator generalize better [4].

For the real digit images, the labels are all 1s. However, when we train the discriminator, we use a value slightly smaller than 1 with the real digit images. Otherwise, the discriminator might overfit to the training data and rejects anything else that is slightly different from the training images.

Training Loop

We repeat the following to make both the discriminator and the generator better and better:

  • Prepare a batch of real images
  • Prepare a batch of fake images generated by the generator using latent samples
  • Make the discriminator trainable
  • Train the discriminator to classify the real and fake images
  • Make the discriminator non-trainable
  • Train the generator via the GAN

When training the generator via the GAN, the expect labels are all 1s (real). Initially, the generator produces not very realistic images so the discriminator classifies them as 0s (fake), which causes the back-propagation to adjust the weights inside the generator. The discriminator is not affected as we set it non-trainable in this step.


In [21]:
# hyperparameters
sample_size     = 100     # latent sample size (i.e., 100 random numbers)
g_hidden_size   = 128
d_hidden_size   = 128
leaky_alpha     = 0.01
g_learning_rate = 0.0001  # learning rate for the generator
d_learning_rate = 0.001   # learning rate for the discriminator
epochs          = 100
batch_size      = 64      # train batch size
eval_size       = 16      # evaluate size
smooth          = 0.1

# labels for the batch size and the test size
y_train_real, y_train_fake = make_labels(batch_size)
y_eval_real,  y_eval_fake  = make_labels(eval_size)

# create a GAN, a generator and a discriminator
gan, generator, discriminator = make_simple_GAN(
    sample_size, 
    g_hidden_size, 
    d_hidden_size, 
    leaky_alpha, 
    g_learning_rate,
    d_learning_rate)

losses = []
for e in range(epochs):
    for i in range(len(X_train_real)//batch_size):
        # real MNIST digit images
        X_batch_real = X_train_real[i*batch_size:(i+1)*batch_size]
        
        # latent samples and the generated digit images
        latent_samples = make_latent_samples(batch_size, sample_size)
        X_batch_fake = generator.predict_on_batch(latent_samples)
        
        # train the discriminator to detect real and fake images
        make_trainable(discriminator, True)
        discriminator.train_on_batch(X_batch_real, y_train_real * (1 - smooth))
        discriminator.train_on_batch(X_batch_fake, y_train_fake)

        # train the generator via GAN
        make_trainable(discriminator, False)
        gan.train_on_batch(latent_samples, y_train_real)
    
    # evaluate
    X_eval_real = X_test_real[np.random.choice(len(X_test_real), eval_size, replace=False)]
    
    latent_samples = make_latent_samples(eval_size, sample_size)
    X_eval_fake = generator.predict_on_batch(latent_samples)

    d_loss  = discriminator.test_on_batch(X_eval_real, y_eval_real)
    d_loss += discriminator.test_on_batch(X_eval_fake, y_eval_fake)
    g_loss  = gan.test_on_batch(latent_samples, y_eval_real) # we want the fake to be realistic!
    
    losses.append((d_loss, g_loss))
    
    print("Epoch: {:>3}/{} Discriminator Loss: {:>6.4f} Generator Loss: {:>6.4f}".format(
        e+1, epochs, d_loss, g_loss))


Epoch:   1/100 Discriminator Loss: 0.9538 Generator Loss: 5.5816
Epoch:   2/100 Discriminator Loss: 0.1899 Generator Loss: 2.1410
Epoch:   3/100 Discriminator Loss: 0.3655 Generator Loss: 1.4118
Epoch:   4/100 Discriminator Loss: 0.1344 Generator Loss: 3.0701
Epoch:   5/100 Discriminator Loss: 0.3473 Generator Loss: 2.3738
Epoch:   6/100 Discriminator Loss: 0.5153 Generator Loss: 3.7222
Epoch:   7/100 Discriminator Loss: 0.5159 Generator Loss: 3.5699
Epoch:   8/100 Discriminator Loss: 0.7343 Generator Loss: 2.1149
Epoch:   9/100 Discriminator Loss: 0.4661 Generator Loss: 2.3607
Epoch:  10/100 Discriminator Loss: 0.2888 Generator Loss: 2.4522
Epoch:  11/100 Discriminator Loss: 0.4095 Generator Loss: 1.5612
Epoch:  12/100 Discriminator Loss: 0.5406 Generator Loss: 2.7845
Epoch:  13/100 Discriminator Loss: 0.3287 Generator Loss: 2.6246
Epoch:  14/100 Discriminator Loss: 0.2806 Generator Loss: 2.9350
Epoch:  15/100 Discriminator Loss: 0.6982 Generator Loss: 3.5069
Epoch:  16/100 Discriminator Loss: 0.4339 Generator Loss: 3.3039
Epoch:  17/100 Discriminator Loss: 0.7092 Generator Loss: 1.9226
Epoch:  18/100 Discriminator Loss: 0.8024 Generator Loss: 4.9868
Epoch:  19/100 Discriminator Loss: 0.2961 Generator Loss: 2.8417
Epoch:  20/100 Discriminator Loss: 0.5851 Generator Loss: 3.4906
Epoch:  21/100 Discriminator Loss: 0.3387 Generator Loss: 1.9244
Epoch:  22/100 Discriminator Loss: 0.4378 Generator Loss: 3.1012
Epoch:  23/100 Discriminator Loss: 0.2871 Generator Loss: 2.0432
Epoch:  24/100 Discriminator Loss: 0.3734 Generator Loss: 2.6555
Epoch:  25/100 Discriminator Loss: 0.7119 Generator Loss: 2.8028
Epoch:  26/100 Discriminator Loss: 0.1978 Generator Loss: 2.8457
Epoch:  27/100 Discriminator Loss: 0.5232 Generator Loss: 2.5416
Epoch:  28/100 Discriminator Loss: 0.2756 Generator Loss: 2.2270
Epoch:  29/100 Discriminator Loss: 0.3289 Generator Loss: 3.0124
Epoch:  30/100 Discriminator Loss: 0.5604 Generator Loss: 3.3040
Epoch:  31/100 Discriminator Loss: 0.8915 Generator Loss: 3.2336
Epoch:  32/100 Discriminator Loss: 0.6021 Generator Loss: 1.9195
Epoch:  33/100 Discriminator Loss: 0.4144 Generator Loss: 3.0869
Epoch:  34/100 Discriminator Loss: 0.6223 Generator Loss: 2.1233
Epoch:  35/100 Discriminator Loss: 0.2667 Generator Loss: 2.5386
Epoch:  36/100 Discriminator Loss: 0.3951 Generator Loss: 2.5837
Epoch:  37/100 Discriminator Loss: 0.6367 Generator Loss: 2.4578
Epoch:  38/100 Discriminator Loss: 0.4050 Generator Loss: 2.6965
Epoch:  39/100 Discriminator Loss: 0.3699 Generator Loss: 2.7588
Epoch:  40/100 Discriminator Loss: 0.6437 Generator Loss: 2.1038
Epoch:  41/100 Discriminator Loss: 0.2477 Generator Loss: 3.1998
Epoch:  42/100 Discriminator Loss: 0.3357 Generator Loss: 3.3893
Epoch:  43/100 Discriminator Loss: 0.8104 Generator Loss: 2.9031
Epoch:  44/100 Discriminator Loss: 0.5600 Generator Loss: 3.6187
Epoch:  45/100 Discriminator Loss: 0.4684 Generator Loss: 2.3988
Epoch:  46/100 Discriminator Loss: 0.2899 Generator Loss: 3.8236
Epoch:  47/100 Discriminator Loss: 0.2372 Generator Loss: 3.9306
Epoch:  48/100 Discriminator Loss: 0.5744 Generator Loss: 2.6210
Epoch:  49/100 Discriminator Loss: 0.5644 Generator Loss: 2.2713
Epoch:  50/100 Discriminator Loss: 0.9803 Generator Loss: 2.6462
Epoch:  51/100 Discriminator Loss: 0.5349 Generator Loss: 3.0719
Epoch:  52/100 Discriminator Loss: 0.8361 Generator Loss: 3.4607
Epoch:  53/100 Discriminator Loss: 0.4824 Generator Loss: 3.0189
Epoch:  54/100 Discriminator Loss: 0.6155 Generator Loss: 2.7298
Epoch:  55/100 Discriminator Loss: 0.6074 Generator Loss: 2.4785
Epoch:  56/100 Discriminator Loss: 0.6182 Generator Loss: 2.7999
Epoch:  57/100 Discriminator Loss: 0.8172 Generator Loss: 2.2989
Epoch:  58/100 Discriminator Loss: 0.6180 Generator Loss: 3.2786
Epoch:  59/100 Discriminator Loss: 0.8217 Generator Loss: 3.1931
Epoch:  60/100 Discriminator Loss: 0.6151 Generator Loss: 3.0382
Epoch:  61/100 Discriminator Loss: 0.8208 Generator Loss: 3.2423
Epoch:  62/100 Discriminator Loss: 0.7167 Generator Loss: 2.0826
Epoch:  63/100 Discriminator Loss: 0.7112 Generator Loss: 2.6495
Epoch:  64/100 Discriminator Loss: 0.5140 Generator Loss: 2.6749
Epoch:  65/100 Discriminator Loss: 0.8169 Generator Loss: 2.7637
Epoch:  66/100 Discriminator Loss: 0.5278 Generator Loss: 2.1983
Epoch:  67/100 Discriminator Loss: 0.7610 Generator Loss: 3.1669
Epoch:  68/100 Discriminator Loss: 0.5442 Generator Loss: 2.8738
Epoch:  69/100 Discriminator Loss: 0.8466 Generator Loss: 2.0486
Epoch:  70/100 Discriminator Loss: 0.6251 Generator Loss: 2.2485
Epoch:  71/100 Discriminator Loss: 0.6418 Generator Loss: 2.2814
Epoch:  72/100 Discriminator Loss: 0.4677 Generator Loss: 2.2908
Epoch:  73/100 Discriminator Loss: 0.6132 Generator Loss: 2.8723
Epoch:  74/100 Discriminator Loss: 0.7114 Generator Loss: 2.6701
Epoch:  75/100 Discriminator Loss: 1.0905 Generator Loss: 1.9237
Epoch:  76/100 Discriminator Loss: 0.9982 Generator Loss: 2.5257
Epoch:  77/100 Discriminator Loss: 0.7214 Generator Loss: 2.1899
Epoch:  78/100 Discriminator Loss: 0.7138 Generator Loss: 1.9392
Epoch:  79/100 Discriminator Loss: 0.9038 Generator Loss: 2.3197
Epoch:  80/100 Discriminator Loss: 1.0450 Generator Loss: 3.0043
Epoch:  81/100 Discriminator Loss: 0.5381 Generator Loss: 2.9071
Epoch:  82/100 Discriminator Loss: 0.5621 Generator Loss: 2.5895
Epoch:  83/100 Discriminator Loss: 0.8544 Generator Loss: 3.3824
Epoch:  84/100 Discriminator Loss: 0.8167 Generator Loss: 2.6601
Epoch:  85/100 Discriminator Loss: 0.7621 Generator Loss: 2.9904
Epoch:  86/100 Discriminator Loss: 0.8123 Generator Loss: 2.7157
Epoch:  87/100 Discriminator Loss: 0.5252 Generator Loss: 3.1781
Epoch:  88/100 Discriminator Loss: 0.9563 Generator Loss: 2.1756
Epoch:  89/100 Discriminator Loss: 1.0338 Generator Loss: 2.7354
Epoch:  90/100 Discriminator Loss: 0.5451 Generator Loss: 3.0826
Epoch:  91/100 Discriminator Loss: 0.9634 Generator Loss: 3.0830
Epoch:  92/100 Discriminator Loss: 0.7814 Generator Loss: 2.9515
Epoch:  93/100 Discriminator Loss: 0.8324 Generator Loss: 3.5539
Epoch:  94/100 Discriminator Loss: 1.8759 Generator Loss: 1.9687
Epoch:  95/100 Discriminator Loss: 0.9151 Generator Loss: 2.2101
Epoch:  96/100 Discriminator Loss: 0.9279 Generator Loss: 2.2720
Epoch:  97/100 Discriminator Loss: 0.8894 Generator Loss: 3.2808
Epoch:  98/100 Discriminator Loss: 0.9052 Generator Loss: 2.5838
Epoch:  99/100 Discriminator Loss: 0.6043 Generator Loss: 2.9611
Epoch: 100/100 Discriminator Loss: 0.5693 Generator Loss: 2.3100

Stabilizing GAN

As it turns out, training a GAN is quite hard, and there are many tricks and heuristics required [4]. It is because the discriminator and the generator are not cooperating and individually learning to predict better.

For example, the generator might learn to fool the discriminator with garbage. Ideally, the discriminator should learn earlier than the generator so that it can classify images accurately.

Therefore, I used different learning rates for the generator and the discriminator. I wanted to slow down the generator learning so that the discriminator learns to classify well.

I am not 100% certain if this is a generally good strategy to use but it does seem to work in this project.

As the generator learns more and the loss decreases, the discriminator's loss increases. I see kind of equilibrium around 80-90 epochs.


In [22]:
losses = np.array(losses)

fig, ax = plt.subplots()
plt.plot(losses.T[0], label='Discriminator')
plt.plot(losses.T[1], label='Generator')
plt.title("Training Losses")
plt.legend()
plt.show()


Tracking Failures

According to [4],

  • The discriminator loss = 0 means something is wrong
  • When things are working, the discriminator loss has low variance and goes down over time.
  • When things are not working, the discriminator loss has huge variance and spiking
  • If the generator loss steadily decreases, it is fooling the discriminator with garbage.

Don't balance via loss statistics

According to [4],

  • i.e., don't try to find a (number of G/number of D) schedule to uncollapse training

Testing the Generator

Now we generates some digit images using the trained generator.


In [23]:
latent_samples = make_latent_samples(20, sample_size)
generated_digits = generator.predict(latent_samples)

plt.figure(figsize=(10, 8))
for i in range(20):
    img = deprocess(generated_digits[i])
    plt.subplot(4, 5, i+1)
    plt.imshow(img, cmap='gray')
    plt.xticks([])
    plt.yticks([])
plt.tight_layout()
plt.show()


The results are not outstanding as we are using simple networks. Deep Convolutional GAN (aka DCGAN) would produce better results than this.

References

[1] Generative Adversarial Networks

Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio

https://arxiv.org/abs/1406.2661

[2] GAN MNIST Example in TensorFlow

Udacity

https://github.com/udacity/deep-learning/tree/master/gan_mnist

[3] MNIST dataset

Yann LeCun

http://yann.lecun.com/exdb/mnist/

[4] How to Train a GAN? Tips and tricks to make GANs work

Facebook AI Research: Soumith Chintala, Emily Denton, Martin Arjovsky, Michael Mathieu

https://github.com/soumith/ganhacks

https://www.youtube.com/watch?v=X1mUN6dD8uE