In [1]:
from __future__ import absolute_import, division, print_function

# Import TensorFlow >= 1.10 and enable eager execution
import tensorflow as tf
tf.enable_eager_execution()

import os
import time
import numpy as np
import glob
import matplotlib.pyplot as plt
import PIL
import imageio
from IPython import display


C:\ProgramData\Anaconda3\lib\site-packages\h5py\__init__.py:34: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters

In [2]:
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
# We are normalizing the images to the range of [-1, 1]
train_images = (train_images - 127.5) / 127.5
BUFFER_SIZE = 60000
BATCH_SIZE = 256
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

In [3]:
class Generator(tf.keras.Model):
  def __init__(self):
    super(Generator, self).__init__()
    self.fc1 = tf.keras.layers.Dense(7*7*64, use_bias=False)
    self.batchnorm1 = tf.keras.layers.BatchNormalization()
    
    self.conv1 = tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(1, 1), padding='same', use_bias=False)
    self.batchnorm2 = tf.keras.layers.BatchNormalization()
    
    self.conv2 = tf.keras.layers.Conv2DTranspose(32, (5, 5), strides=(2, 2), padding='same', use_bias=False)
    self.batchnorm3 = tf.keras.layers.BatchNormalization()
    
    self.conv3 = tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False)

  def call(self, x, training=True):
    x = self.fc1(x)
    x = self.batchnorm1(x, training=training)
    x = tf.nn.relu(x)

    x = tf.reshape(x, shape=(-1, 7, 7, 64))

    x = self.conv1(x)
    x = self.batchnorm2(x, training=training)
    x = tf.nn.relu(x)

    x = self.conv2(x)
    x = self.batchnorm3(x, training=training)
    x = tf.nn.relu(x)

    x = tf.nn.tanh(self.conv3(x))  
    return x

In [4]:
class Discriminator(tf.keras.Model):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.conv1 = tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')
    self.conv2 = tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')
    self.dropout = tf.keras.layers.Dropout(0.3)
    self.flatten = tf.keras.layers.Flatten()
    self.fc1 = tf.keras.layers.Dense(1)

  def call(self, x, training=True):
    x = tf.nn.leaky_relu(self.conv1(x))
    x = self.dropout(x, training=training)
    x = tf.nn.leaky_relu(self.conv2(x))
    x = self.dropout(x, training=training)
    x = self.flatten(x)
    x = self.fc1(x)
    return x

In [5]:
generator = Generator()
discriminator = Discriminator()

# Defun gives 10 secs/epoch performance boost# Defun 
generator.call = tf.contrib.eager.defun(generator.call)
discriminator.call = tf.contrib.eager.defun(discriminator.call)

def discriminator_loss(real_output, generated_output):
    # [1,1,...,1] with real output since it is true and we want
    # our generated examples to look like it
    real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(real_output), logits=real_output)

    # [0,0,...,0] with generated images since they are fake
    generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.zeros_like(generated_output), logits=generated_output)

    total_loss = real_loss + generated_loss

    return total_loss
def generator_loss(generated_output):
    return tf.losses.sigmoid_cross_entropy(tf.ones_like(generated_output), generated_output)

discriminator_optimizer = tf.train.AdamOptimizer(1e-4)
generator_optimizer = tf.train.AdamOptimizer(1e-4)

In [6]:
EPOCHS = 100
noise_dim = 100
num_examples_to_generate = 16

# keeping the random vector constant for generation (prediction) so
# it will be easier to see the improvement of the gan.
random_vector_for_generation = tf.random_normal([num_examples_to_generate,
                                                 noise_dim])

def generate_and_save_images(model, epoch, test_input):
  # make sure the training parameter is set to False because we
  # don't want to train the batchnorm layer when doing inference.
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4,4))
  
  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')
        
  plt.savefig('gan_image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()
    
def train(dataset, epochs, noise_dim):  
  for epoch in range(epochs):
    start = time.time()
    
    for images in dataset:
      # generating noise from a uniform distribution
      noise = tf.random_normal([BATCH_SIZE, noise_dim])
      
      with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
      
        real_output = discriminator(images, training=True)
        generated_output = discriminator(generated_images, training=True)
        
        gen_loss = generator_loss(generated_output)
        disc_loss = discriminator_loss(real_output, generated_output)
        
      gradients_of_generator = gen_tape.gradient(gen_loss, generator.variables)
      gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.variables)
      
      generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables))
      discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))

      
    if epoch % 1 == 0:
      display.clear_output(wait=True)
      generate_and_save_images(generator,
                               epoch + 1,
                               random_vector_for_generation)
    
    # saving (checkpoint) the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)
    
    print ('Time taken for epoch {} is {} sec'.format(epoch + 1,
                                                      time.time()-start))
  # generating after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           random_vector_for_generation)

In [ ]:
train(train_dataset, EPOCHS, noise_dim)


---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-7-595742673dee> in <module>()
----> 1 train(train_dataset, EPOCHS, noise_dim)

<ipython-input-6-5f713d666945> in train(dataset, epochs, noise_dim)
     32 
     33       with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
---> 34         generated_images = generator(noise, training=True)
     35 
     36         real_output = discriminator(images, training=True)

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in __call__(self, inputs, *args, **kwargs)
    734 
    735       if not in_deferred_mode:
--> 736         outputs = self.call(inputs, *args, **kwargs)
    737         if outputs is None:
    738           raise ValueError('A layer\'s `call` method should return a Tensor '

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwds)
    851     """Calls a graph function specialized for this input signature."""
    852     graph_function, inputs = self._maybe_define_function(*args, **kwds)
--> 853     return graph_function(*inputs)
    854 
    855   def call_python_function(self, *args, **kwargs):

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py in __call__(self, *args)
    595       if self._backward_function is None:
    596         self._construct_backprop_function()
--> 597       return self._backprop_call(tensor_inputs)
    598 
    599     ctx = context.context()

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py in _backprop_call(self, args)
    528     all_args = args + self._extra_inputs
    529     ctx = context.context()
--> 530     outputs = self._forward_fdef.call(ctx, all_args, self._output_shapes)
    531     if isinstance(outputs, ops.Operation) or outputs is None:
    532       return outputs

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, output_shapes)
    368           f=self,
    369           tout=self._output_types,
--> 370           executing_eagerly=executing_eagerly)
    371 
    372     if executing_eagerly:

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\ops\functional_ops.py in partitioned_call(args, f, tout, executing_eagerly)
    976     if f.stateful_ops:
    977       outputs = gen_functional_ops.stateful_partitioned_call(
--> 978           args=args, Tout=tout, f=f)
    979     else:
    980       outputs = gen_functional_ops.partitioned_call(args=args, Tout=tout, f=f)

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\ops\gen_functional_ops.py in stateful_partitioned_call(args, Tout, f, name)
    453         _ctx._context_handle, _ctx._eager_context.device_name,
    454         "StatefulPartitionedCall", name, _ctx._post_execution_callbacks, args,
--> 455         "Tout", Tout, "f", f)
    456       return _result
    457     except _core._FallbackException:

KeyboardInterrupt: 

In [ ]:
def display_image(epoch_no):
  return PIL.Image.open('gan_image_at_epoch_{:04d}.png'.format(epoch_no))

display_image(EPOCHS)