In [1]:
import keras
from keras.datasets import mnist
from keras.models import Model, Sequential
from keras.layers import *
from keras.optimizers import Adam
from tqdm import tqdm
from keras.layers.advanced_activations import ELU
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline


(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
print(X_train.shape, Y_train.shape, X_test.shape, Y_test.shape)


X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)

X_train = X_train.astype('float32')

# Scaling the range of the image to [-1, 1]
# Because we are using tanh as the activation function in the last layer of the generator
# and tanh restricts the weights in the range [-1, 1]
X_train = (X_train - 127.5) / 127.5

X_train.shape

generator = Sequential([
        Dense(128*7*7, input_shape=(100,)),
        ELU(0.8),
        BatchNormalization(),
        Reshape((7,7,128)),
        UpSampling2D(),
        Conv2D(64, (2,2), padding='same'),
        ELU(0.8),
        BatchNormalization(),
        UpSampling2D(),
        Conv2D(1, (3,3), padding='same', activation=keras.activations.tanh),
    ])
generator.summary()

discriminator = Sequential([
        Conv2D(64, (3,3), strides=(2,2), input_shape=(28,28,1), padding='same'),
        ELU(0.8),
        Dropout(0.3),
        Conv2D(128, (2,2), strides=(2,2), padding='same'),
        ELU(0.8),
        Dropout(0.3),
        Flatten(),
        Dense(1, activation='sigmoid')
    ])
discriminator.summary()

generator.compile(loss='binary_crossentropy', optimizer=Adam())
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.001))

gan_input = Input(shape=(100,))
gen_output = generator(gan_input)
dis_output = discriminator(gen_output)
gan = Model(inputs=gan_input, outputs=dis_output)
gan.compile(loss='binary_crossentropy', optimizer=Adam(0.0001))
gan.summary()

def plot_output():
    try_input = np.random.rand(25, 100)
    preds = generator.predict(try_input)

    plt.figure(figsize=(5,5))
    for i in range(preds.shape[0]):
        plt.subplot(5, 5, i+1)
        plt.imshow(preds[i, :, :, 0], cmap='gray')
        plt.axis('off')
    
    # tight_layout minimizes the overlap between 2 sub-plots
    plt.tight_layout()
    

def train(epoch=30, batch_size=128):
    batch_count = X_train.shape[0] // batch_size
    
    for i in range(epoch):
        dis_loss = 0
        gen_loss = 0
        for j in tqdm(range(batch_count)):
            # Input for the generator
            noise_input = np.random.rand(batch_size, 100)
            
            # getting random images from X_train of size=batch_size 
            # these are the real images that will be fed to the discriminator
            image_batch = X_train[np.random.randint(0, X_train.shape[0], size=batch_size)]
            
            # these are the predicted images from the generator
            predictions = generator.predict(noise_input, batch_size=batch_size)
            
            # the discriminator takes in the real images and the generated images
            X = np.concatenate([predictions, image_batch])
            
            # labels for the discriminator
            y_discriminator = [0]*batch_size + [1]*batch_size
            
            # Let's train the discriminator
            discriminator.trainable = True
            dis_loss += discriminator.train_on_batch(X, y_discriminator)
            
            # Let's train the generator
            noise_input = np.random.rand(batch_size, 100)
            y_generator = [1]*batch_size
            discriminator.trainable = False
            gen_loss += gan.train_on_batch(noise_input, y_generator)
        print(i, gen_loss, dis_loss)


Using TensorFlow backend.
(60000, 28, 28) (60000,) (10000, 28, 28) (10000,)
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 6272)              633472    
_________________________________________________________________
elu_1 (ELU)                  (None, 6272)              0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 6272)              25088     
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 128)         0         
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 64)        32832     
_________________________________________________________________
elu_2 (ELU)                  (None, 14, 14, 64)        0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 14, 14, 64)        256       
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 28, 28, 1)         577       
=================================================================
Total params: 692,225
Trainable params: 679,553
Non-trainable params: 12,672
_________________________________________________________________
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_3 (Conv2D)            (None, 14, 14, 64)        640       
_________________________________________________________________
elu_3 (ELU)                  (None, 14, 14, 64)        0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 7, 7, 128)         32896     
_________________________________________________________________
elu_4 (ELU)                  (None, 7, 7, 128)         0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 7, 7, 128)         0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 6272)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 6273      
=================================================================
Total params: 39,809
Trainable params: 39,809
Non-trainable params: 0
_________________________________________________________________
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 100)               0         
_________________________________________________________________
sequential_1 (Sequential)    (None, 28, 28, 1)         692225    
_________________________________________________________________
sequential_2 (Sequential)    (None, 1)                 39809     
=================================================================
Total params: 732,034
Trainable params: 719,362
Non-trainable params: 12,672
_________________________________________________________________

In [2]:
train()


100%|██████████| 468/468 [00:25<00:00, 17.82it/s]
  0%|          | 2/468 [00:00<00:26, 17.73it/s]
0 45.728616739 15.3082632502
100%|██████████| 468/468 [00:23<00:00, 20.02it/s]
  1%|          | 3/468 [00:00<00:22, 20.64it/s]
1 359.946834362 14.2421313586
100%|██████████| 468/468 [00:22<00:00, 20.48it/s]
  1%|          | 3/468 [00:00<00:22, 20.43it/s]
2 270.268580432 26.7404733138
100%|██████████| 468/468 [00:23<00:00, 20.34it/s]
  0%|          | 2/468 [00:00<00:23, 19.91it/s]
3 207.252195539 125.339413667
100%|██████████| 468/468 [00:23<00:00, 20.42it/s]
  0%|          | 2/468 [00:00<00:23, 19.81it/s]
4 244.330880871 183.634676274
100%|██████████| 468/468 [00:23<00:00, 20.34it/s]
  0%|          | 2/468 [00:00<00:23, 19.77it/s]
5 154.591438912 178.437432662
100%|██████████| 468/468 [00:22<00:00, 20.57it/s]
  1%|          | 3/468 [00:00<00:22, 20.24it/s]
6 263.467713242 191.27296399
100%|██████████| 468/468 [00:22<00:00, 20.48it/s]
  0%|          | 2/468 [00:00<00:24, 19.41it/s]
7 178.767658334 184.590664051
100%|██████████| 468/468 [00:22<00:00, 20.51it/s]
  0%|          | 2/468 [00:00<00:23, 19.95it/s]
8 444.643265292 256.853784651
100%|██████████| 468/468 [00:22<00:00, 20.40it/s]
  0%|          | 2/468 [00:00<00:23, 19.58it/s]
9 157.054298878 165.443401635
100%|██████████| 468/468 [00:22<00:00, 20.48it/s]
  0%|          | 2/468 [00:00<00:23, 19.65it/s]
10 284.274783287 196.384821415
100%|██████████| 468/468 [00:22<00:00, 20.44it/s]
  1%|          | 3/468 [00:00<00:22, 20.34it/s]
11 377.572860263 230.750669181
100%|██████████| 468/468 [00:22<00:00, 20.32it/s]
  0%|          | 2/468 [00:00<00:23, 19.91it/s]
12 290.769304315 185.583918795
100%|██████████| 468/468 [00:22<00:00, 20.53it/s]
  0%|          | 2/468 [00:00<00:23, 19.79it/s]
13 330.230689708 145.313964101
100%|██████████| 468/468 [00:22<00:00, 20.46it/s]
  0%|          | 2/468 [00:00<00:23, 19.79it/s]
14 443.009476949 188.566826776
100%|██████████| 468/468 [00:22<00:00, 20.51it/s]
  1%|          | 3/468 [00:00<00:22, 20.61it/s]
15 534.515924342 218.49495174
100%|██████████| 468/468 [00:22<00:00, 20.50it/s]
  0%|          | 2/468 [00:00<00:23, 19.95it/s]
16 635.477267072 206.870078653
100%|██████████| 468/468 [00:22<00:00, 20.57it/s]
  0%|          | 2/468 [00:00<00:23, 19.87it/s]
17 549.681995727 175.507111285
100%|██████████| 468/468 [00:22<00:00, 20.59it/s]
  1%|          | 3/468 [00:00<00:22, 20.62it/s]
18 433.982109614 146.780628055
100%|██████████| 468/468 [00:22<00:00, 20.56it/s]
  0%|          | 2/468 [00:00<00:23, 19.98it/s]
19 679.951272696 179.956161603
100%|██████████| 468/468 [00:22<00:00, 20.49it/s]
  1%|          | 3/468 [00:00<00:22, 20.41it/s]
20 857.470400035 207.166205823
100%|██████████| 468/468 [00:22<00:00, 20.56it/s]
  1%|          | 3/468 [00:00<00:23, 20.02it/s]
21 826.314484775 211.347297013
100%|██████████| 468/468 [00:22<00:00, 20.52it/s]
  1%|          | 3/468 [00:00<00:22, 20.39it/s]
22 731.507044226 173.80821453
100%|██████████| 468/468 [00:22<00:00, 20.58it/s]
  0%|          | 2/468 [00:00<00:24, 19.24it/s]
23 531.372854441 152.711646944
100%|██████████| 468/468 [00:22<00:00, 20.58it/s]
  1%|          | 3/468 [00:00<00:23, 20.19it/s]
24 788.935490608 173.497263253
100%|██████████| 468/468 [00:22<00:00, 20.65it/s]
  1%|          | 3/468 [00:00<00:22, 20.59it/s]
25 768.602527201 196.603465095
100%|██████████| 468/468 [00:22<00:00, 20.55it/s]
  1%|          | 3/468 [00:00<00:22, 20.84it/s]
26 827.01235348 226.080884427
100%|██████████| 468/468 [00:22<00:00, 20.58it/s]
  1%|          | 3/468 [00:00<00:22, 20.39it/s]
27 677.701851219 219.211761683
100%|██████████| 468/468 [00:22<00:00, 20.61it/s]
  1%|          | 3/468 [00:00<00:22, 20.82it/s]
28 696.187817067 210.243203133
100%|██████████| 468/468 [00:22<00:00, 20.65it/s]
29 730.016027868 230.822440714


In [3]:
plot_output()



In [ ]:


In [ ]: