Licensed under the Apache License, Version 2.0 (the "License").
|
This notebook demonstrates how to generate images of handwritten digits using tf.keras and eager execution. To do so, we use Deep Convolutional Generative Adverserial Networks (DCGAN).
This model takes about ~30 seconds per epoch (using tf.contrib.eager.defun to create graph functions) to train on a single Tesla K80 on Colab, as of July 2018.
Below is the output generated after training the generator and discriminator models for 150 epochs.
In [0]:
# to generate gifs
!pip install imageio
In [0]:
from __future__ import absolute_import, division, print_function
# Import TensorFlow >= 1.9 and enable eager execution
import tensorflow as tf
tf.enable_eager_execution()
import os
import time
import numpy as np
import glob
import matplotlib.pyplot as plt
import PIL
import imageio
from IPython import display
In [0]:
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
In [0]:
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
# We are normalizing the images to the range of [-1, 1]
train_images = (train_images - 127.5) / 127.5
In [0]:
BUFFER_SIZE = 60000
BATCH_SIZE = 256
In [0]:
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
Generator
Discriminator
In [0]:
class Generator(tf.keras.Model):
def __init__(self):
super(Generator, self).__init__()
self.fc1 = tf.keras.layers.Dense(7*7*64, use_bias=False)
self.batchnorm1 = tf.keras.layers.BatchNormalization()
self.conv1 = tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(1, 1), padding='same', use_bias=False)
self.batchnorm2 = tf.keras.layers.BatchNormalization()
self.conv2 = tf.keras.layers.Conv2DTranspose(32, (5, 5), strides=(2, 2), padding='same', use_bias=False)
self.batchnorm3 = tf.keras.layers.BatchNormalization()
self.conv3 = tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False)
def call(self, x, training=True):
x = self.fc1(x)
x = self.batchnorm1(x, training=training)
x = tf.nn.relu(x)
x = tf.reshape(x, shape=(-1, 7, 7, 64))
x = self.conv1(x)
x = self.batchnorm2(x, training=training)
x = tf.nn.relu(x)
x = self.conv2(x)
x = self.batchnorm3(x, training=training)
x = tf.nn.relu(x)
x = tf.nn.tanh(self.conv3(x))
return x
In [0]:
class Discriminator(tf.keras.Model):
def __init__(self):
super(Discriminator, self).__init__()
self.conv1 = tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')
self.conv2 = tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')
self.dropout = tf.keras.layers.Dropout(0.3)
self.flatten = tf.keras.layers.Flatten()
self.fc1 = tf.keras.layers.Dense(1)
def call(self, x, training=True):
x = tf.nn.leaky_relu(self.conv1(x))
x = self.dropout(x, training=training)
x = tf.nn.leaky_relu(self.conv2(x))
x = self.dropout(x, training=training)
x = self.flatten(x)
x = self.fc1(x)
return x
In [0]:
generator = Generator()
discriminator = Discriminator()
In [0]:
# Defun gives 10 secs/epoch performance boost
generator.call = tf.contrib.eager.defun(generator.call)
discriminator.call = tf.contrib.eager.defun(discriminator.call)
Discriminator loss
Generator loss
In [0]:
def discriminator_loss(real_output, generated_output):
# [1,1,...,1] with real output since it is true and we want
# our generated examples to look like it
real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(real_output), logits=real_output)
# [0,0,...,0] with generated images since they are fake
generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.zeros_like(generated_output), logits=generated_output)
total_loss = real_loss + generated_loss
return total_loss
In [0]:
def generator_loss(generated_output):
return tf.losses.sigmoid_cross_entropy(tf.ones_like(generated_output), generated_output)
In [0]:
discriminator_optimizer = tf.train.AdamOptimizer(1e-4)
generator_optimizer = tf.train.AdamOptimizer(1e-4)
In [0]:
EPOCHS = 150
noise_dim = 100
num_examples_to_generate = 100
# keeping the random vector constant for generation (prediction) so
# it will be easier to see the improvement of the gan.
random_vector_for_generation = tf.random_normal([num_examples_to_generate,
noise_dim])
In [0]:
def generate_and_save_images(model, epoch, test_input):
# make sure the training parameter is set to False because we
# don't want to train the batchnorm layer when doing inference.
predictions = model(test_input, training=False)
fig = plt.figure(figsize=(10,10))
for i in range(predictions.shape[0]):
plt.subplot(10, 10, i+1)
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')
# tight_layout minimizes the overlap between 2 sub-plots
plt.tight_layout()
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
plt.show()
In [0]:
def train(dataset, epochs, noise_dim):
for epoch in range(epochs):
start = time.time()
for images in dataset:
# generating noise from a uniform distribution
noise = tf.random_normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
generated_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(generated_output)
disc_loss = discriminator_loss(real_output, generated_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))
if epoch % 10 == 0:
display.clear_output(wait=True)
generate_and_save_images(generator,
epoch + 1,
random_vector_for_generation)
print ('Time taken for epoch {} is {} sec'.format(epoch + 1,
time.time()-start))
# generating after the final epoch
generate_and_save_images(generator,
epochs,
random_vector_for_generation)
In [0]:
train(train_dataset, EPOCHS, noise_dim)
In [0]:
def display_image(epoch_no):
plt.figure(figsize=(15,15))
plt.imshow(np.array(PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))))
plt.axis('off')
In [0]:
display_image(EPOCHS)
In [0]:
with imageio.get_writer('dcgan.gif', mode='I') as writer:
filenames = glob.glob('image*.png')
filenames = sorted(filenames)
for filename in filenames:
image = imageio.imread(filename)
writer.append_data(image)
# this is a hack to display the gif inside the notebook
os.system('mv dcgan.gif dcgan.gif.png')
In [0]:
display.Image(filename="dcgan.gif.png")
In [0]: