Image Generation with DCGAN

This tutorial shows how to generate images using DCGAN. We use MNIST dataset for this tutorial, but any other dataset of reasonable size can be used.


In [ ]:
# If you run this notebook on Google Colab, uncomment and run the following to set up dependencies.
# !pip install nnabla-ext-cuda100
# !git clone https://github.com/sony/nnabla-examples.git
# %cd nnabla-examples

Let's start by importing dependencies.


In [ ]:
from __future__ import absolute_import
from six.moves import range

import numpy as np

import nnabla as nn
import nnabla.logger as logger
import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.solvers as S
import nnabla.utils.save as save

import os

from utils.neu.checkpoint_util import save_checkpoint, load_checkpoint
from utils.neu.save_nnp import save_nnp

Now let's define a function to download and load MNIST. This function will pass image-label pairs to DataSource class, which we will define next.


In [ ]:
import struct
import zlib

from nnabla.logger import logger
from nnabla.utils.data_iterator import data_iterator
from nnabla.utils.data_source import DataSource
from nnabla.utils.data_source_loader import download


def load_mnist(train=True):
    '''
    Load MNIST dataset images and labels from the original page by Yan LeCun or the cache file.
    Args:
        train (bool): The testing dataset will be returned if False. Training data has 60000 images, while testing has 10000 images.
    Returns:
        numpy.ndarray: A shape of (#images, 1, 28, 28). Values in [0.0, 1.0].
        numpy.ndarray: A shape of (#images, 1). Values in {0, 1, ..., 9}.
    '''
    if train:
        image_uri = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
        label_uri = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
    else:
        image_uri = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'
        label_uri = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'
    logger.info('Getting label data from {}.'.format(label_uri))
    r = download(label_uri)
    data = zlib.decompress(r.read(), zlib.MAX_WBITS | 32)
    _, size = struct.unpack('>II', data[0:8])
    labels = np.frombuffer(data[8:], np.uint8).reshape(-1, 1)
    r.close()
    logger.info('Getting label data done.')

    logger.info('Getting image data from {}.'.format(image_uri))
    r = download(image_uri)
    data = zlib.decompress(r.read(), zlib.MAX_WBITS | 32)
    _, size, height, width = struct.unpack('>IIII', data[0:16])
    images = np.frombuffer(data[16:], np.uint8).reshape(
        size, 1, height, width)
    r.close()
    logger.info('Getting image data done.')

    return images, labels

Now we define a data iterator to pass the images and labels to actual computation graphs.


In [ ]:
class MnistDataSource(DataSource):
    '''
    Get data directly from MNIST dataset from Internet(yann.lecun.com).
    '''

    def _get_data(self, position):
        image = self._images[self._indexes[position]]
        label = self._labels[self._indexes[position]]
        return (image, label)

    def __init__(self, train=True, shuffle=False, rng=None):
        super(MnistDataSource, self).__init__(shuffle=shuffle)
        self._train = train

        self._images, self._labels = load_mnist(train)

        self._size = self._labels.size
        self._variables = ('x', 'y')
        if rng is None:
            rng = np.random.RandomState(313)
        self.rng = rng
        self.reset()

    def reset(self):
        if self._shuffle:
            self._indexes = self.rng.permutation(self._size)
        else:
            self._indexes = np.arange(self._size)
        super(MnistDataSource, self).reset()

    @property
    def images(self):
        """Get copy of whole data with a shape of (N, 1, H, W)."""
        return self._images.copy()

    @property
    def labels(self):
        """Get copy of whole label with a shape of (N, 1)."""
        return self._labels.copy()

def data_iterator_mnist(batch_size,
                        train=True,
                        rng=None,
                        shuffle=True,
                        with_memory_cache=False,
                        with_file_cache=False):
    '''
    Provide DataIterator with :py:class:`MnistDataSource`
    with_memory_cache and with_file_cache option's default value is all False,
    because :py:class:`MnistDataSource` is able to store all data into memory.
    For example,
    .. code-block:: python
        with data_iterator_mnist(True, batch_size) as di:
            for data in di:
                SOME CODE TO USE data.
    '''
    return data_iterator(MnistDataSource(train=train, shuffle=shuffle, rng=rng),
                         batch_size,
                         rng,
                         with_memory_cache,
                         with_file_cache)

Generative adversarial networks and its variations, including DCGAN, have adversarial setting, in which generator network and discriminator network compete against each other. Let's define our generator network first. We implement the network with 4 consecutive deconvolution layers, each of which is followed by batch normalization and ELU non-linear activation. Then, we apply convolutional layer followed by hyperbolic tangent.


In [ ]:
def generator(z, maxh=256, test=False, output_hidden=False):
    """
    Building generator network which takes (B, Z, 1, 1) inputs and generates
    (B, 1, 28, 28) outputs.
    """
    # Define shortcut functions
    def bn(x):
        # Batch normalization
        return PF.batch_normalization(x, batch_stat=not test)

    def upsample2(x, c):
        # Twice upsampling with deconvolution.
        return PF.deconvolution(x, c, kernel=(4, 4), pad=(1, 1), stride=(2, 2), with_bias=False)

    assert maxh / 4 > 0
    with nn.parameter_scope("gen"):
        # (Z, 1, 1) --> (256, 4, 4)
        with nn.parameter_scope("deconv1"):
            d1 = F.elu(bn(PF.deconvolution(z, maxh, (4, 4), with_bias=False)))
        # (256, 4, 4) --> (128, 8, 8)
        with nn.parameter_scope("deconv2"):
            d2 = F.elu(bn(upsample2(d1, maxh / 2)))
        # (128, 8, 8) --> (64, 16, 16)
        with nn.parameter_scope("deconv3"):
            d3 = F.elu(bn(upsample2(d2, maxh / 4)))
        # (64, 16, 16) --> (32, 28, 28)
        with nn.parameter_scope("deconv4"):
            # Convolution with kernel=4, pad=3 and stride=2 transforms a 28 x 28 map
            # to a 16 x 16 map. Deconvolution with those parameters behaves like an
            # inverse operation, i.e. maps 16 x 16 to 28 x 28.
            d4 = F.elu(bn(PF.deconvolution(
                d3, maxh / 8, (4, 4), pad=(3, 3), stride=(2, 2), with_bias=False)))
        # (32, 28, 28) --> (1, 28, 28)
        with nn.parameter_scope("conv5"):
            x = F.tanh(PF.convolution(d4, 1, (3, 3), pad=(1, 1)))
    if output_hidden:
        return x, [d1, d2, d3, d4]
    return x

We then define the other part of adversarial setting, discriminator network. We can think of discriminator as the reverse of generator network, where we have 4 consecutive convolutional layers instead of deconvolution layers. All convolutional layers are again followed by batch normalization and ELU activation, except for the last convolutional layer. Finally, we apply affine.


In [ ]:
def discriminator(x, maxh=256, test=False, output_hidden=False):
    """
    Building discriminator network which maps a (B, 1, 28, 28) input to
    a (B, 1).
    """
    # Define shortcut functions
    def bn(xx):
        # Batch normalization
        return PF.batch_normalization(xx, batch_stat=not test)

    def downsample2(xx, c):
        return PF.convolution(xx, c, (3, 3), pad=(1, 1), stride=(2, 2), with_bias=False)

    assert maxh / 8 > 0
    with nn.parameter_scope("dis"):
        # (1, 28, 28) --> (32, 16, 16)
        with nn.parameter_scope("conv1"):
            c1 = F.elu(bn(PF.convolution(x, maxh / 8,
                                         (3, 3), pad=(3, 3), stride=(2, 2), with_bias=False)))
        # (32, 16, 16) --> (64, 8, 8)
        with nn.parameter_scope("conv2"):
            c2 = F.elu(bn(downsample2(c1, maxh / 4)))
        # (64, 8, 8) --> (128, 4, 4)
        with nn.parameter_scope("conv3"):
            c3 = F.elu(bn(downsample2(c2, maxh / 2)))
        # (128, 4, 4) --> (256, 4, 4)
        with nn.parameter_scope("conv4"):
            c4 = bn(PF.convolution(c3, maxh, (3, 3),
                                   pad=(1, 1), with_bias=False))
        # (256, 4, 4) --> (1,)
        with nn.parameter_scope("fc1"):
            f = PF.affine(c4, 1)
    if output_hidden:
        return f, [c1, c2, c3, c4]
    return f

Now we are ready to get into the training part. Let's first define the context to use GPU and define hyperparameters. We also need to define the noise variable z, which is fed into the generator network to generate fake images, which in turn will be fed to the discriminator network. We define separate losses for generator and discriminator networks, both with sigmoid cross entropy.


In [ ]:
# Get context.
from nnabla.ext_utils import get_extension_context
ctx = get_extension_context('cudnn')
nn.set_default_context(ctx)

# Create CNN network for both training and testing.
# TRAIN

# Fake path
batch_size = 64
learning_rate = 0.0002
max_iter = 20000
weight_decay = 0.0001

z = nn.Variable([batch_size, 100, 1, 1])
fake = generator(z)
fake.persistent = True  # Not to clear at backward
pred_fake = discriminator(fake)
loss_gen = F.mean(F.sigmoid_cross_entropy(
    pred_fake, F.constant(1, pred_fake.shape)))
fake_dis = fake.get_unlinked_variable(need_grad=True)
pred_fake_dis = discriminator(fake_dis)
loss_dis = F.mean(F.sigmoid_cross_entropy(
    pred_fake_dis, F.constant(0, pred_fake_dis.shape)))

Likewise, let's define a variable for real images, also to be input to discriminator network. Note that discriminator loss will have to account for both fake and real images.


In [ ]:
# Real path
x = nn.Variable([batch_size, 1, 28, 28])
pred_real = discriminator(x)
loss_dis += F.mean(F.sigmoid_cross_entropy(pred_real,
                                           F.constant(1, pred_real.shape)))

We also define separate solvers for generator and discriminator. Following the paper, we used Adam for both. Let's also define monitor variables to keep track of the progress. This will also save some of the generated fake images.


In [ ]:
# Create Solver.
solver_gen = S.Adam(learning_rate, beta1=0.5)
solver_dis = S.Adam(learning_rate, beta1=0.5)
with nn.parameter_scope("gen"):
    solver_gen.set_parameters(nn.get_parameters())
with nn.parameter_scope("dis"):
    solver_dis.set_parameters(nn.get_parameters())
start_point = 0

# If necessary, load weights and solver state info from specified checkpoint files.
# start_point = load_checkpoint(
#     specified_checkpoint, {"gen": solver_gen, "dis": solver_dis})

# Create monitor.
import nnabla.monitor as M
monitor_path = 'tmp.monitor.dcgan'
monitor = M.Monitor(monitor_path)
monitor_loss_gen = M.MonitorSeries("Generator loss", monitor, interval=10)
monitor_loss_dis = M.MonitorSeries(
    "Discriminator loss", monitor, interval=10)
monitor_time = M.MonitorTimeElapsed("Time", monitor, interval=100)
monitor_fake = M.MonitorImageTile(
    "Fake images", monitor, normalize_method=lambda x: (x + 1) / 2.)

We are now ready to go! We call the data iterator that we defined earlier, and define training loop, in which we alternate between generator update and discriminator update.


In [ ]:
data = data_iterator_mnist(batch_size, True)

# Training loop.
for i in range(start_point, max_iter):

    # Training forward
    image, _ = data.next()
    x.d = image / 255. - 0.5  # [0, 255] to [-1, 1]
    z.d = np.random.randn(*z.shape)

    # Generator update.
    solver_gen.zero_grad()
    loss_gen.forward(clear_no_need_grad=True)
    loss_gen.backward(clear_buffer=True)
    solver_gen.weight_decay(weight_decay)
    solver_gen.update()
    monitor_fake.add(i, fake)
    monitor_loss_gen.add(i, loss_gen.d.copy())

    # Discriminator update.
    solver_dis.zero_grad()
    loss_dis.forward(clear_no_need_grad=True)
    loss_dis.backward(clear_buffer=True)
    solver_dis.weight_decay(weight_decay)
    solver_dis.update()
    monitor_loss_dis.add(i, loss_dis.d.copy())
    monitor_time.add(i)

Now that we're done training, let's see how the generated fake images evolved throughout the training!


In [ ]:
from IPython.display import Image, display
for i in range(20):
    print("At iteration",(i+1)*1000-1)
    display(Image('tmp.monitor.dcgan/Fake-images/{:06d}.png'.format((i+1)*1000-1)))