NNabla Cifar-10 Training Tutorial

In this tutorial, we show how to train a classifier on Cifar-10 dataset using nnabla, including setting up data-iterator and network.


In [ ]:
# If you run this notebook on Google Colab, uncomment and run the following to set up dependencies.
# !pip install nnabla-ext-cuda100
# !git clone https://github.com/sony/nnabla-examples.git
# %cd nnabla-examples

Let's import dependencies first.


In [ ]:
import os, sys
import time
import nnabla as nn
from nnabla.ext_utils import get_extension_context
import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.solvers as S
import numpy as np
import functools
import nnabla.utils.save as save

from utils.neu.checkpoint_util import save_checkpoint, load_checkpoint
from utils.neu.save_nnp import save_nnp

We then define a data iterator for Cifar-10. When called, it'll also download the dataset, and pass the samples to the network during training.


In [ ]:
from contextlib import contextmanager
import struct
import tarfile
import zlib
import errno

from nnabla.logger import logger
from nnabla.utils.data_iterator import data_iterator
from nnabla.utils.data_source import DataSource
from nnabla.utils.data_source_loader import download, get_data_home


class Cifar10DataSource(DataSource):
    '''
    Get data directly from cifar10 dataset from Internet(yann.lecun.com).
    '''

    def _get_data(self, position):
        image = self._images[self._indexes[position]]
        label = self._labels[self._indexes[position]]
        return (image, label)


    def __init__(self, train=True, shuffle=False, rng=None):
        super(Cifar10DataSource, self).__init__(shuffle=shuffle, rng=rng)
        self._train = train
        data_uri = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
        logger.info('Getting labeled data from {}.'.format(data_uri))
        r = download(data_uri)  # file object returned
        with tarfile.open(fileobj=r, mode="r:gz") as fpin:
            # Training data
            if train:
                images = []
                labels = []
                for member in fpin.getmembers():
                    if "data_batch" not in member.name:
                        continue
                    fp = fpin.extractfile(member)
                    data = np.load(fp, encoding="bytes", allow_pickle=True)
                    images.append(data[b"data"])
                    labels.append(data[b"labels"])
                self._size = 50000
                self._images = np.concatenate(
                    images).reshape(self._size, 3, 32, 32)
                self._labels = np.concatenate(labels).reshape(-1, 1)
            # Validation data
            else:
                for member in fpin.getmembers():
                    if "test_batch" not in member.name:
                        continue
                    fp = fpin.extractfile(member)
                    data = np.load(fp, encoding="bytes", allow_pickle=True)
                    images = data[b"data"]
                    labels = data[b"labels"]
                self._size = 10000
                self._images = images.reshape(self._size, 3, 32, 32)
                self._labels = np.array(labels).reshape(-1, 1)
        r.close()
        logger.info('Getting labeled data from {}.'.format(data_uri))

        self._size = self._labels.size
        self._variables = ('x', 'y')
        if rng is None:
            rng = np.random.RandomState(313)
        self.rng = rng
        self.reset()

    def reset(self):
        if self._shuffle:
            self._indexes = self.rng.permutation(self._size)
        else:
            self._indexes = np.arange(self._size)
        super(Cifar10DataSource, self).reset()

    @property
    def images(self):
        """Get copy of whole data with a shape of (N, 1, H, W)."""
        return self._images.copy()

    @property
    def labels(self):
        """Get copy of whole label with a shape of (N, 1)."""
        return self._labels.copy()


def data_iterator_cifar10(batch_size,
                          train=True,
                          rng=None,
                          shuffle=True,
                          with_memory_cache=False,
                          with_file_cache=False):
    '''
    Provide DataIterator with :py:class:`Cifar10DataSource`
    with_memory_cache and with_file_cache option's default value is all False,
    because :py:class:`Cifar10DataSource` is able to store all data into memory.

    '''
    return data_iterator(Cifar10DataSource(train=train, shuffle=shuffle, rng=rng),
                         batch_size,
                         rng,
                         with_memory_cache,
                         with_file_cache)

In [ ]:
def categorical_error(pred, label):
    pred_label = pred.argmax(1)
    return (pred_label != label.flat).mean()

We now define our neural network. In this example, we employ a slightly modified architecture based on ResNet. We are also performing data augmentation here.


