Understanding the vanishing gradient problem through visualization

There're reasons why deep neural network could work very well, while few people get a promising result or make it possible by simply make their neural network deep.

  • Computational power and data grow tremendously. People need more complex model and faster computer to make it feasible.
  • Realize and understand the difficulties associated with training a deep model.

In this tutorial, we would like to show you some insights of the techniques that researchers find useful in training a deep model, using MXNet and its visualizing tool -- TensorBoard.

Let’s recap some of the relevant issues on training a deep model:

  • Weight initialization. If you initialize the network with random and small weights, when you look at the gradients down the top layer, you would find they’re getting smaller and smaller, then the first layer almost doesn’t change as the gradients are too small to make a significant update. Without a chance to learn the first layer effectively, it's impossible to update and learn a good deep model.
  • Nonlinearity activation. When people use sigmoid or tanh as activation function, the gradient, same as the above, is getting smaller and smaller. Just remind the formula of the parameter updates and the gradient.

Data and DataIter


In [1]:
def download(data_dir):
    if not os.path.isdir(data_dir):
        os.system('mkdir ' + data_dir)
    os.chdir(data_dir)
    if (not os.path.exists('train-images-idx3-ubyte')) or \
       (not os.path.exists('train-labels-idx1-ubyte')) or \
       (not os.path.exists('t10k-images-idx3-ubyte')) or \
       (not os.path.exists('t10k-labels-idx1-ubyte')):
           os.system('wget http://data.mxnet.io/mxnet/data/mnist.zip')
           os.system('unzip mnist.zip; rm mnist.zip')
    os.chdir('..')

In [2]:
def get_iterator(data_shape):
    def get_iterator_impl(args, kv):
        data_dir = args.data_dir
        # if Windows
        if os.name == "nt":
            data_dir = data_dir[:-1] + "\\"
        if '://' not in args.data_dir:
            download(data_dir)
        flat = False if len(data_shape) == 3 else True

        train           = mx.io.MNISTIter(
            image       = data_dir + "train-images-idx3-ubyte",
            label       = data_dir + "train-labels-idx1-ubyte",
            input_shape = data_shape,
            batch_size  = args.batch_size,
            shuffle     = True,
            flat        = flat,
            num_parts   = kv.num_workers,
            part_index  = kv.rank)

        val = mx.io.MNISTIter(
            image       = data_dir + "t10k-images-idx3-ubyte",
            label       = data_dir + "t10k-labels-idx1-ubyte",
            input_shape = data_shape,
            batch_size  = args.batch_size,
            flat        = flat,
            num_parts   = kv.num_workers,
            part_index  = kv.rank)

        return (train, val)
    return get_iterator_impl

Network Structure

Here's the network structure:


In [3]:
def get_mlp(acti="relu"):
    """
    multi-layer perceptron
    """
    data = mx.symbol.Variable('data')
    fc   = mx.symbol.FullyConnected(data = data, name='fc', num_hidden=512)
    act  = mx.symbol.Activation(data = fc, name='act', act_type=acti)
    fc0  = mx.symbol.FullyConnected(data = act, name='fc0', num_hidden=256)
    act0 = mx.symbol.Activation(data = fc0, name='act0', act_type=acti)
    fc1  = mx.symbol.FullyConnected(data = act0, name='fc1', num_hidden=128)
    act1 = mx.symbol.Activation(data = fc1, name='act1', act_type=acti)
    fc2  = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
    act2 = mx.symbol.Activation(data = fc2, name='act2', act_type=acti)
    fc3  = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=32)
    act3 = mx.symbol.Activation(data = fc3, name='act3', act_type=acti)
    fc4  = mx.symbol.FullyConnected(data = act3, name='fc4', num_hidden=16)
    act4 = mx.symbol.Activation(data = fc4, name='act4', act_type=acti)
    fc5  = mx.symbol.FullyConnected(data = act4, name='fc5', num_hidden=10)
    mlp  = mx.symbol.SoftmaxOutput(data = fc5, name = 'softmax')
    return mlp

As you might already notice, we intentionally add more layers than usual, as the vanished gradient problem becomes severer as the network goes deeper.

Experiment Setting

Here we create a simple MLP for cifar10 dataset and visualize the learning processing through loss/accuracy, and its gradient distributions, by changing its initialization and activation setting.

General Setting

We adopt MLP as our model and run our experiment in MNIST dataset. Then we'll visualize the weight and gradient of a layer using Monitor in MXNet and Histogram in TensorBoard.

Weight Initialization

The weight initialization also has uniform and xavier.

if args.init == 'uniform':
        init = mx.init.Uniform(0.1)
if args.init == 'xavier':
    init = mx.init.Xavier(factor_type="in", magnitude=2.34)

Note that we intentionally choose a near zero setting in uniform.

Activation Function

We would compare two different activations, sigmoid and relu.

# acti = sigmoid or relu.
act  = mx.symbol.Activation(data = fc, name='act', act_type=acti)

Logging with TensorBoard and Monitor

In order to monitor the weight and gradient of this network in different settings, we could use MXNet's monitor for logging and TensorBoard for visualization.

Usage

Here's a code snippet from train_model.py:

import mxnet as mx
from tensorboard import summary
from tensorboard import FileWriter

# where to keep your TensorBoard logging file
logdir = './logs/'
summary_writer = FileWriter(logdir)

# mx.mon.Monitor's callback 
def get_gradient(g):
    # get flatten list
    grad = g.asnumpy().flatten()
    # logging using tensorboard, use histogram type.
    s = summary.histogram('fc_backward_weight', grad)
    summary_writer.add_summary(s)
    return mx.nd.norm(g)/np.sqrt(g.size)

mon = mx.mon.Monitor(int(args.num_examples/args.batch_size), get_gradient, pattern='fc_backward_weight')  # get the gradient passed to the first fully-connnected layer.

# training
model.fit(
        X                  = train,
        eval_data          = val,
        eval_metric        = eval_metrics,
        kvstore            = kv,
        monitor            = mon,
        epoch_end_callback = checkpoint)

# close summary_writer
summary_writer.close()

In [4]:
import mxnet as mx
import argparse
import os, sys

def parse_args(init_type, name):
    parser = argparse.ArgumentParser(description='train an image classifer on mnist')
    parser.add_argument('--network', type=str, default='mlp',
                        choices = ['mlp', 'lenet', 'lenet-stn'],
                        help = 'the cnn to use')
    parser.add_argument('--data-dir', type=str, default='mnist/',
                        help='the input data directory')
    parser.add_argument('--gpus', type=str,
                        help='the gpus will be used, e.g "0,1,2,3"')
    parser.add_argument('--num-examples', type=int, default=60000,
                        help='the number of training examples')
    parser.add_argument('--batch-size', type=int, default=128,
                        help='the batch size')
    parser.add_argument('--lr', type=float, default=.1,
                        help='the initial learning rate')
    parser.add_argument('--model-prefix', type=str,
                        help='the prefix of the model to load/save')
    parser.add_argument('--save-model-prefix', type=str,
                        help='the prefix of the model to save')
    parser.add_argument('--num-epochs', type=int, default=10,
                        help='the number of training epochs')
    parser.add_argument('--load-epoch', type=int,
                        help="load the model on an epoch using the model-prefix")
    parser.add_argument('--kv-store', type=str, default='local',
                        help='the kvstore type')
    parser.add_argument('--lr-factor', type=float, default=1,
                        help='times the lr with a factor for every lr-factor-epoch epoch')
    parser.add_argument('--lr-factor-epoch', type=float, default=1,
                        help='the number of epoch to factor the lr, could be .5')
    parser.add_argument('--init', type=str, default=init_type,
                        help='the weight initialization method')
    parser.add_argument('--name', type=str, default=name,
                        help='name for summary.histogram for gradient/weight logging')
    return parser.parse_args("")

In [5]:
import mxnet as mx
import logging
import os
import numpy as np
from tensorboard import summary
from tensorboard import FileWriter

def fit(args, network, data_loader, batch_end_callback=None):
    # kvstore
    kv = mx.kvstore.create(args.kv_store)

    # logging
    head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
    if 'log_file' in args and args.log_file is not None:
        log_file = args.log_file
        log_dir = args.log_dir
        log_file_full_name = os.path.join(log_dir, log_file)
        if not os.path.exists(log_dir):
            os.mkdir(log_dir)
        logger = logging.getLogger()
        handler = logging.FileHandler(log_file_full_name)
        formatter = logging.Formatter(head)
        handler.setFormatter(formatter)
        logger.addHandler(handler)
        logger.setLevel(logging.DEBUG)
        logger.info('start with arguments %s', args)
    else:
        logging.basicConfig(level=logging.DEBUG, format=head)
        logging.info('start with arguments %s', args)

    # load model
    model_prefix = args.model_prefix
    if model_prefix is not None:
        model_prefix += "-%d" % (kv.rank)
    model_args = {}
    if args.load_epoch is not None:
        assert model_prefix is not None
        tmp = mx.model.FeedForward.load(model_prefix, args.load_epoch)
        model_args = {'arg_params' : tmp.arg_params,
                      'aux_params' : tmp.aux_params,
                      'begin_epoch' : args.load_epoch}
        # TODO: check epoch_size for 'dist_sync'
        epoch_size = args.num_examples / args.batch_size
        model_args['begin_num_update'] = epoch_size * args.load_epoch

    # save model
    save_model_prefix = args.save_model_prefix
    if save_model_prefix is None:
        save_model_prefix = model_prefix
    checkpoint = None if save_model_prefix is None else mx.callback.do_checkpoint(save_model_prefix)

    # data
    (train, val) = data_loader(args, kv)

    # train
    devs = [mx.cpu(i) for i in range(4)] if args.gpus is None else [
        mx.gpu(int(i)) for i in args.gpus.split(',')]

    epoch_size = args.num_examples / args.batch_size

    if args.kv_store == 'dist_sync':
        epoch_size /= kv.num_workers
        model_args['epoch_size'] = epoch_size

    if 'lr_factor' in args and args.lr_factor < 1:
        model_args['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
            step = max(int(epoch_size * args.lr_factor_epoch), 1),
            factor = args.lr_factor)

    if 'clip_gradient' in args and args.clip_gradient is not None:
        model_args['clip_gradient'] = args.clip_gradient

    # disable kvstore for single device
    if 'local' in kv.type and (
            args.gpus is None or len(args.gpus.split(',')) is 1):
        kv = None
    
    if args.init == 'uniform':
        init = mx.init.Uniform(0.1)
    if args.init == 'normal':
        init = mx.init.Normal(0,0.1)
    if args.init == 'xavier':
        init = mx.init.Xavier(factor_type="in", magnitude=2.34)
    model = mx.model.FeedForward(
        ctx                = devs,
        symbol             = network,
        num_epoch          = args.num_epochs,
        learning_rate      = args.lr,
        momentum           = 0.9,
        wd                 = 0.00001,
        initializer        = init,
        **model_args)

    eval_metrics = ['accuracy']
    ## TopKAccuracy only allows top_k > 1
    for top_k in [5]:
        eval_metrics.append(mx.metric.create('top_k_accuracy', top_k = top_k))

    if batch_end_callback is not None:
        if not isinstance(batch_end_callback, list):
            batch_end_callback = [batch_end_callback]
    else:
        batch_end_callback = []
    batch_end_callback.append(mx.callback.Speedometer(args.batch_size, 50))
    
    logdir = './logs/'
    summary_writer = FileWriter(logdir)
    def get_grad(g):
        # logging using tensorboard
        grad = g.asnumpy().flatten()
        s = summary.histogram(args.name, grad)
        summary_writer.add_summary(s)
        return mx.nd.norm(g)/np.sqrt(g.size)
    mon = mx.mon.Monitor(int(args.num_examples/args.batch_size), get_grad, pattern='fc_backward_weight')  # get weight of first fully-connnected layer
    
    model.fit(
        X                  = train,
        eval_data          = val,
        eval_metric        = eval_metrics,
        kvstore            = kv,
        monitor            = mon,
        epoch_end_callback = checkpoint)

    summary_writer.close()

What to expect?

If a setting suffers from an vanish gradient problem, the gradients passed from the top should be very close to zero, and the weight of the network barely change/update.

Uniform and Sigmoid


In [6]:
# Uniform and sigmoid
args = parse_args('uniform', 'uniform_sigmoid')
data_shape = (784, )
net = get_mlp("sigmoid")

# train
fit(args, net, get_iterator(data_shape))


2017-01-12 19:30:11,856 Node[0] start with arguments Namespace(batch_size=128, data_dir='mnist/', gpus=None, init='uniform', kv_store='local', load_epoch=None, lr=0.1, lr_factor=1, lr_factor_epoch=1, model_prefix=None, name='uniform_sigmoid', network='mlp', num_epochs=10, num_examples=60000, save_model_prefix=None)
2017-01-12 19:30:14,983 Node[0] [Deprecation Warning] mxnet.model.FeedForward has been deprecated. Please use mxnet.mod.Module instead.
2017-01-12 19:30:14,990 Node[0] Start training with [cpu(0), cpu(1), cpu(2), cpu(3)]
2017-01-12 19:30:17,453 Node[0] Batch:       1 fc_backward_weight             5.1907e-07	
2017-01-12 19:30:17,454 Node[0] Batch:       1 fc_backward_weight             4.2085e-07	
2017-01-12 19:30:17,455 Node[0] Batch:       1 fc_backward_weight             4.31894e-07	
2017-01-12 19:30:17,456 Node[0] Batch:       1 fc_backward_weight             5.80652e-07	
2017-01-12 19:30:22,885 Node[0] Epoch[0] Resetting Data Iterator
2017-01-12 19:30:22,888 Node[0] Epoch[0] Time cost=7.839
2017-01-12 19:30:23,215 Node[0] Epoch[0] Validation-accuracy=0.105769
2017-01-12 19:30:23,216 Node[0] Epoch[0] Validation-top_k_accuracy_5=0.509115
2017-01-12 19:30:25,728 Node[0] Batch:     469 fc_backward_weight             5.15008e-07	
2017-01-12 19:30:25,730 Node[0] Batch:     469 fc_backward_weight             5.52044e-07	
2017-01-12 19:30:25,730 Node[0] Batch:     469 fc_backward_weight             4.48535e-07	
2017-01-12 19:30:25,732 Node[0] Batch:     469 fc_backward_weight             5.8659e-07	
2017-01-12 19:30:31,356 Node[0] Epoch[1] Resetting Data Iterator
2017-01-12 19:30:31,357 Node[0] Epoch[1] Time cost=8.140
2017-01-12 19:30:31,868 Node[0] Epoch[1] Validation-accuracy=0.105769
2017-01-12 19:30:31,869 Node[0] Epoch[1] Validation-top_k_accuracy_5=0.504507
2017-01-12 19:30:34,348 Node[0] Batch:     937 fc_backward_weight             5.96259e-07	
2017-01-12 19:30:34,349 Node[0] Batch:     937 fc_backward_weight             5.97974e-07	
2017-01-12 19:30:34,350 Node[0] Batch:     937 fc_backward_weight             4.51892e-07	
2017-01-12 19:30:34,351 Node[0] Batch:     937 fc_backward_weight             6.5213e-07	
2017-01-12 19:30:39,779 Node[0] Epoch[2] Resetting Data Iterator
2017-01-12 19:30:39,780 Node[0] Epoch[2] Time cost=7.910
2017-01-12 19:30:40,325 Node[0] Epoch[2] Validation-accuracy=0.105769
2017-01-12 19:30:40,327 Node[0] Epoch[2] Validation-top_k_accuracy_5=0.510216
2017-01-12 19:30:42,989 Node[0] Batch:    1405 fc_backward_weight             6.52871e-07	
2017-01-12 19:30:42,989 Node[0] Batch:    1405 fc_backward_weight             6.20821e-07	
2017-01-12 19:30:42,990 Node[0] Batch:    1405 fc_backward_weight             4.46476e-07	
2017-01-12 19:30:42,992 Node[0] Batch:    1405 fc_backward_weight             7.53641e-07	
2017-01-12 19:30:48,245 Node[0] Epoch[3] Resetting Data Iterator
2017-01-12 19:30:48,246 Node[0] Epoch[3] Time cost=7.917
2017-01-12 19:30:48,555 Node[0] Epoch[3] Validation-accuracy=0.105769
2017-01-12 19:30:48,556 Node[0] Epoch[3] Validation-top_k_accuracy_5=0.510216
2017-01-12 19:30:51,142 Node[0] Batch:    1873 fc_backward_weight             6.63064e-07	
2017-01-12 19:30:51,143 Node[0] Batch:    1873 fc_backward_weight             6.33577e-07	
2017-01-12 19:30:51,145 Node[0] Batch:    1873 fc_backward_weight             4.2922e-07	
2017-01-12 19:30:51,147 Node[0] Batch:    1873 fc_backward_weight             8.31741e-07	
2017-01-12 19:30:56,116 Node[0] Epoch[4] Resetting Data Iterator
2017-01-12 19:30:56,117 Node[0] Epoch[4] Time cost=7.559
2017-01-12 19:30:56,430 Node[0] Epoch[4] Validation-accuracy=0.103666
2017-01-12 19:30:56,431 Node[0] Epoch[4] Validation-top_k_accuracy_5=0.509014
2017-01-12 19:30:59,020 Node[0] Batch:    2341 fc_backward_weight             6.47525e-07	
2017-01-12 19:30:59,021 Node[0] Batch:    2341 fc_backward_weight             6.37593e-07	
2017-01-12 19:30:59,021 Node[0] Batch:    2341 fc_backward_weight             4.12299e-07	
2017-01-12 19:30:59,022 Node[0] Batch:    2341 fc_backward_weight             8.71203e-07	
2017-01-12 19:31:04,117 Node[0] Epoch[5] Resetting Data Iterator
2017-01-12 19:31:04,118 Node[0] Epoch[5] Time cost=7.686
2017-01-12 19:31:04,505 Node[0] Epoch[5] Validation-accuracy=0.103666
2017-01-12 19:31:04,505 Node[0] Epoch[5] Validation-top_k_accuracy_5=0.509014
2017-01-12 19:31:07,535 Node[0] Batch:    2809 fc_backward_weight             6.23424e-07	
2017-01-12 19:31:07,536 Node[0] Batch:    2809 fc_backward_weight             6.33117e-07	
2017-01-12 19:31:07,537 Node[0] Batch:    2809 fc_backward_weight             3.99334e-07	
2017-01-12 19:31:07,539 Node[0] Batch:    2809 fc_backward_weight             8.78155e-07	
2017-01-12 19:31:13,145 Node[0] Epoch[6] Resetting Data Iterator
2017-01-12 19:31:13,149 Node[0] Epoch[6] Time cost=8.642
2017-01-12 19:31:13,733 Node[0] Epoch[6] Validation-accuracy=0.107472
2017-01-12 19:31:13,735 Node[0] Epoch[6] Validation-top_k_accuracy_5=0.509014
2017-01-12 19:31:16,660 Node[0] Batch:    3277 fc_backward_weight             5.97921e-07	
2017-01-12 19:31:16,661 Node[0] Batch:    3277 fc_backward_weight             6.22105e-07	
2017-01-12 19:31:16,662 Node[0] Batch:    3277 fc_backward_weight             3.89208e-07	
2017-01-12 19:31:16,663 Node[0] Batch:    3277 fc_backward_weight             8.6379e-07	
2017-01-12 19:31:21,486 Node[0] Epoch[7] Resetting Data Iterator
2017-01-12 19:31:21,487 Node[0] Epoch[7] Time cost=7.742
2017-01-12 19:31:21,781 Node[0] Epoch[7] Validation-accuracy=0.109776
2017-01-12 19:31:21,782 Node[0] Epoch[7] Validation-top_k_accuracy_5=0.509014
2017-01-12 19:31:24,268 Node[0] Batch:    3745 fc_backward_weight             5.73259e-07	
2017-01-12 19:31:24,270 Node[0] Batch:    3745 fc_backward_weight             6.06878e-07	
2017-01-12 19:31:24,270 Node[0] Batch:    3745 fc_backward_weight             3.80379e-07	
2017-01-12 19:31:24,271 Node[0] Batch:    3745 fc_backward_weight             8.37382e-07	
2017-01-12 19:31:29,187 Node[0] Epoch[8] Resetting Data Iterator
2017-01-12 19:31:29,188 Node[0] Epoch[8] Time cost=7.405
2017-01-12 19:31:29,508 Node[0] Epoch[8] Validation-accuracy=0.105970
2017-01-12 19:31:29,509 Node[0] Epoch[8] Validation-top_k_accuracy_5=0.512620
2017-01-12 19:31:31,994 Node[0] Batch:    4213 fc_backward_weight             5.49988e-07	
2017-01-12 19:31:31,995 Node[0] Batch:    4213 fc_backward_weight             5.89305e-07	
2017-01-12 19:31:31,996 Node[0] Batch:    4213 fc_backward_weight             3.71941e-07	
2017-01-12 19:31:31,997 Node[0] Batch:    4213 fc_backward_weight             8.05085e-07	
2017-01-12 19:31:37,268 Node[0] Epoch[9] Resetting Data Iterator
2017-01-12 19:31:37,270 Node[0] Epoch[9] Time cost=7.760
2017-01-12 19:31:37,707 Node[0] Epoch[9] Validation-accuracy=0.105970
2017-01-12 19:31:37,708 Node[0] Epoch[9] Validation-top_k_accuracy_5=0.512620

As you've seen, the metrics of fc_backward_weight is so close to zero, and it didn't change a lot during batchs.

2017-01-07 15:44:38,845 Node[0] Batch:       1 fc_backward_weight             5.1907e-07    
2017-01-07 15:44:38,846 Node[0] Batch:       1 fc_backward_weight             4.2085e-07    
2017-01-07 15:44:38,847 Node[0] Batch:       1 fc_backward_weight             4.31894e-07   
2017-01-07 15:44:38,848 Node[0] Batch:       1 fc_backward_weight             5.80652e-07

2017-01-07 15:45:50,199 Node[0] Batch:    4213 fc_backward_weight             5.49988e-07   
2017-01-07 15:45:50,200 Node[0] Batch:    4213 fc_backward_weight             5.89305e-07   
2017-01-07 15:45:50,201 Node[0] Batch:    4213 fc_backward_weight             3.71941e-07   
2017-01-07 15:45:50,202 Node[0] Batch:    4213 fc_backward_weight             8.05085e-07

You might wonder why we have 4 different fc_backward_weight, cause we use 4 cpus.

Uniform and ReLu


In [7]:
# Uniform and sigmoid
args = parse_args('uniform', 'uniform_relu')
data_shape = (784, )
net = get_mlp("relu")

# train
fit(args, net, get_iterator(data_shape))


2017-01-12 19:31:37,722 Node[0] start with arguments Namespace(batch_size=128, data_dir='mnist/', gpus=None, init='uniform', kv_store='local', load_epoch=None, lr=0.1, lr_factor=1, lr_factor_epoch=1, model_prefix=None, name='uniform_relu', network='mlp', num_epochs=10, num_examples=60000, save_model_prefix=None)
2017-01-12 19:31:40,900 Node[0] [Deprecation Warning] mxnet.model.FeedForward has been deprecated. Please use mxnet.mod.Module instead.
2017-01-12 19:31:40,906 Node[0] Start training with [cpu(0), cpu(1), cpu(2), cpu(3)]
2017-01-12 19:31:43,382 Node[0] Batch:       1 fc_backward_weight             0.000267409	
2017-01-12 19:31:43,383 Node[0] Batch:       1 fc_backward_weight             0.00031988	
2017-01-12 19:31:43,384 Node[0] Batch:       1 fc_backward_weight             0.000306785	
2017-01-12 19:31:43,385 Node[0] Batch:       1 fc_backward_weight             0.000347533	
2017-01-12 19:31:48,518 Node[0] Epoch[0] Resetting Data Iterator
2017-01-12 19:31:48,519 Node[0] Epoch[0] Time cost=7.595
2017-01-12 19:31:48,821 Node[0] Epoch[0] Validation-accuracy=0.694912
2017-01-12 19:31:48,822 Node[0] Epoch[0] Validation-top_k_accuracy_5=0.976362
2017-01-12 19:31:51,256 Node[0] Batch:     469 fc_backward_weight             0.0527437	
2017-01-12 19:31:51,257 Node[0] Batch:     469 fc_backward_weight             0.0421219	
2017-01-12 19:31:51,259 Node[0] Batch:     469 fc_backward_weight             0.0495309	
2017-01-12 19:31:51,260 Node[0] Batch:     469 fc_backward_weight             0.0421051	
2017-01-12 19:31:56,442 Node[0] Epoch[1] Resetting Data Iterator
2017-01-12 19:31:56,443 Node[0] Epoch[1] Time cost=7.619
2017-01-12 19:31:56,723 Node[0] Epoch[1] Validation-accuracy=0.907652
2017-01-12 19:31:56,724 Node[0] Epoch[1] Validation-top_k_accuracy_5=0.986679
2017-01-12 19:31:59,184 Node[0] Batch:     937 fc_backward_weight             0.0285753	
2017-01-12 19:31:59,185 Node[0] Batch:     937 fc_backward_weight             0.0520748	
2017-01-12 19:31:59,186 Node[0] Batch:     937 fc_backward_weight             0.0807526	
2017-01-12 19:31:59,187 Node[0] Batch:     937 fc_backward_weight             0.0502396	
2017-01-12 19:32:04,648 Node[0] Epoch[2] Resetting Data Iterator
2017-01-12 19:32:04,649 Node[0] Epoch[2] Time cost=7.924
2017-01-12 19:32:04,923 Node[0] Epoch[2] Validation-accuracy=0.921675
2017-01-12 19:32:04,923 Node[0] Epoch[2] Validation-top_k_accuracy_5=0.987380
2017-01-12 19:32:07,411 Node[0] Batch:    1405 fc_backward_weight             0.0596137	
2017-01-12 19:32:07,412 Node[0] Batch:    1405 fc_backward_weight             0.145902	
2017-01-12 19:32:07,413 Node[0] Batch:    1405 fc_backward_weight             0.0783883	
2017-01-12 19:32:07,414 Node[0] Batch:    1405 fc_backward_weight             0.0810687	
2017-01-12 19:32:13,291 Node[0] Epoch[3] Resetting Data Iterator
2017-01-12 19:32:13,292 Node[0] Epoch[3] Time cost=8.368
2017-01-12 19:32:13,621 Node[0] Epoch[3] Validation-accuracy=0.947516
2017-01-12 19:32:13,623 Node[0] Epoch[3] Validation-top_k_accuracy_5=0.990084
2017-01-12 19:32:16,028 Node[0] Batch:    1873 fc_backward_weight             0.113804	
2017-01-12 19:32:16,029 Node[0] Batch:    1873 fc_backward_weight             0.0355092	
2017-01-12 19:32:16,030 Node[0] Batch:    1873 fc_backward_weight             0.0510211	
2017-01-12 19:32:16,031 Node[0] Batch:    1873 fc_backward_weight             0.0461469	
2017-01-12 19:32:20,539 Node[0] Epoch[4] Resetting Data Iterator
2017-01-12 19:32:20,541 Node[0] Epoch[4] Time cost=6.917
2017-01-12 19:32:20,823 Node[0] Epoch[4] Validation-accuracy=0.949319
2017-01-12 19:32:20,823 Node[0] Epoch[4] Validation-top_k_accuracy_5=0.991587
2017-01-12 19:32:23,312 Node[0] Batch:    2341 fc_backward_weight             0.0304884	
2017-01-12 19:32:23,313 Node[0] Batch:    2341 fc_backward_weight             0.0153732	
2017-01-12 19:32:23,314 Node[0] Batch:    2341 fc_backward_weight             0.0638052	
2017-01-12 19:32:23,315 Node[0] Batch:    2341 fc_backward_weight             0.0358958	
2017-01-12 19:32:27,721 Node[0] Epoch[5] Resetting Data Iterator
2017-01-12 19:32:27,722 Node[0] Epoch[5] Time cost=6.897
2017-01-12 19:32:28,116 Node[0] Epoch[5] Validation-accuracy=0.952224
2017-01-12 19:32:28,117 Node[0] Epoch[5] Validation-top_k_accuracy_5=0.991687
2017-01-12 19:32:30,555 Node[0] Batch:    2809 fc_backward_weight             0.180743	
2017-01-12 19:32:30,556 Node[0] Batch:    2809 fc_backward_weight             0.0453026	
2017-01-12 19:32:30,558 Node[0] Batch:    2809 fc_backward_weight             0.0212601	
2017-01-12 19:32:30,558 Node[0] Batch:    2809 fc_backward_weight             0.0950233	
2017-01-12 19:32:36,190 Node[0] Epoch[6] Resetting Data Iterator
2017-01-12 19:32:36,191 Node[0] Epoch[6] Time cost=8.074
2017-01-12 19:32:36,548 Node[0] Epoch[6] Validation-accuracy=0.949219
2017-01-12 19:32:36,552 Node[0] Epoch[6] Validation-top_k_accuracy_5=0.992889
2017-01-12 19:32:39,129 Node[0] Batch:    3277 fc_backward_weight             0.0977342	
2017-01-12 19:32:39,130 Node[0] Batch:    3277 fc_backward_weight             0.0354421	
2017-01-12 19:32:39,131 Node[0] Batch:    3277 fc_backward_weight             0.00394049	
2017-01-12 19:32:39,132 Node[0] Batch:    3277 fc_backward_weight             0.0402826	
2017-01-12 19:32:44,758 Node[0] Epoch[7] Resetting Data Iterator
2017-01-12 19:32:44,759 Node[0] Epoch[7] Time cost=8.206
2017-01-12 19:32:45,051 Node[0] Epoch[7] Validation-accuracy=0.956130
2017-01-12 19:32:45,052 Node[0] Epoch[7] Validation-top_k_accuracy_5=0.993389
2017-01-12 19:32:47,585 Node[0] Batch:    3745 fc_backward_weight             0.012503	
2017-01-12 19:32:47,586 Node[0] Batch:    3745 fc_backward_weight             0.064014	
2017-01-12 19:32:47,587 Node[0] Batch:    3745 fc_backward_weight             0.0158367	
2017-01-12 19:32:47,588 Node[0] Batch:    3745 fc_backward_weight             0.00945755	
2017-01-12 19:32:53,593 Node[0] Epoch[8] Resetting Data Iterator
2017-01-12 19:32:53,594 Node[0] Epoch[8] Time cost=8.541
2017-01-12 19:32:54,017 Node[0] Epoch[8] Validation-accuracy=0.957031
2017-01-12 19:32:54,018 Node[0] Epoch[8] Validation-top_k_accuracy_5=0.992788
2017-01-12 19:32:56,820 Node[0] Batch:    4213 fc_backward_weight             0.0226081	
2017-01-12 19:32:56,821 Node[0] Batch:    4213 fc_backward_weight             0.0039793	
2017-01-12 19:32:56,822 Node[0] Batch:    4213 fc_backward_weight             0.0306151	
2017-01-12 19:32:56,823 Node[0] Batch:    4213 fc_backward_weight             0.00818676	
2017-01-12 19:33:02,386 Node[0] Epoch[9] Resetting Data Iterator
2017-01-12 19:33:02,387 Node[0] Epoch[9] Time cost=8.368
2017-01-12 19:33:02,666 Node[0] Epoch[9] Validation-accuracy=0.959736
2017-01-12 19:33:02,667 Node[0] Epoch[9] Validation-top_k_accuracy_5=0.991987

Even we have a "poor" initialization, the model could still converge quickly with proper activation function. And its magnitude has significant difference.

2017-01-07 15:54:12,286 Node[0] Batch:       1 fc_backward_weight             0.000267409   
2017-01-07 15:54:12,287 Node[0] Batch:       1 fc_backward_weight             0.00031988    
2017-01-07 15:54:12,288 Node[0] Batch:       1 fc_backward_weight             0.000306785   
2017-01-07 15:54:12,289 Node[0] Batch:       1 fc_backward_weight             0.000347533

2017-01-07 15:55:25,936 Node[0] Batch:    4213 fc_backward_weight             0.0226081 
2017-01-07 15:55:25,937 Node[0] Batch:    4213 fc_backward_weight             0.0039793 
2017-01-07 15:55:25,937 Node[0] Batch:    4213 fc_backward_weight             0.0306151 
2017-01-07 15:55:25,938 Node[0] Batch:    4213 fc_backward_weight             0.00818676

Xavier and Sigmoid


In [8]:
# Xavier and sigmoid
args = parse_args('xavier', 'xavier_sigmoid')
data_shape = (784, )
net = get_mlp("sigmoid")

# train
fit(args, net, get_iterator(data_shape))


2017-01-12 19:33:02,682 Node[0] start with arguments Namespace(batch_size=128, data_dir='mnist/', gpus=None, init='xavier', kv_store='local', load_epoch=None, lr=0.1, lr_factor=1, lr_factor_epoch=1, model_prefix=None, name='xavier_sigmoid', network='mlp', num_epochs=10, num_examples=60000, save_model_prefix=None)
2017-01-12 19:33:05,863 Node[0] [Deprecation Warning] mxnet.model.FeedForward has been deprecated. Please use mxnet.mod.Module instead.
2017-01-12 19:33:05,871 Node[0] Start training with [cpu(0), cpu(1), cpu(2), cpu(3)]
2017-01-12 19:33:08,355 Node[0] Batch:       1 fc_backward_weight             9.27798e-06	
2017-01-12 19:33:08,356 Node[0] Batch:       1 fc_backward_weight             8.58008e-06	
2017-01-12 19:33:08,358 Node[0] Batch:       1 fc_backward_weight             8.96261e-06	
2017-01-12 19:33:08,359 Node[0] Batch:       1 fc_backward_weight             7.33611e-06	
2017-01-12 19:33:13,214 Node[0] Epoch[0] Resetting Data Iterator
2017-01-12 19:33:13,215 Node[0] Epoch[0] Time cost=7.320
2017-01-12 19:33:13,516 Node[0] Epoch[0] Validation-accuracy=0.105769
2017-01-12 19:33:13,517 Node[0] Epoch[0] Validation-top_k_accuracy_5=0.509115
2017-01-12 19:33:15,898 Node[0] Batch:     469 fc_backward_weight             6.76125e-06	
2017-01-12 19:33:15,899 Node[0] Batch:     469 fc_backward_weight             6.54805e-06	
2017-01-12 19:33:15,900 Node[0] Batch:     469 fc_backward_weight             6.80302e-06	
2017-01-12 19:33:15,901 Node[0] Batch:     469 fc_backward_weight             7.39115e-06	
2017-01-12 19:33:21,153 Node[0] Epoch[1] Resetting Data Iterator
2017-01-12 19:33:21,154 Node[0] Epoch[1] Time cost=7.637
2017-01-12 19:33:21,438 Node[0] Epoch[1] Validation-accuracy=0.105769
2017-01-12 19:33:21,439 Node[0] Epoch[1] Validation-top_k_accuracy_5=0.504507
2017-01-12 19:33:23,894 Node[0] Batch:     937 fc_backward_weight             5.83071e-06	
2017-01-12 19:33:23,895 Node[0] Batch:     937 fc_backward_weight             5.59626e-06	
2017-01-12 19:33:23,895 Node[0] Batch:     937 fc_backward_weight             5.776e-06	
2017-01-12 19:33:23,896 Node[0] Batch:     937 fc_backward_weight             6.28738e-06	
2017-01-12 19:33:28,578 Node[0] Epoch[2] Resetting Data Iterator
2017-01-12 19:33:28,580 Node[0] Epoch[2] Time cost=7.139
2017-01-12 19:33:28,870 Node[0] Epoch[2] Validation-accuracy=0.105769
2017-01-12 19:33:28,871 Node[0] Epoch[2] Validation-top_k_accuracy_5=0.510216
2017-01-12 19:33:31,294 Node[0] Batch:    1405 fc_backward_weight             4.951e-06	
2017-01-12 19:33:31,296 Node[0] Batch:    1405 fc_backward_weight             4.72836e-06	
2017-01-12 19:33:31,299 Node[0] Batch:    1405 fc_backward_weight             4.8514e-06	
2017-01-12 19:33:31,302 Node[0] Batch:    1405 fc_backward_weight             5.26915e-06	
2017-01-12 19:33:36,266 Node[0] Epoch[3] Resetting Data Iterator
2017-01-12 19:33:36,267 Node[0] Epoch[3] Time cost=7.395
2017-01-12 19:33:36,576 Node[0] Epoch[3] Validation-accuracy=0.105769
2017-01-12 19:33:36,577 Node[0] Epoch[3] Validation-top_k_accuracy_5=0.509014
2017-01-12 19:33:38,997 Node[0] Batch:    1873 fc_backward_weight             4.22193e-06	
2017-01-12 19:33:38,998 Node[0] Batch:    1873 fc_backward_weight             4.03044e-06	
2017-01-12 19:33:38,999 Node[0] Batch:    1873 fc_backward_weight             4.11877e-06	
2017-01-12 19:33:39,000 Node[0] Batch:    1873 fc_backward_weight             4.45402e-06	
2017-01-12 19:33:44,271 Node[0] Epoch[4] Resetting Data Iterator
2017-01-12 19:33:44,272 Node[0] Epoch[4] Time cost=7.695
2017-01-12 19:33:44,567 Node[0] Epoch[4] Validation-accuracy=0.105769
2017-01-12 19:33:44,568 Node[0] Epoch[4] Validation-top_k_accuracy_5=0.509014
2017-01-12 19:33:47,092 Node[0] Batch:    2341 fc_backward_weight             3.64564e-06	
2017-01-12 19:33:47,094 Node[0] Batch:    2341 fc_backward_weight             3.48901e-06	
2017-01-12 19:33:47,094 Node[0] Batch:    2341 fc_backward_weight             3.55765e-06	
2017-01-12 19:33:47,095 Node[0] Batch:    2341 fc_backward_weight             3.82692e-06	
2017-01-12 19:33:52,308 Node[0] Epoch[5] Resetting Data Iterator
2017-01-12 19:33:52,309 Node[0] Epoch[5] Time cost=7.740
2017-01-12 19:33:52,674 Node[0] Epoch[5] Validation-accuracy=0.105769
2017-01-12 19:33:52,675 Node[0] Epoch[5] Validation-top_k_accuracy_5=0.509014
2017-01-12 19:33:55,045 Node[0] Batch:    2809 fc_backward_weight             3.19336e-06	
2017-01-12 19:33:55,046 Node[0] Batch:    2809 fc_backward_weight             3.06777e-06	
2017-01-12 19:33:55,046 Node[0] Batch:    2809 fc_backward_weight             3.12543e-06	
2017-01-12 19:33:55,047 Node[0] Batch:    2809 fc_backward_weight             3.34344e-06	
2017-01-12 19:34:02,004 Node[0] Epoch[6] Resetting Data Iterator
2017-01-12 19:34:02,005 Node[0] Epoch[6] Time cost=9.328
2017-01-12 19:34:02,295 Node[0] Epoch[6] Validation-accuracy=0.107472
2017-01-12 19:34:02,296 Node[0] Epoch[6] Validation-top_k_accuracy_5=0.509014
2017-01-12 19:34:04,830 Node[0] Batch:    3277 fc_backward_weight             2.83478e-06	
2017-01-12 19:34:04,831 Node[0] Batch:    3277 fc_backward_weight             2.73443e-06	
2017-01-12 19:34:04,832 Node[0] Batch:    3277 fc_backward_weight             2.78607e-06	
2017-01-12 19:34:04,833 Node[0] Batch:    3277 fc_backward_weight             2.9644e-06	
2017-01-12 19:34:10,903 Node[0] Epoch[7] Resetting Data Iterator
2017-01-12 19:34:10,904 Node[0] Epoch[7] Time cost=8.607
2017-01-12 19:34:11,209 Node[0] Epoch[7] Validation-accuracy=0.105970
2017-01-12 19:34:11,210 Node[0] Epoch[7] Validation-top_k_accuracy_5=0.512620
2017-01-12 19:34:13,543 Node[0] Batch:    3745 fc_backward_weight             2.54587e-06	
2017-01-12 19:34:13,544 Node[0] Batch:    3745 fc_backward_weight             2.46527e-06	
2017-01-12 19:34:13,545 Node[0] Batch:    3745 fc_backward_weight             2.51372e-06	
2017-01-12 19:34:13,546 Node[0] Batch:    3745 fc_backward_weight             2.66109e-06	
2017-01-12 19:34:17,928 Node[0] Epoch[8] Resetting Data Iterator
2017-01-12 19:34:17,928 Node[0] Epoch[8] Time cost=6.718
2017-01-12 19:34:18,225 Node[0] Epoch[8] Validation-accuracy=0.105970
2017-01-12 19:34:18,226 Node[0] Epoch[8] Validation-top_k_accuracy_5=0.512620
2017-01-12 19:34:20,790 Node[0] Batch:    4213 fc_backward_weight             2.30903e-06	
2017-01-12 19:34:20,791 Node[0] Batch:    4213 fc_backward_weight             2.24373e-06	
2017-01-12 19:34:20,792 Node[0] Batch:    4213 fc_backward_weight             2.29058e-06	
2017-01-12 19:34:20,793 Node[0] Batch:    4213 fc_backward_weight             2.41351e-06	
2017-01-12 19:34:26,309 Node[0] Epoch[9] Resetting Data Iterator
2017-01-12 19:34:26,310 Node[0] Epoch[9] Time cost=8.083
2017-01-12 19:34:26,629 Node[0] Epoch[9] Validation-accuracy=0.105970
2017-01-12 19:34:26,630 Node[0] Epoch[9] Validation-top_k_accuracy_5=0.512620

Visualization

Now start using TensorBoard:

tensorboard --logdir=logs/