In this tutorial, we will show you how to perform feature embedding with Sieamese Neural Networks using NNabla. Siamese Neural Networks were originally proposed for one-shot image recognition task, and can also be useful for feature embedding, where it learns by considering what makes 2 images similar or dissimilar, as we will see below.

We will use images of MNIST for this tutorial. At the end of the tutorial, you will be able to visually see how the model has learned to embed each class of digit.


In [ ]:
# If you run this notebook on Google Colab, uncomment and run the following to set up dependencies.
# !pip install nnabla-ext-cuda100
#!git clone https://github.com/sony/nnabla-examples.git
# %cd nnabla-examples

Let's start by importing dependencies.


In [ ]:
from __future__ import absolute_import
from six.moves import range

from contextlib import contextmanager
import numpy as np
import os

import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.solver as S
from nnabla.logger import logger

from utils.neu.save_nnp import save_nnp
import nnabla.utils.save as save

Let's also define data iterator for MNIST. You can disregard the details for now.


In [ ]:
import struct
import zlib

from nnabla.logger import logger
from nnabla.utils.data_iterator import data_iterator
from nnabla.utils.data_source import DataSource
from nnabla.utils.data_source_loader import download


def load_mnist(train=True):
    '''
    Load MNIST dataset images and labels from the original page by Yan LeCun or the cache file.
    Args:
        train (bool): The testing dataset will be returned if False. Training data has 60000 images, while testing has 10000 images.
    Returns:
        numpy.ndarray: A shape of (#images, 1, 28, 28). Values in [0.0, 1.0].
        numpy.ndarray: A shape of (#images, 1). Values in {0, 1, ..., 9}.
    '''
    if train:
        image_uri = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
        label_uri = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
    else:
        image_uri = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'
        label_uri = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'
    logger.info('Getting label data from {}.'.format(label_uri))
    r = download(label_uri)
    data = zlib.decompress(r.read(), zlib.MAX_WBITS | 32)
    _, size = struct.unpack('>II', data[0:8])
    labels = np.frombuffer(data[8:], np.uint8).reshape(-1, 1)
    r.close()
    logger.info('Getting label data done.')

    logger.info('Getting image data from {}.'.format(image_uri))
    r = download(image_uri)
    data = zlib.decompress(r.read(), zlib.MAX_WBITS | 32)
    _, size, height, width = struct.unpack('>IIII', data[0:16])
    images = np.frombuffer(data[16:], np.uint8).reshape(
        size, 1, height, width)
    r.close()
    logger.info('Getting image data done.')

    return images, labels

class MnistDataSource(DataSource):
    '''
    Get data directly from MNIST dataset from Internet(yann.lecun.com).
    '''

    def _get_data(self, position):
        image = self._images[self._indexes[position]]
        label = self._labels[self._indexes[position]]
        return (image, label)

    def __init__(self, train=True, shuffle=False, rng=None):
        super(MnistDataSource, self).__init__(shuffle=shuffle)
        self._train = train

        self._images, self._labels = load_mnist(train)

        self._size = self._labels.size
        self._variables = ('x', 'y')
        if rng is None:
            rng = np.random.RandomState(313)
        self.rng = rng
        self.reset()

    def reset(self):
        if self._shuffle:
            self._indexes = self.rng.permutation(self._size)
        else:
            self._indexes = np.arange(self._size)
        super(MnistDataSource, self).reset()

    @property
    def images(self):
        """Get copy of whole data with a shape of (N, 1, H, W)."""
        return self._images.copy()

    @property
    def labels(self):
        """Get copy of whole label with a shape of (N, 1)."""
        return self._labels.copy()

def data_iterator_mnist(batch_size,
                        train=True,
                        rng=None,
                        shuffle=True,
                        with_memory_cache=False,
                        with_file_cache=False):
    '''
    Provide DataIterator with :py:class:`MnistDataSource`
    with_memory_cache and with_file_cache option's default value is all False,
    because :py:class:`MnistDataSource` is able to store all data into memory.
    For example,
    .. code-block:: python
        with data_iterator_mnist(True, batch_size) as di:
            for data in di:
                SOME CODE TO USE data.
    '''
    return data_iterator(MnistDataSource(train=train, shuffle=shuffle, rng=rng),
                         batch_size,
                         rng,
                         with_memory_cache,
                         with_file_cache)

Let's define a simple LeNet for MNIST. We stack two convolution layers with kernel size of 5x5, each of which is followed by non-linear activation function (ELU) and an average pooling layer. We then apply three fully-connected (affine) layers.


In [ ]:
def mnist_lenet_feature(image, test=False):
    """
    Construct LeNet for MNIST.
    """
    c1 = F.elu(PF.convolution(image, 20, (5, 5), name='conv1'))
    c1 = F.average_pooling(c1, (2, 2))
    c2 = F.elu(PF.convolution(c1, 50, (5, 5), name='conv2'))
    c2 = F.average_pooling(c2, (2, 2))
    c3 = F.elu(PF.affine(c2, 500, name='fc3'))
    c4 = PF.affine(c3, 10, name='fc4')
    c5 = PF.affine(c4, 2, name='fc_embed')
    return c5

Here we define a function that takes two images, inputs each image to LeNet defined above, and computes the squared error from the output features of each input image. In other words, we input two images to the same network that shares weights. This is why the model is called Siamese network!


In [ ]:
def mnist_lenet_siamese(x0, x1, test=False):
    """"""
    h0 = mnist_lenet_feature(x0, test)
    h1 = mnist_lenet_feature(x1, test)  # share weights
    h = F.squared_error(h0, h1)
    p = F.sum(h, axis=1)
    return p

Let's also define our loss function.


In [ ]:
def contrastive_loss(sd, l, margin=1.0, eps=1e-4):
    sim_cost = l * sd
    dissim_cost = (1 - l) * \
        (F.maximum_scalar(margin - (sd + eps) ** (0.5), 0) ** 2)
    return sim_cost + dissim_cost

Since Siamese Neural Networks take two images as inputs, we need to slightly modify the data iterator defined above so that it provides two images and corresponding labels.


In [ ]:
class MnistSiameseDataIterator(object):

    def __init__(self, itr0, itr1):
        self.itr0 = itr0
        self.itr1 = itr1

    def next(self):
        x0, l0 = self.itr0.next()
        x1, l1 = self.itr1.next()
        sim = (l0 == l1).astype(np.int).flatten()
        return x0 / 255., x1 / 255., sim


def siamese_data_iterator(batch_size, train, rng=None):
    itr0 = data_iterator_mnist(batch_size, train=train, rng=rng, shuffle=True)
    itr1 = data_iterator_mnist(batch_size, train=train, rng=rng, shuffle=True)
    return MnistSiameseDataIterator(itr0, itr1)

Before we start training, let's set context to use GPU.

We now define our computation graph for training, first by defining two variable for input images, and a variable for label. The image variables are fed into Siamese Lenet defined above, and the resulting prediction will be compared with the label to compute contrastive loss, which is also defined above. We can define a computation graph for validation in the same way.

Let's also set our solver and monitor variables to track the progress of training. We use Adam as our solver.


In [ ]:
# Get context.
from nnabla.ext_utils import get_extension_context
ctx = get_extension_context('cudnn')
nn.set_default_context(ctx)

# Create CNN network for both training and testing.
margin = 1.0  # Margin for contrastive loss.

# TRAIN
# Create input variables.
batch_size = 128
image0 = nn.Variable([batch_size, 1, 28, 28])
image1 = nn.Variable([batch_size, 1, 28, 28])
label = nn.Variable([batch_size])
# Create prediction graph.
pred = mnist_lenet_siamese(image0, image1, test=False)
# Create loss function.
loss = F.mean(contrastive_loss(pred, label, margin))

# TEST
# Create input variables.
vimage0 = nn.Variable([batch_size, 1, 28, 28])
vimage1 = nn.Variable([batch_size, 1, 28, 28])
vlabel = nn.Variable([batch_size])
# Create prediction graph.
vpred = mnist_lenet_siamese(vimage0, vimage1, test=True)
vloss = F.mean(contrastive_loss(vpred, vlabel, margin))

# Create Solver.
learning_rate = 1e-3
solver = S.Adam(learning_rate)
solver.set_parameters(nn.get_parameters())

start_point = 0

# Create monitor.
import nnabla.monitor as M
model_save_path = 'tmp.monitor.siamese'
monitor = M.Monitor(model_save_path)
monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=100)
monitor_vloss = M.MonitorSeries("Test loss", monitor, interval=10)

After setting up data iterator variables using the slightly modified version we defined earlier, we can start the training. The final parameters obtained at the end of the training will be used for visualization of the learned feature embedding space.


In [ ]:
rng = np.random.RandomState(313)
data = siamese_data_iterator(batch_size, True, rng)
vdata = siamese_data_iterator(batch_size, False, rng)

# Training loop.
max_iter = 5000
val_interval = 100
val_iter = 10
weight_decay = 0
for i in range(start_point, max_iter):
    if i % val_interval == 0:
        # Validation
        ve = 0.0
        for j in range(val_iter):
            vimage0.d, vimage1.d, vlabel.d = vdata.next()
            vloss.forward(clear_buffer=True)
            ve += vloss.d
        monitor_vloss.add(i, ve / val_iter)
    image0.d, image1.d, label.d = data.next()
    solver.zero_grad()
    # Training forward, backward and update
    loss.forward(clear_no_need_grad=True)
    loss.backward(clear_buffer=True)
    solver.weight_decay(weight_decay)
    solver.update()
    monitor_loss.add(i, loss.d.copy())
    monitor_time.add(i)

# Comment out if you want to save the parameters
# parameter_file = os.path.join(
#     model_save_path, 'params_%06d.h5' % max_iter)
# nn.save_parameters(parameter_file)

Let's visually confirm how each class of digit is represented in the feature embedding space that our model learned. We load 10,000 samples from MNIST, extract their features using LeNet, and plot it on a graph using dimensionality reduction technique called t-SNE. Each dot represents a sample, with each distinct color representing a unique class of digit. If the model was trained successfully, you should be able to see that the dots of the same color form a seemingly distinct group, which implies that the classification can be reliably performed using these features.


In [ ]:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

batch_size = 500

# Create embedded network
image = nn.Variable([batch_size, 1, 28, 28])
feature = mnist_lenet_feature(image, test=True)

features = []
labels = []

rng = np.random.RandomState(313)
data = data_iterator_mnist(batch_size, train=False, rng=rng, shuffle=True)
for i in range(10000 // batch_size):
    image_data, label_data = data.next()
    image.d = image_data / 255.
    feature.forward(clear_buffer=True)
    features.append(feature.d.copy())
    labels.append(label_data.copy())
features = np.vstack(features)
labels = np.vstack(labels)

# Visualize
f = plt.figure(figsize=(16, 9))
for i in range(10):
    c = plt.cm.Set1(i / 10.)
    plt.plot(features[labels.flat == i, 0].flatten(), features[
              labels.flat == i, 1].flatten(), '.', c=c)
plt.legend(list(map(str, range(10))))
plt.grid()
plt.savefig(os.path.join(model_save_path, "embed.png"))

from IPython.display import Image, display
display(Image('tmp.monitor.siamese/embed.png'))