In [ ]:
def resnet23_prediction(image, test=False, ncls=10, nmaps=64, act=F.relu):
    """
    Construct ResNet 23
    """
    # Residual Unit
    def res_unit(x, scope_name, dn=False):
        C = x.shape[1]
        with nn.parameter_scope(scope_name):
            # Conv -> BN -> Nonlinear
            with nn.parameter_scope("conv1"):
                h = PF.convolution(x, C // 2, kernel=(1, 1), pad=(0, 0),
                                   with_bias=False)
                h = PF.batch_normalization(h, batch_stat=not test)
                h = act(h)
            # Conv -> BN -> Nonlinear
            with nn.parameter_scope("conv2"):
                h = PF.convolution(h, C // 2, kernel=(3, 3), pad=(1, 1),
                                   with_bias=False)
                h = PF.batch_normalization(h, batch_stat=not test)
                h = act(h)
            # Conv -> BN
            with nn.parameter_scope("conv3"):
                h = PF.convolution(h, C, kernel=(1, 1), pad=(0, 0),
                                   with_bias=False)
                h = PF.batch_normalization(h, batch_stat=not test)
            # Residual -> Nonlinear
            h = act(F.add2(h, x, inplace=True))
            # Maxpooling
            if dn:
                h = F.max_pooling(h, kernel=(2, 2), stride=(2, 2))
            return h
    # Conv -> BN -> Nonlinear
    with nn.parameter_scope("conv1"):
        # Preprocess
        if not test:
            image = F.image_augmentation(image, contrast=1.0,
                                         angle=0.25,
                                         flip_lr=True)
            image.need_grad = False
        h = PF.convolution(image, nmaps, kernel=(3, 3),
                           pad=(1, 1), with_bias=False)
        h = PF.batch_normalization(h, batch_stat=not test)
        h = act(h)

    h = res_unit(h, "conv2", False)    # -> 32x32
    h = res_unit(h, "conv3", True)     # -> 16x16
    h = res_unit(h, "conv4", False)    # -> 16x16
    h = res_unit(h, "conv5", True)     # -> 8x8
    h = res_unit(h, "conv6", False)    # -> 8x8
    h = res_unit(h, "conv7", True)     # -> 4x4
    h = res_unit(h, "conv8", False)    # -> 4x4
    h = F.average_pooling(h, kernel=(4, 4))  # -> 1x1
    pred = PF.affine(h, ncls)

    return pred

Define our loss function, which in this case is the mean of softmax cross entropy, computed from the predictions and the labels.


In [ ]:
def loss_function(pred, label):
    loss = F.mean(F.softmax_cross_entropy(pred, label))
    return loss

We are almost ready to start training! Let's define some hyper-parameters for the training.


In [ ]:
n_train_samples = 50000
batch_size = 64
bs_valid = 64 #batch size for validation
extension_module = 'cudnn'
ctx = get_extension_context(
    extension_module)
nn.set_default_context(ctx)
prediction = functools.partial(
    resnet23_prediction, ncls=10, nmaps=64, act=F.relu)

We then create our training and validation graphs. Note that labels are not provided for validation.


In [ ]:
# Create training graphs
test = False
image_train = nn.Variable((batch_size, 3, 32, 32))
label_train = nn.Variable((batch_size, 1))
pred_train = prediction(image_train, test)
loss_train = loss_function(pred_train, label_train)
input_image_train = {"image": image_train, "label": label_train}

# Create validation graph
test = True
image_valid = nn.Variable((bs_valid, 3, 32, 32))
pred_valid = prediction(image_valid, test)
input_image_valid = {"image": image_valid}

Let's also define our solver. We employ Adam in this example, but other solvers can be used too. Let's also define monitor variables to keep track of the progress during training. Note that, if you want to load previously saved weight parameters, you can load it using load_checkpoint.


In [ ]:
# Solvers
solver = S.Adam()
solver.set_parameters(nn.get_parameters())
start_point = 0

# If necessary, load weights and solver state info from specified checkpoint file.
# start_point = load_checkpoint(specified_checkpoint, solver)

# Create monitor
from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
monitor = Monitor('tmp.monitor')
monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
monitor_err = MonitorSeries("Training error", monitor, interval=10)
monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
monitor_verr = MonitorSeries("Test error", monitor, interval=1)

We define data iterator variables separately for training and validation, using the data iterator we defined earlier. Note that the second argument is different for each variable, depending whether it is for training or validation.


In [ ]:
# Data Iterator
tdata = data_iterator_cifar10(batch_size, True)
vdata = data_iterator_cifar10(batch_size, False)

# save intermediate weights if you need
#contents = save_nnp({'x': image_valid}, {'y': pred_valid}, batch_size)
#save.save(os.path.join('tmp.monitor',
#                       '{}_epoch0_result.nnp'.format('cifar10_resnet23')), contents)

We are good to go now! Start training, get a coffee, and watch how the training loss and test error decline as the training proceeds.


In [ ]:
max_iter = 40000
val_iter = 100
model_save_interval = 10000
model_save_path = 'tmp.monitor'
# Training-loop
for i in range(start_point, max_iter):
    # Validation
    if i % int(n_train_samples / batch_size) == 0:
        ve = 0.
        for j in range(val_iter):
            image, label = vdata.next()
            input_image_valid["image"].d = image
            pred_valid.forward()
            ve += categorical_error(pred_valid.d, label)
        ve /= val_iter
        monitor_verr.add(i, ve)
    if int(i % model_save_interval) == 0:
        # save checkpoint file
        save_checkpoint(model_save_path, i, solver)

    # Forward/Zerograd/Backward
    image, label = tdata.next()
    input_image_train["image"].d = image
    input_image_train["label"].d = label
    loss_train.forward()
    solver.zero_grad()
    loss_train.backward()

    # Solvers update
    solver.update()

    e = categorical_error(
        pred_train.d, input_image_train["label"].d)
    monitor_loss.add(i, loss_train.d.copy())
    monitor_err.add(i, e)
    monitor_time.add(i)

nn.save_parameters(os.path.join(model_save_path,
                                'params_%06d.h5' % (max_iter)))

# save_nnp_lastepoch
contents = save_nnp({'x': image_valid}, {'y': pred_valid}, batch_size)
save.save(os.path.join(model_save_path,
                       '{}_result.nnp'.format('cifar10_resnet23')), contents)