"Quick, Draw!" GAN


In [1]:
import numpy as np
import h5py

import keras
from keras.models import Model, Sequential
from keras.layers import Input, Activation, Conv2D, Reshape, Dense, BatchNormalization, Dropout, Flatten
from keras.layers import UpSampling2D, Conv2DTranspose, AveragePooling2D # new! 
from keras.optimizers import RMSprop

from matplotlib import pyplot as plt
%matplotlib inline


Using TensorFlow backend.

In [2]:
data = np.load('../quickdraw/baseball.npy')
data = data/255
data = np.reshape(data,(data.shape[0],28,28,1))
img_w,img_h = data.shape[1:3]
data.shape


Out[2]:
(126845, 28, 28, 1)

In [12]:
def discriminator_builder(dim=64,p=0.4):
    
    # Define inputs
    inputs = Input((img_w, img_h, 1))
    
    # Convolutional layers
    conv1 = Conv2D(dim*1, 5, strides=2, padding='same', activation='relu')(inputs)
    conv1 = Dropout(p)(conv1)
    
    conv2 = Conv2D(dim*2, 5, strides=2, padding='same', activation='relu')(conv1)
    conv2 = Dropout(p)(conv2)
    
    conv3 = Conv2D(dim*4, 5, strides=2, padding='same', activation='relu')(conv2)
    conv3 = Dropout(p)(conv3)
    
    conv4 = Conv2D(dim*8, 5, strides=1, padding='same', activation='relu')(conv3)
    conv4 = Dropout(p)(conv4)
    conv4 = Flatten()(conv4)
    
    outputs = Dense(1, activation='sigmoid')(conv4)
    
    model = Model(inputs=inputs, outputs=outputs)
    model.summary()
    
    return model

In [13]:
# Compile discriminator: 
discriminator_model = discriminator_builder()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_4 (InputLayer)         (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 14, 14, 64)        1664      
_________________________________________________________________
dropout_10 (Dropout)         (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 7, 7, 128)         204928    
_________________________________________________________________
dropout_11 (Dropout)         (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 4, 4, 256)         819456    
_________________________________________________________________
dropout_12 (Dropout)         (None, 4, 4, 256)         0         
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 4, 4, 512)         3277312   
_________________________________________________________________
dropout_13 (Dropout)         (None, 4, 4, 512)         0         
_________________________________________________________________
flatten_3 (Flatten)          (None, 8192)              0         
_________________________________________________________________
dense_4 (Dense)              (None, 1)                 8193      
=================================================================
Total params: 4,311,553
Trainable params: 4,311,553
Non-trainable params: 0
_________________________________________________________________

In [5]:
discriminator_model.compile(loss='binary_crossentropy', optimizer=RMSprop(lr=0.0002, decay=6e-18))

In [6]:
def generator_builder(z_dim=100,dim=64,p=0.4):
    
    # Define inputs
    inputs = Input((z_dim,))
    
    # First dense layer
    dense1 = Dense(7*7*64)(inputs)
    dense1 = BatchNormalization(axis=-1,momentum=0.9)(dense1)
    dense1 = Activation(activation='relu')(dense1)
    dense1 = Reshape((7,7,64))(dense1)
    dense1 = Dropout(p)(dense1)
    
    # Deconvolutional layers
    conv1 = UpSampling2D()(dense1)
    conv1 = Conv2DTranspose(int(dim/2), kernel_size=5, padding='same', activation=None)(conv1)
    conv1 = BatchNormalization(axis=-1, momentum=0.9)(conv1)
    conv1 = Activation(activation='relu')(conv1)
    
    conv2 = UpSampling2D()(conv1)
    conv2 = Conv2DTranspose(int(dim/4), kernel_size=5, padding='same', activation=None)(conv2)
    conv2 = BatchNormalization(axis=-1, momentum=0.9)(conv2)
    conv2 = Activation(activation='relu')(conv2)
    
