In [1]:
import numpy as np
import h5py
import keras
from keras.models import Model, Sequential
from keras.layers import Input, Activation, Conv2D, Reshape, Dense, BatchNormalization, Dropout, Flatten
from keras.layers import UpSampling2D, Conv2DTranspose, AveragePooling2D # new!
from keras.optimizers import RMSprop
from matplotlib import pyplot as plt
%matplotlib inline
In [2]:
data = np.load('../quickdraw/baseball.npy')
data = data/255
data = np.reshape(data,(data.shape[0],28,28,1))
img_w,img_h = data.shape[1:3]
data.shape
Out[2]:
In [12]:
def discriminator_builder(dim=64,p=0.4):
# Define inputs
inputs = Input((img_w, img_h, 1))
# Convolutional layers
conv1 = Conv2D(dim*1, 5, strides=2, padding='same', activation='relu')(inputs)
conv1 = Dropout(p)(conv1)
conv2 = Conv2D(dim*2, 5, strides=2, padding='same', activation='relu')(conv1)
conv2 = Dropout(p)(conv2)
conv3 = Conv2D(dim*4, 5, strides=2, padding='same', activation='relu')(conv2)
conv3 = Dropout(p)(conv3)
conv4 = Conv2D(dim*8, 5, strides=1, padding='same', activation='relu')(conv3)
conv4 = Dropout(p)(conv4)
conv4 = Flatten()(conv4)
outputs = Dense(1, activation='sigmoid')(conv4)
model = Model(inputs=inputs, outputs=outputs)
model.summary()
return model
In [13]:
# Compile discriminator:
discriminator_model = discriminator_builder()
In [5]:
discriminator_model.compile(loss='binary_crossentropy', optimizer=RMSprop(lr=0.0002, decay=6e-18))
In [6]:
def generator_builder(z_dim=100,dim=64,p=0.4):
# Define inputs
inputs = Input((z_dim,))
# First dense layer
dense1 = Dense(7*7*64)(inputs)
dense1 = BatchNormalization(axis=-1,momentum=0.9)(dense1)
dense1 = Activation(activation='relu')(dense1)
dense1 = Reshape((7,7,64))(dense1)
dense1 = Dropout(p)(dense1)
# Deconvolutional layers
conv1 = UpSampling2D()(dense1)
conv1 = Conv2DTranspose(int(dim/2), kernel_size=5, padding='same', activation=None)(conv1)
conv1 = BatchNormalization(axis=-1, momentum=0.9)(conv1)
conv1 = Activation(activation='relu')(conv1)
conv2 = UpSampling2D()(conv1)
conv2 = Conv2DTranspose(int(dim/4), kernel_size=5, padding='same', activation=None)(conv2)
conv2 = BatchNormalization(axis=-1, momentum=0.9)(conv2)
conv2 = Activation(activation='relu')(conv2)
# conv3 = UpSampling2D()(conv2)
conv3 = Conv2DTranspose(int(dim/8), kernel_size=5, padding='same', activation=None)(conv2)
conv3 = BatchNormalization(axis=-1, momentum=0.9)(conv3)
conv3 = Activation(activation='relu')(conv3)
# Define output layers
outputs = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(conv3)
# Model definition
model = Model(inputs=inputs, outputs=outputs)
model.summary()
return model
In [7]:
generator = generator_builder()
In [14]:
def adversarial_builder(z_dim=100):
model = Sequential()
model.add(generator)
model.add(discriminator_model)
model.compile(loss='binary_crossentropy', optimizer=RMSprop(lr=0.0001, decay=3e-8), metrics=['accuracy'])
model.summary()
return model
In [15]:
AM = adversarial_builder()
In [16]:
import os
output_dir = './images'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
In [ ]:
def train(epochs=2000,batch=128):
for i in range(epochs):
real_imgs = np.reshape(data[np.random.choice(data.shape[0],batch,replace=False)],(batch,28,28,1))
fake_imgs = generator.predict(np.random.uniform(-1.0, 1.0, size=[batch, 100]))
x = np.concatenate((real_imgs,fake_imgs))
y = np.ones([2*batch,1])
y[batch:,:] = 0
d_loss = discriminator_model.train_on_batch(x,y)
noise = np.random.uniform(-1.0, 1.0, size=[batch, 100])
y = np.ones([batch,1])
a_loss = AM.train_on_batch(noise,y)
if (i+1)%1000 == 0:
print('Epoch #{}'.format(i+1))
log_mesg = "%d: [D loss: %f, acc: %f]" % (i, d_loss[0], d_loss[1])
log_mesg = "%s [A loss: %f, acc: %f]" % (log_mesg, a_loss[0], a_loss[1])
print(log_mesg)
noise = np.random.uniform(-1.0, 1.0, size=[16, 100])
gen_imgs = generator.predict(noise)
plt.figure(figsize=(5,5))
for k in range(gen_imgs.shape[0]):
plt.subplot(4, 4, k+1)
plt.imshow(gen_imgs[k, :, :, 0], cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.show()
plt.savefig('./images/baseball_{}.png'.format(i+1))
In [ ]:
train(epochs=20000)
In [ ]: