We modularized commonly used codes for training and inference in the module (or mod for short) package. This package provides intermediate-level and high-level interface for executing predefined networks.
In this tutorial, we will use a simple multilayer perception for 10 classes and a synthetic dataset.
In [1]:
import mxnet as mx
from data_iter import SyntheticData
# mlp
net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(net, name='fc1', num_hidden=64)
net = mx.sym.Activation(net, name='relu1', act_type="relu")
net = mx.sym.FullyConnected(net, name='fc2', num_hidden=10)
net = mx.sym.SoftmaxOutput(net, name='softmax')
# synthetic 10 classes dataset with 128 dimension
data = SyntheticData(10, 128)
mx.viz.plot_network(net)
Out[1]:
The most widely used module class is Module, which wraps a Symbol and one or more Executors.
We construct a module by specify
One can refer to data.ipynb for more explanations about the last two arguments. Here we have only one data named data, and one label, with the name softmax_label, which is automatically named for us following the name softmax we specified for the SoftmaxOutput operator.
In [2]:
mod = mx.mod.Module(symbol=net,
context=mx.cpu(),
data_names=['data'],
label_names=['softmax_label'])
In [3]:
# @@@ AUTOTEST_OUTPUT_IGNORED_CELL
import logging
logging.basicConfig(level=logging.INFO)
batch_size=32
mod.fit(data.get_iter(batch_size),
eval_data=data.get_iter(batch_size),
optimizer='sgd',
optimizer_params={'learning_rate':0.1},
eval_metric='acc',
num_epoch=5)
To predict with a module, simply call predict() with a DataIter. It will collect and return all the prediction results.
In [4]:
y = mod.predict(data.get_iter(batch_size))
'shape of predict: %s' % (y.shape,)
Out[4]:
Another convenient API for prediction in the case where the prediction results might be too large to fit in the memory is iter_predict:
In [5]:
# @@@ AUTOTEST_OUTPUT_IGNORED_CELL
for preds, i_batch, batch in mod.iter_predict(data.get_iter(batch_size)):
pred_label = preds[0].asnumpy().argmax(axis=1)
label = batch.label[0].asnumpy().astype('int32')
print('batch %d, accuracy %f' % (i_batch, float(sum(pred_label==label))/len(label)))
If we do not need the prediction outputs, but just need to evaluate on a test set, we can call the score() function with a DataIter and a EvalMetric:
In [6]:
# @@@ AUTOTEST_OUTPUT_IGNORED_CELL
mod.score(data.get_iter(batch_size), ['mse', 'acc'])
Out[6]:
In [7]:
# @@@ AUTOTEST_OUTPUT_IGNORED_CELL
# construct a callback function to save checkpoints
model_prefix = 'mx_mlp'
checkpoint = mx.callback.do_checkpoint(model_prefix)
mod = mx.mod.Module(symbol=net)
mod.fit(data.get_iter(batch_size), num_epoch=5, epoch_end_callback=checkpoint)
To load the saved module parameters, call the load_checkpoint function. It load the Symbol and the associated parameters. We can then set the loaded parameters into the module.
In [8]:
sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
print(sym.tojson() == net.tojson())
# assign the loaded parameters to the module
mod.set_params(arg_params, aux_params)
Or if we just want to resume training from a saved checkpoint, instead of calling set_params(), we can directly call fit(), passing the loaded parameters, so that fit() knows to start from those parameters instead of initializing from random. We also set the begin_epoch so that so that fit() knows we are resuming from a previous saved epoch.
In [9]:
# @@@ AUTOTEST_OUTPUT_IGNORED_CELL
mod = mx.mod.Module(symbol=sym)
mod.fit(data.get_iter(batch_size),
num_epoch=5,
arg_params=arg_params,
aux_params=aux_params,
begin_epoch=3)
We already seen how to module for basic training and inference. Now we are going to show a more flexiable usage of module.
A module represents a computation component. The design purpose of a module is that it abstract a computation “machine”, that accpets Symbol programs and data, and then we can run forward, backward, update parameters, etc.
We aim to make the APIs easy and flexible to use, especially in the case when we need to use imperative API to work with multiple modules (e.g. stochastic depth network).
A module has several states:
The following codes implement a simplified fit(). Here we used other components including initializer, optimizer, and metric, which are explained in other notebooks.
In [10]:
# @@@ AUTOTEST_OUTPUT_IGNORED_CELL
# initial state
mod = mx.mod.Module(symbol=net)
# bind, tell the module the data and label shapes, so
# that memory could be allocated on the devices for computation
train_iter = data.get_iter(batch_size)
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
# init parameters
mod.init_params(initializer=mx.init.Xavier(magnitude=2.))
# init optimizer
mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1), ))
# use accuracy as the metric
metric = mx.metric.create('acc')
# train one epoch, i.e. going over the data iter one pass
for batch in train_iter:
mod.forward(batch, is_train=True) # compute predictions
mod.update_metric(metric, batch.label) # accumulate prediction accuracy
mod.backward() # compute gradients
mod.update() # update parameters using SGD
# training accuracy
print(metric.get())
Beside the operations, a module provides a lot of useful information.
basic names:
state information
input/output information
parameters (for modules with parameters)
In [11]:
print((mod.data_shapes, mod.label_shapes, mod.output_shapes))
print(mod.get_params())
Module simplifies the implementation of new modules. For example
SequentialModule can chain multiple modules togetherBucketingModule is able to handle bucketing, which is useful for various length inputs and outputsPythonModule implements many APIs as empty function to ease users to implement customized modules. See also example/module for a list of code examples using the module API.
The module is implemented in python, located at python/mxnet/module