#     conv3 = UpSampling2D()(conv2)
    conv3 = Conv2DTranspose(int(dim/8), kernel_size=5, padding='same', activation=None)(conv2)
    conv3 = BatchNormalization(axis=-1, momentum=0.9)(conv3)
    conv3 = Activation(activation='relu')(conv3)
    
    # Define output layers
    outputs = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(conv3)
    
    # Model definition    
    model = Model(inputs=inputs, outputs=outputs)
    
    model.summary()
    
    return model

In [7]:
generator = generator_builder()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         (None, 100)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 3136)              316736    
_________________________________________________________________
batch_normalization_1 (Batch (None, 3136)              12544     
_________________________________________________________________
activation_1 (Activation)    (None, 3136)              0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 64)          0         
_________________________________________________________________
dropout_5 (Dropout)          (None, 7, 7, 64)          0         
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 14, 14, 32)        51232     
_________________________________________________________________
batch_normalization_2 (Batch (None, 14, 14, 32)        128       
_________________________________________________________________
activation_2 (Activation)    (None, 14, 14, 32)        0         
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 28, 28, 32)        0         
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 28, 28, 16)        12816     
_________________________________________________________________
batch_normalization_3 (Batch (None, 28, 28, 16)        64        
_________________________________________________________________
activation_3 (Activation)    (None, 28, 28, 16)        0         
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 28, 28, 8)         3208      
_________________________________________________________________
batch_normalization_4 (Batch (None, 28, 28, 8)         32        
_________________________________________________________________
activation_4 (Activation)    (None, 28, 28, 8)         0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 28, 28, 1)         201       
=================================================================
Total params: 396,961
Trainable params: 390,577
Non-trainable params: 6,384
_________________________________________________________________

In [14]:
def adversarial_builder(z_dim=100):
    
    model = Sequential()

    model.add(generator)
    model.add(discriminator_model)
    
    model.compile(loss='binary_crossentropy', optimizer=RMSprop(lr=0.0001, decay=3e-8), metrics=['accuracy'])
    
    model.summary()

    return model

In [15]:
AM = adversarial_builder()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
model_2 (Model)              (None, 28, 28, 1)         396961    
_________________________________________________________________
model_4 (Model)              (None, 1)                 4311553   
=================================================================
Total params: 4,708,514
Trainable params: 4,702,130
Non-trainable params: 6,384
_________________________________________________________________

In [16]:
import os
output_dir = './images'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

In [ ]:
def train(epochs=2000,batch=128):
    for i in range(epochs):

        real_imgs = np.reshape(data[np.random.choice(data.shape[0],batch,replace=False)],(batch,28,28,1))
        fake_imgs = generator.predict(np.random.uniform(-1.0, 1.0, size=[batch, 100]))
        
        x = np.concatenate((real_imgs,fake_imgs))
        y = np.ones([2*batch,1])
        y[batch:,:] = 0
        
        d_loss = discriminator_model.train_on_batch(x,y)
        
        noise = np.random.uniform(-1.0, 1.0, size=[batch, 100])
        y = np.ones([batch,1])
        a_loss = AM.train_on_batch(noise,y)
                
        if (i+1)%1000 == 0:
            print('Epoch #{}'.format(i+1))
            log_mesg = "%d: [D loss: %f, acc: %f]" % (i, d_loss[0], d_loss[1])
            log_mesg = "%s  [A loss: %f, acc: %f]" % (log_mesg, a_loss[0], a_loss[1])
            print(log_mesg)
            noise = np.random.uniform(-1.0, 1.0, size=[16, 100])
            gen_imgs = generator.predict(noise)
            plt.figure(figsize=(5,5))
            
            for k in range(gen_imgs.shape[0]):
                plt.subplot(4, 4, k+1)
                plt.imshow(gen_imgs[k, :, :, 0], cmap='gray')
                plt.axis('off')
            
            plt.tight_layout()
            plt.show()
            plt.savefig('./images/baseball_{}.png'.format(i+1))

In [ ]:
train(epochs=20000)

In [ ]: