Generative Adverserial Networks

Generative Adversarial Networks (GAN) is one of the most promising recent developments in Deep Learning. GAN, introduced by Ian Goodfellow in 2014, attacks the problem of unsupervised learning by training two deep networks, called Generator and Discriminator, that compete and cooperate with each other. In the course of training, both networks eventually learn how to perform their tasks.

GAN is almost always explained like the case of a counterfeiter (Generative) and the police (Discriminator). Initially, the counterfeiter will show the police a fake money. The police says it is fake. The police gives feedback to the counterfeiter why the money is fake. The counterfeiter attempts to make a new fake money based on the feedback it received. The police says the money is still fake and offers a new set of feedback. The counterfeiter attempts to make a new fake money based on the latest feedback. The cycle continues indefinitely until the police is fooled by the fake money because it looks real.

While the idea of GAN is simple in theory, it is very difficult to build a model that works. In GAN, there are two deep networks coupled together making back propagation of gradients twice as challenging. Deep Convolutional GAN (DCGAN) is one of the models that demonstrated how to build a practical GAN that is able to learn by itself how to synthesize new images. In this article, we discuss how a working DCGAN can be built using Keras 2.0 on Tensorflow 1.0 backend in less than 200 lines of code. We will train a DCGAN to learn how to write handwritten digits, the MNIST way.


In [1]:
%matplotlib inline
import os,random
import numpy as np
import theano as th
import theano.tensor as T
from keras.utils import np_utils
import keras.models as models
from keras.layers import Input,merge
from keras.layers.core import Reshape,Dense,Dropout,Activation,Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.activations import *
from keras.layers.wrappers import TimeDistributed
from keras.layers.noise import GaussianNoise
from keras.layers.convolutional import Convolution2D, MaxPooling2D, ZeroPadding2D, Deconv2D, UpSampling2D
from keras.layers.recurrent import LSTM
from keras.regularizers import *
from keras.layers.normalization import *
from keras.optimizers import *
from keras.datasets import mnist
import matplotlib.pyplot as plt
import seaborn as sns
import cPickle, random, sys, keras
from keras.models import Model
from IPython import display
sys.path.append("../common")
from keras.utils import np_utils
from tqdm import tqdm


Using Theano backend.
Using gpu device 0: GeForce GTX 1080 Ti.

In [2]:
img_rows, img_cols = 28, 28

# the data, shuffled and split between train and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()

X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255

print np.min(X_train), np.max(X_train)

print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')


0.0 1.0
('X_train shape:', (60000, 1, 28, 28))
(60000, 'train samples')
(10000, 'test samples')

In [3]:
def make_trainable(net, val):
    net.trainable = val
    for l in net.layers:
        l.trainable = val

Discriminator

A discriminator that tells how real an image is, is basically a deep Convolutional Neural Network (CNN) as shown in Figure 1. For MNIST Dataset, the input is an image (28 pixel x 28 pixel x 1 channel). The sigmoid output is a scalar value of the probability of how real the image is (0.0 is certainly fake, 1.0 is certainly real, anything in between is a gray area). The difference from a typical CNN is the absence of max-pooling in between layers. Instead, a strided convolution is used for downsampling. The activation function used in each CNN layer is a leaky ReLU. A dropout between 0.4 and 0.7 between layers prevent over fitting and memorization. Listing 1 shows the implementation in Keras.

Figure 1. Discriminator of DCGAN tells how real an input image of a digit is. MNIST Dataset is used as ground truth for real images. Strided convolution instead of max-pooling down samples the image.

Generator

The generator synthesizes fake images. In Figure 2, the fake image is generated from a 100-dimensional noise (uniform distribution between -1.0 to 1.0) using the inverse of convolution, called transposed convolution. Instead of fractionally-strided convolution as suggested in DCGAN, upsampling between the first three layers is used since it synthesizes more realistic handwriting images. In between layers, batch normalization stabilizes learning. The activation function after each layer is a ReLU. The output of the sigmoid at the last layer produces the fake image. Dropout of between 0.3 and 0.5 at the first layer prevents overfitting. Listing 2 shows the implementation in Keras.

Figure 2. Generator model synthesizes fake MNIST images from noise. Upsampling is used instead of fractionally-strided transposed convolution.


In [4]:
shp = X_train.shape[1:]
print shp

dropout_rate = 0.25

opt = Adam(lr=1e-3)
dopt = Adam(lr=1e-4)
nch = 200

# Build Generative model 
nch = 200
g_input = Input(shape=[100])
H = Dense(nch*14*14, init='glorot_normal')(g_input)
H = BatchNormalization(mode=2)(H)
H = Activation('relu')(H)
H = Reshape( [nch, 14, 14] )(H)
H = UpSampling2D(size=(2, 2))(H)
H = Convolution2D(nch/2, 3, 3, border_mode='same', init='glorot_uniform')(H)
H = BatchNormalization(mode=2)(H)
H = Activation('relu')(H)
H = Convolution2D(nch/4, 3, 3, border_mode='same', init='glorot_uniform')(H)
H = BatchNormalization(mode=2)(H)
H = Activation('relu')(H)
H = Convolution2D(1, 1, 1, border_mode='same', init='glorot_uniform')(H)
g_V = Activation('sigmoid')(H)
generator = Model(g_input,g_V)
generator.compile(loss='binary_crossentropy', optimizer=opt)
generator.summary()


# Build Discriminative model 
d_input = Input(shape=shp)
H = Convolution2D(256, 5, 5, subsample=(2, 2), border_mode = 'same', activation='relu')(d_input)
H = LeakyReLU(0.2)(H)
H = Dropout(dropout_rate)(H)
H = Convolution2D(512, 5, 5, subsample=(2, 2), border_mode = 'same', activation='relu')(H)
H = LeakyReLU(0.2)(H)
H = Dropout(dropout_rate)(H)
H = Flatten()(H)
H = Dense(256)(H)
H = LeakyReLU(0.2)(H)
H = Dropout(dropout_rate)(H)
d_V = Dense(2,activation='softmax')(H)
discriminator = Model(d_input,d_V)
discriminator.compile(loss='categorical_crossentropy', optimizer=dopt)
discriminator.summary()

# Freeze weights in the discriminator for stacked training
def make_trainable(net, val):
    net.trainable = val
    for l in net.layers:
        l.trainable = val
make_trainable(discriminator, False)

# Build stacked GAN model
gan_input = Input(shape=[100])
H = generator(gan_input)
gan_V = discriminator(H)
GAN = Model(gan_input, gan_V)
GAN.compile(loss='categorical_crossentropy', optimizer=opt)
GAN.summary()


(1, 28, 28)
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_1 (InputLayer)             (None, 100)           0                                            
____________________________________________________________________________________________________
dense_1 (Dense)                  (None, 39200)         3959200     input_1[0][0]                    
____________________________________________________________________________________________________
batchnormalization_1 (BatchNormal(None, 39200)         78400       dense_1[0][0]                    
____________________________________________________________________________________________________
activation_1 (Activation)        (None, 39200)         0           batchnormalization_1[0][0]       
____________________________________________________________________________________________________
reshape_1 (Reshape)              (None, 200, 14, 14)   0           activation_1[0][0]               
____________________________________________________________________________________________________
upsampling2d_1 (UpSampling2D)    (None, 200, 28, 28)   0           reshape_1[0][0]                  
____________________________________________________________________________________________________
convolution2d_1 (Convolution2D)  (None, 100, 28, 28)   180100      upsampling2d_1[0][0]             
____________________________________________________________________________________________________
batchnormalization_2 (BatchNormal(None, 100, 28, 28)   56          convolution2d_1[0][0]            
____________________________________________________________________________________________________
activation_2 (Activation)        (None, 100, 28, 28)   0           batchnormalization_2[0][0]       
____________________________________________________________________________________________________
convolution2d_2 (Convolution2D)  (None, 50, 28, 28)    45050       activation_2[0][0]               
____________________________________________________________________________________________________
batchnormalization_3 (BatchNormal(None, 50, 28, 28)    56          convolution2d_2[0][0]            
____________________________________________________________________________________________________
activation_3 (Activation)        (None, 50, 28, 28)    0           batchnormalization_3[0][0]       
____________________________________________________________________________________________________
convolution2d_3 (Convolution2D)  (None, 1, 28, 28)     51          activation_3[0][0]               
____________________________________________________________________________________________________
activation_4 (Activation)        (None, 1, 28, 28)     0           convolution2d_3[0][0]            
====================================================================================================
Total params: 4262913
____________________________________________________________________________________________________
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_2 (InputLayer)             (None, 1, 28, 28)     0                                            
____________________________________________________________________________________________________
convolution2d_4 (Convolution2D)  (None, 256, 14, 14)   6656        input_2[0][0]                    
____________________________________________________________________________________________________
leakyrelu_1 (LeakyReLU)          (None, 256, 14, 14)   0           convolution2d_4[0][0]            
____________________________________________________________________________________________________
dropout_1 (Dropout)              (None, 256, 14, 14)   0           leakyrelu_1[0][0]                
____________________________________________________________________________________________________
convolution2d_5 (Convolution2D)  (None, 512, 7, 7)     3277312     dropout_1[0][0]                  
____________________________________________________________________________________________________
leakyrelu_2 (LeakyReLU)          (None, 512, 7, 7)     0           convolution2d_5[0][0]            
____________________________________________________________________________________________________
dropout_2 (Dropout)              (None, 512, 7, 7)     0           leakyrelu_2[0][0]                
____________________________________________________________________________________________________
flatten_1 (Flatten)              (None, 25088)         0           dropout_2[0][0]                  
____________________________________________________________________________________________________
dense_2 (Dense)                  (None, 256)           6422784     flatten_1[0][0]                  
____________________________________________________________________________________________________
leakyrelu_3 (LeakyReLU)          (None, 256)           0           dense_2[0][0]                    
____________________________________________________________________________________________________
dropout_3 (Dropout)              (None, 256)           0           leakyrelu_3[0][0]                
____________________________________________________________________________________________________
dense_3 (Dense)                  (None, 2)             514         dropout_3[0][0]                  
====================================================================================================
Total params: 9707266
____________________________________________________________________________________________________
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_3 (InputLayer)             (None, 100)           0                                            
____________________________________________________________________________________________________
model_1 (Model)                  (None, 1, 28, 28)     4262913     input_3[0][0]                    
____________________________________________________________________________________________________
model_2 (Model)                  (None, 2)             0           model_1[1][0]                    
====================================================================================================
Total params: 4262913
____________________________________________________________________________________________________

In [ ]:


In [ ]:


In [5]:
def plot_loss(losses):
        display.clear_output(wait=True)
        display.display(plt.gcf())
        plt.figure(figsize=(10,8))
        plt.plot(losses["d"], label='discriminitive loss')
        plt.plot(losses["g"], label='generative loss')
        plt.legend()
        plt.show()

In [6]:
def plot_gen(n_ex=16,dim=(4,4), figsize=(10,10) ):
    noise = np.random.uniform(0,1,size=[n_ex,100])
    generated_images = generator.predict(noise)

    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0],dim[1],i+1)
        img = generated_images[i,0,:,:]
        plt.imshow(img)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [ ]:


In [7]:
ntrain = 10000
trainidx = random.sample(range(0,X_train.shape[0]), ntrain)
XT = X_train[trainidx,:,:,:]

# Pre-train the discriminator network ...
noise_gen = np.random.uniform(0,1,size=[XT.shape[0],100])
generated_images = generator.predict(noise_gen)
X = np.concatenate((XT, generated_images))
n = XT.shape[0]
y = np.zeros([2*n,2])
y[:n,1] = 1
y[n:,0] = 1

make_trainable(discriminator,True)
discriminator.fit(X,y, nb_epoch=1, batch_size=32)
y_hat = discriminator.predict(X)


Epoch 1/1
20000/20000 [==============================] - 19s - loss: 0.0063    

In [8]:
y_hat_idx = np.argmax(y_hat,axis=1)
y_idx = np.argmax(y,axis=1)
diff = y_idx-y_hat_idx
n_tot = y.shape[0]
n_rig = (diff==0).sum()
acc = n_rig*100.0/n_tot
print "Accuracy: %0.02f pct (%d of %d) right"%(acc, n_rig, n_tot)


Accuracy: 100.00 pct (20000 of 20000) right

In [9]:
# set up loss storage vector
losses = {"d":[], "g":[]}

Adversarial Training

The adversarial model is just the generator-discriminator stacked together as shown in Figure 3. The Generator part is trying to fool the Discriminator and learning from its feedback at the same time. Listing 4 shows the implementation using Keras code. The training parameters are the same as in the Discriminator model except for a reduced learning rate and corresponding weight decay.

Figure 3. The Adversarial model is simply generator with its output connected to the input of the discriminator. Also shown is the training process wherein the Generator labels its fake image output with 1.0 trying to fool the Discriminator.

Training

Training is the hardest part. We determine first if Discriminator model is correct by training it alone with real and fake images. Afterwards, the Discriminator and Adversarial models are trained one after the other. Figure 4 shows the Discriminator Model while Figure 3 shows the Adversarial Model during training. Listing 5 shows the training code in Keras.

Figure 4. Discriminator model is trained to distinguish real from fake handwritten images.


In [10]:
def train_for_n(nb_epoch=5000, plt_frq=25,BATCH_SIZE=32):

    for e in tqdm(range(nb_epoch)):  
        
        # Make generative images
        image_batch = X_train[np.random.randint(0,X_train.shape[0],size=BATCH_SIZE),:,:,:]    
        noise_gen = np.random.uniform(0,1,size=[BATCH_SIZE,100])
        generated_images = generator.predict(noise_gen)
        
        # Train discriminator on generated images
        X = np.concatenate((image_batch, generated_images))
        y = np.zeros([2*BATCH_SIZE,2])
        y[0:BATCH_SIZE,1] = 1
        y[BATCH_SIZE:,0] = 1
        
        make_trainable(discriminator,True)
        d_loss  = discriminator.train_on_batch(X,y)
        losses["d"].append(d_loss)
    
        # train Generator-Discriminator stack on input noise to non-generated output class
        noise_tr = np.random.uniform(0,1,size=[BATCH_SIZE,100])
        y2 = np.zeros([BATCH_SIZE,2])
        y2[:,1] = 1
        
        make_trainable(discriminator,False)
        g_loss = GAN.train_on_batch(noise_tr, y2 )
        losses["g"].append(g_loss)
        
        # Updates plots
        if e%plt_frq==plt_frq-1:
            plot_loss(losses)
            plot_gen()

In [11]:
train_for_n(nb_epoch=250, plt_frq=25,BATCH_SIZE=128)


<matplotlib.figure.Figure at 0x7fcaac680e90>
<matplotlib.figure.Figure at 0x7fcaac680e90>
100%|██████████| 250/250 [01:32<00:00,  1.42it/s]

In [ ]:
K.set_value(opt.lr, 1e-4)
K.set_value(dopt.lr, 1e-5)
train_for_n(nb_epoch=100, plt_frq=10,BATCH_SIZE=128)


<matplotlib.figure.Figure at 0x7fcac04a3190>
<matplotlib.figure.Figure at 0x7fcac04a3190>
100%|██████████| 100/100 [00:44<00:00,  1.32it/s]

In [ ]:
K.set_value(opt.lr, 1e-5)
K.set_value(dopt.lr, 1e-6)
train_for_n(nb_epoch=100, plt_frq=10,BATCH_SIZE=256)


<matplotlib.figure.Figure at 0x7fcaac3d2290>
<matplotlib.figure.Figure at 0x7fcaac3d2290>
 13%|█▎        | 13/100 [00:10<01:10,  1.24it/s]

Hands On

Retrain the GAN and try to get better Results!


In [10]: