Deep learning frequently requires a large amount of labeled data, but in practice, it can be very costly to collect data with labels. Semi-supervised setting has gained attention since it can leverage unlabeled data to train a model.
In this tutorial, we will show you how to perform semi-supervised learning on MNIST with NNabla, using the model known as virtual adversarial training (VAT). Although MNIST is fully labeled, we will assume a setting where some of the labels are missing.
In [ ]:
# If you run this notebook on Google Colab, uncomment and run the following to set up dependencies.
# !pip install nnabla-ext-cuda100
# !git clone https://github.com/sony/nnabla-examples.git
# %cd nnabla-examples
As always, let's start by importing dependencies.
In [ ]:
from __future__ import absolute_import
from six.moves import range
import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.solver as S
from nnabla.logger import logger
import nnabla.utils.save as save
from nnabla.utils.data_iterator import data_iterator_simple
from utils.neu.save_nnp import save_nnp
import numpy as np
import time
import os
Let's also define data iterator for MNIST. You can disregard the details for now.
In [ ]:
import struct
import zlib
from nnabla.logger import logger
from nnabla.utils.data_iterator import data_iterator
from nnabla.utils.data_source import DataSource
from nnabla.utils.data_source_loader import download
def load_mnist(train=True):
'''
Load MNIST dataset images and labels from the original page by Yan LeCun or the cache file.
Args:
train (bool): The testing dataset will be returned if False. Training data has 60000 images, while testing has 10000 images.
Returns:
numpy.ndarray: A shape of (#images, 1, 28, 28). Values in [0.0, 1.0].
numpy.ndarray: A shape of (#images, 1). Values in {0, 1, ..., 9}.
'''
if train:
image_uri = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
label_uri = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
else:
image_uri = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'
label_uri = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'
logger.info('Getting label data from {}.'.format(label_uri))
r = download(label_uri)
data = zlib.decompress(r.read(), zlib.MAX_WBITS | 32)
_, size = struct.unpack('>II', data[0:8])
labels = np.frombuffer(data[8:], np.uint8).reshape(-1, 1)
r.close()
logger.info('Getting label data done.')
logger.info('Getting image data from {}.'.format(image_uri))
r = download(image_uri)
data = zlib.decompress(r.read(), zlib.MAX_WBITS | 32)
_, size, height, width = struct.unpack('>IIII', data[0:16])
images = np.frombuffer(data[16:], np.uint8).reshape(
size, 1, height, width)
r.close()
logger.info('Getting image data done.')
return images, labels
class MnistDataSource(DataSource):
'''
Get data directly from MNIST dataset from Internet(yann.lecun.com).
'''
def _get_data(self, position):
image = self._images[self._indexes[position]]
label = self._labels[self._indexes[position]]
return (image, label)
def __init__(self, train=True, shuffle=False, rng=None):
super(MnistDataSource, self).__init__(shuffle=shuffle)
self._train = train
self._images, self._labels = load_mnist(train)
self._size = self._labels.size
self._variables = ('x', 'y')
if rng is None:
rng = np.random.RandomState(313)
self.rng = rng
self.reset()
def reset(self):
if self._shuffle:
self._indexes = self.rng.permutation(self._size)
else:
self._indexes = np.arange(self._size)
super(MnistDataSource, self).reset()
@property
def images(self):
"""Get copy of whole data with a shape of (N, 1, H, W)."""
return self._images.copy()
@property
def labels(self):
"""Get copy of whole label with a shape of (N, 1)."""
return self._labels.copy()
def data_iterator_mnist(batch_size,
train=True,
rng=None,
shuffle=True,
with_memory_cache=False,
with_file_cache=False):
'''
Provide DataIterator with :py:class:`MnistDataSource`
with_memory_cache and with_file_cache option's default value is all False,
because :py:class:`MnistDataSource` is able to store all data into memory.
For example,
.. code-block:: python
with data_iterator_mnist(True, batch_size) as di:
for data in di:
SOME CODE TO USE data.
'''
return data_iterator(MnistDataSource(train=train, shuffle=shuffle, rng=rng),
batch_size,
rng,
with_memory_cache,
with_file_cache)
We now define a multi-layer perceptron (MLP) network to be used later. Our MLP consists of 3 fully-connected layers, two of whiich are followed by batch normalization and non-linear activation.
In [ ]:
def mlp_net(x, n_h, n_y, test=False):
"""
Args:
x(`~nnabla.Variable`): N-D array
n_h(int): number of units in an intermediate layer
n_y(int): number of classes
test: operation type train=True, test=False
Returns:
~nnabla.Variable: h
"""
h = x
with nn.parameter_scope("fc1"):
h = F.relu(PF.batch_normalization(
PF.affine(h, n_h), batch_stat=not test), inplace=True)
with nn.parameter_scope("fc2"):
h = F.relu(PF.batch_normalization(
PF.affine(h, n_h), batch_stat=not test), inplace=True)
with nn.parameter_scope("fc3"):
h = PF.affine(h, n_y)
return h
Let's also define a function to measure the distance between two distributions. In this example, we use a function called multinomial Kullback-Leibler divergence, commonly known as KL-divergence.
In [ ]:
def distance(y0, y1):
"""
Distance function is Kullback-Leibler Divergence for categorical distribution
"""
return F.kl_multinomial(F.softmax(y0), F.softmax(y1))
Before we get into the main computational graph, let's also define a function to evaluate the network. This function simply returns error rate during validation, which is averaged over the number of iterations.
In [ ]:
def calc_validation_error(di_v, xv, tv, err, val_iter):
"""
Calculate validation error rate
Args:
di_v; validation dataset
xv: variable for input
tv: variable for label
err: variable for error estimation
val_iter: number of iteration
Returns:
error rate
"""
ve = 0.0
for j in range(val_iter):
xv.d, tv.d = di_v.next()
xv.d = xv.d / 255
err.forward(clear_buffer=True)
ve += err.d
return ve / val_iter
Now we get into the main computational graph. We start by setting context to use cuDNN, and loading data iterator for MNIST.
In [ ]:
# Get context.
from nnabla.ext_utils import get_extension_context
ctx = get_extension_context('cudnn')
nn.set_default_context(ctx)
# Load MNIST Dataset
images, labels = load_mnist(train=True)
rng = np.random.RandomState(706)
inds = rng.permutation(len(images))
Let's define two functions for loading data for labeled and unlabeled settings respectively. Although feed_unlabeled
function is also returning labels, we will later see that the labels are disregarded in the graph.
After declaring some hyperparameters, we also define data iterator variables using the two load functions we just defined, separately for labeled and unlabeled settings. Let's also define a data iterator variable for validation.
In [ ]:
def feed_labeled(i):
j = inds[i]
return images[j], labels[j]
def feed_unlabeled(i):
j = inds[i]
return images[j], labels[j]
shape_x = (1, 28, 28)
n_h = 1200 #number of units
n_y = 10 #number of classes
n_labeled = 100
n_train = 60000
batchsize_l = 100
batchsize_u = 250
batchsize_v = 100
di_l = data_iterator_simple(feed_labeled, n_labeled,
batchsize_l, shuffle=True, rng=rng, with_file_cache=False)
di_u = data_iterator_simple(feed_unlabeled, n_train,
batchsize_u, shuffle=True, rng=rng, with_file_cache=False)
di_v = data_iterator_mnist(batchsize_v, train=False)
We first define a simple forward function that calls the multi-layer perceptron network that we defined above.
We then define the variables separately for labeled and unlabeled data. xl
, xu
and yl
,yu
refer to input and output for MLP network. In the labeled setting, we also have teacher variable tl
, from which we can calculate the loss by applying softmax cross entropy. Note that this loss is for labeled data only and we will define separate loss variable later for unlabeled data.
Also, notice that we do not have teacher variable for unlabeled setting, because we assume that the labels are missing. Instead, we define an unlinked variable of yu
.
In [ ]:
# Create networks
# feed-forward-net building function
def forward(x, test=False):
return mlp_net(x, n_h, n_y, test)
# Net for learning labeled data
xl = nn.Variable((batchsize_l,) + shape_x, need_grad=False)
yl = forward(xl, test=False)
tl = nn.Variable((batchsize_l, 1), need_grad=False)
loss_l = F.mean(F.softmax_cross_entropy(yl, tl))
# Net for learning unlabeled data
xu = nn.Variable((batchsize_u,) + shape_x, need_grad=False)
yu = forward(xu, test=False)
y1 = yu.get_unlinked_variable()
y1.need_grad = False
We now define variables for noise, which are added to the input variable xu and fed to MLP. The KL-divergence between the MLP outputs of noisy variable and noise-free variable is used to compute loss. Of the two losses, one is used to perform power method iteration, and another one is loss for unlabeled data.
In [ ]:
xi_for_vat = 10.0
eps_for_vat = 1.5
noise = nn.Variable((batchsize_u,) + shape_x, need_grad=True)
r = noise / (F.sum(noise ** 2, [1, 2, 3], keepdims=True)) ** 0.5
r.persistent = True
y2 = forward(xu + xi_for_vat * r, test=False)
y3 = forward(xu + eps_for_vat * r, test=False)
loss_k = F.mean(distance(y1, y2))
loss_u = F.mean(distance(y1, y3))
# Net for evaluating validation data
xv = nn.Variable((batchsize_v,) + shape_x, need_grad=False)
hv = forward(xv, test=True)
tv = nn.Variable((batchsize_v, 1), need_grad=False)
err = F.mean(F.top_n_error(hv, tv, n=1))
We define our solver and monitor variables. We will use Adam as our solver.
In [ ]:
# Create solver
solver = S.Adam(2e-3)
solver.set_parameters(nn.get_parameters())
# Monitor training and validation stats.
model_save_path = 'tmp.monitor.vat'
import nnabla.monitor as M
monitor = M.Monitor(model_save_path)
monitor_verr = M.MonitorSeries("Test error", monitor, interval=240)
monitor_time = M.MonitorTimeElapsed("Elapsed time", monitor, interval=240)
Now we get into our training loop. We will have separate training stages for labeled and unlabeled data. We first start with labeled data, which is pretty much the same as usual training graph.
Then, we define our training graph for unlabeled data. Note that we are ignoring the label returned by data iterator, setting it as a garbage variable _
. We first forward the noise-free variable, and then calculate adversarial noise first by generating random noise followed by power method over iterations. Finally, we compute loss for unlabeled data.
In [ ]:
# Training Loop.
t0 = time.time()
max_iter = 24000
val_interval = 240
val_iter = 100
weight_decay = 0
n_iter_for_power_method = 1
iter_per_epoch = 240
learning_rate_decay = 0.9
for i in range(max_iter):
# Validation Test
if i % val_interval == 0:
valid_error = calc_validation_error(
di_v, xv, tv, err, val_iter)
monitor_verr.add(i, valid_error)
#################################
## Training by Labeled Data #####
#################################
# forward, backward and update
xl.d, tl.d = di_l.next()
xl.d = xl.d / 255
solver.zero_grad()
loss_l.forward(clear_no_need_grad=True)
loss_l.backward(clear_buffer=True)
solver.weight_decay(weight_decay)
solver.update()
#################################
## Training by Unlabeled Data ###
#################################
# Calculate y without noise, only once.
xu.d, _ = di_u.next()
xu.d = xu.d / 255
yu.forward(clear_buffer=True)
##### Calculate Adversarial Noise #####
# Do power method iteration
noise.d = np.random.normal(size=xu.shape).astype(np.float32)
for k in range(n_iter_for_power_method):
r.grad.zero()
loss_k.forward(clear_no_need_grad=True)
loss_k.backward(clear_buffer=True)
noise.data.copy_from(r.grad)
##### Calculate loss for unlabeled data #####
# forward, backward and update
solver.zero_grad()
loss_u.forward(clear_no_need_grad=True)
loss_u.backward(clear_buffer=True)
solver.weight_decay(weight_decay)
solver.update()
##### Learning rate update #####
if i % iter_per_epoch == 0:
solver.set_learning_rate(
solver.learning_rate() * learning_rate_decay)
monitor_time.add(i)
Finally, we evaluate our model on the validation dataset. If the model was trained correctly, we should get an error rate of around 1.5%.
In [ ]:
# Evaluate the final model by the error rate with validation dataset
valid_error = calc_validation_error(di_v, xv, tv, err, val_iter)
print(valid_error)
# If you need to save the model, please comment out the following lines:
# parameter_file = os.path.join(
# model_save_path, 'params_%06d.h5' % max_iter)
# nn.save_parameters(parameter_file)