Generative Adversarial Networks are invented by Ian Goodfellow (https://arxiv.org/abs/1406.2661).
"There are many interesting recent development in deep learning…The most important one, in my opinion, is adversarial training (also called GAN for Generative Adversarial Networks). This, and the variations that are now being proposed is the most interesting idea in the last 10 years in ML, in my opinion." – Yann LeCun
One network generates candidates and one evaluates them, i.e. we have two models, a generative model and a discriminative model. Before looking at GANs, let’s briefly review the difference between generative and discriminative models:
The discriminative model has the task of determining whether a given image looks natural (an image from the dataset) or looks like it has been artificially created. The task of the generator is to create images so that the discriminator gets trained to produce the correct outputs. This can be thought of as a zero-sum or minimax two player game. Or Goodfellow describes it "the generative model is pitted against an adversary: a discriminative model that learns to determine whether a sample is from the model distribution or the data distribution. The generative model can be thought of as analogous to a team of counterfeiters, trying to produce fake currency and use it without detection, while the discriminative model is analogous to the police, trying to detect the counterfeit currency. Competition in this game drives both teams to improve their methods until the counterfeits are indistiguishable from the genuine articles."
The generator is typically a deconvolutional neural network, and the discriminator is a convolutional neural network. Convolutional networks are a bottom-up approach where the input signal is subjected to multiple layers of convolutions, non-linearities and sub-sampling. By contrast, each layer in our Deconvolutional Network is top-down; it seeks to generate the input signal by a sum over convolutions of the feature maps (as opposed to the input) with learned filters. Given an input and a set of filters, inferring the feature map activations requires solving a multi-component deconvolution problem that is computationally challenging.
Here is a short overview of the process:
What are the pros and cons of Generative Adversarial Networks?
Why are they important? The discriminator now is aware of the “internal representation of the data” because it has been trained to understand the differences between real images from the dataset and artificially created ones. Thus, it can be used as a feature extractor that you can use in a CNN.
In [1]:
import numpy as np
from keras.datasets import mnist
import keras
from keras.layers import Input, UpSampling2D, Conv2DTranspose, Conv2D, LeakyReLU
from keras.layers.core import Reshape,Dense,Dropout,Activation,Flatten
from keras.models import Sequential
from keras.optimizers import RMSprop, Adam
from tensorflow.examples.tutorials.mnist import input_data
from keras.layers.normalization import *
import matplotlib.pyplot as plt
import cv2
%matplotlib inline
The MNIST database (Modified National Institute of Standards and Technology database) is a large database of handwritten digits that is commonly used for training various image processing systems. The database is also widely used for training and testing in the field of machine learning.
The MNIST database contains 60,000 training images and 10,000 testing images. Half of the training set and half of the test set were taken from NIST's training dataset, while the other half of the training set and the other half of the test set were taken from NIST's testing dataset.
In [2]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()
In [3]:
x_train = input_data.read_data_sets("mnist",one_hot=True).train.images
x_train = x_train.reshape(-1, 28,28, 1).astype(np.float32)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')
pixels = x_train[0]
pixels = pixels.reshape((28, 28))
# Plot
plt.imshow(pixels, cmap='gray')
plt.show()
In [4]:
# Build discriminator
Dis = Sequential()
input_shape = (28,28,1)
#output 14 x 14 x 64
Dis.add(Conv2D(64, 5, strides = 2, input_shape = input_shape, padding='same'))
Dis.add(LeakyReLU(0.2))
Dis.add(Dropout(0.2))
#output 7 x 7 x 128
Dis.add(Conv2D(128, 5, strides = 2, input_shape = input_shape, padding='same'))
Dis.add(LeakyReLU(0.2))
Dis.add(Dropout(0.2))
#output 4 x 4 x 256
Dis.add(Conv2D(256, 5, strides = 2, input_shape = input_shape, padding='same'))
Dis.add(LeakyReLU(0.2))
Dis.add(Dropout(0.2))
#output 4 x 4 x 512
Dis.add(Conv2D(512, 5, strides = 1, input_shape = input_shape, padding='same'))
Dis.add(LeakyReLU(0.2))
Dis.add(Dropout(0.2))
# Out: 1-dim probability
Dis.add(Flatten())
Dis.add(Dense(1))
Dis.add(Activation('sigmoid'))
Dis.summary()
For the generator we generate 100 random inputs and eventually map them down to a [1,28,28] pixel so that the they have the same shape as the MNIST data. In Keras, for Deconvolution there is the command "Conv2DTranspose": Transposed convolution layer (sometimes called Deconvolution). The need for transposed convolutions generally arises from the desire to use a transformation going in the opposite direction of a normal convolution, i.e., from something that has the shape of the output of some convolution to something that has the shape of its input while maintaining a connectivity pattern that is compatible with said convolution.
In [5]:
#Build generator
g_input = Input(shape=[100])
Gen = Sequential()
Gen.add(Dense(7*7*256, input_dim=100,kernel_initializer="glorot_normal"))
Gen.add(BatchNormalization(momentum=0.9))
Gen.add(Activation('relu'))
Gen.add(Reshape((7, 7,256)))
#G.add(Dropout(0.2))
# Input 7 x 7 x 256
# Output 14 x 14 x 128
Gen.add(UpSampling2D())
Gen.add(Conv2DTranspose(int(128), 5, padding='same',kernel_initializer="glorot_normal"))
Gen.add(BatchNormalization(momentum=0.9))
Gen.add(Activation('relu'))
# Input 14 x 14 x 128
# Output 28 x 28 x 64
Gen.add(UpSampling2D())
Gen.add(Conv2DTranspose(int(64), 5, padding='same',kernel_initializer="glorot_normal"))
Gen.add(BatchNormalization(momentum=0.9))
Gen.add(Activation('relu'))
# Input 28 x 28 x 64
# Output 28 x 28 x 32
Gen.add(Conv2DTranspose(int(32), 5, padding='same',kernel_initializer="glorot_normal"))
Gen.add(BatchNormalization(momentum=0.9))
Gen.add(Activation('relu'))
# Out: 28 x 28 x 1
Gen.add( Conv2DTranspose(1, 5, padding='same',kernel_initializer="glorot_normal"))
Gen.add( Activation('sigmoid'))
Gen.summary()
In [6]:
# Discriminator model
optimizer = Adam(lr=0.0002, beta_1=0.5)
DM = Sequential()
DM.add(Dis)
DM.compile(loss='binary_crossentropy', optimizer=optimizer,metrics=['accuracy'])
DM.summary()
When training the GAN we are searching for an equilibrium point, which is the optimal point in a minimax game:
In [7]:
# Adversarial model
optimizer = Adam(lr=0.0002, beta_1=0.5)
AM = Sequential()
AM.add(Gen)
AM.add(Dis)
AM.compile(loss='binary_crossentropy', optimizer=optimizer,metrics=['accuracy'])
AM.summary()
In [8]:
# Freeze weights in discriminator D for stacked training
def make_trainable(net, val):
net.trainable = val
for l in net.layers:
l.trainable = val
make_trainable(Dis, False)
The algorithm for training a GAN is the following:
In [9]:
train_steps=50000
batch_size=256
noise_input = None
for i in range(train_steps):
images_train = x_train[np.random.randint(0,x_train.shape[0], size=batch_size),:,:,:]
noise = np.random.normal(0.0, 1.0, size=[batch_size, 100])
images_fake = Gen.predict(noise)
make_trainable(Dis, True)
x = np.concatenate((images_train, images_fake))
y = np.ones([2*batch_size, 1])
y[batch_size:, :] = 0
d_loss = DM.train_on_batch(x, y)
make_trainable(Dis, False)
y = np.ones([batch_size, 1])
noise = np.random.normal(0.0, 1.0, size=[batch_size, 100])
a_loss = AM.train_on_batch(noise, y)
Gen.save('Generator_model.h5')
Based on the trained model we want to check whether the generator has learnt the correct images.
In [10]:
noise = np.random.normal(0.0, 1.0,size=[256,100])
generated_images = Gen.predict(noise)
for i in range(10):
pixels =generated_images[i]
pixels = pixels.reshape((28, 28))
# Plot
plt.imshow(pixels, cmap='gray')
plt.show()
Sources: