In [53]:
import mxnet as mx
from collections import namedtuple
import cv2
import os, urllib
import numpy as np

class ScaleInitializer(mx.init.Initializer):
    """
    Customized initializer for scale layer
    """
    def __init__(self):
        pass

    def _init_default(self, name, arr):
        if name.endswith("scale"):
            self._init_one(name, arr)
        else:
            raise ValueError('Unknown initialization pattern for %s' % name)
            
initializer = mx.init.Mixed([".*scale", ".*"], [ScaleInitializer(), mx.init.Xavier(magnitude=1)])

def save_net(model, prefix):
    model.save_params(prefix+'-0001.params')
    model.symbol.save(prefix+'-symbol.json')
    
def print_net_params(param):
    for m in sorted(param):
        print m, param[m].shape
        
def plot_network(symbol, name='network'):
    graph = mx.viz.plot_network(symbol=symbol, node_attrs={"shape":'rect', "fixedsize":'false', 'rankdir': 'TB'})
    graph.format = 'png'
    graph.render('network.gv', view=True)
    
def print_inferred_shape(net, name, nch=3, size=300, node_type='rgb'):
    if node_type == 'rgb':
        ar, ou, au = net.infer_shape(rgb=(1, nch, size, size))
    if node_type == 'tir': 
        ar, ou, au = net.infer_shape(tir=(1, nch, size, size))
    if node_type == 'spectral':
        ar, ou, au = net.infer_shape(tir=(1, 1, size, size), rgb=(1, nch, size, size))
    print ou
    
def multibox_layer(from_layers, num_classes, sizes=[.2, .95], ratios=[1], normalization=-1, num_channels=[], clip=True, interm_layer=0):
    """
    the basic aggregation module for SSD detection. Takes in multiple layers,
    generate multiple object detection targets by customized layers

    Parameters:
    ----------
    from_layers : list of mx.symbol
        generate multibox detection from layers
    num_classes : int
        number of classes excluding background, will automatically handle
        background in this function
    sizes : list or list of list
        [min_size, max_size] for all layers or [[], [], []...] for specific layers
    ratios : list or list of list
        [ratio1, ratio2...] for all layers or [[], [], ...] for specific layers
    normalizations : int or list of int
        use normalizations value for all layers or [...] for specific layers,
        -1 indicate no normalizations and scales
    num_channels : list of int
        number of input layer channels, used when normalization is enabled, the
        length of list should equals to number of normalization layers
    clip : bool
        whether to clip out-of-image boxes
    interm_layer : int
        if > 0, will add a intermediate Convolution layer

    Returns:
    ----------
    list of outputs, as [loc_preds, cls_preds, anchor_boxes]
    loc_preds : localization regression prediction
    cls_preds : classification prediction
    anchor_boxes : generated anchor boxes
    """
    assert len(from_layers) > 0, "from_layers must not be empty list"
    assert num_classes > 0,      "num_classes {} must be larger than 0".format(num_classes)

    assert len(ratios) > 0, "aspect ratios must not be empty list"
    if not isinstance(ratios[0], list):
        # provided only one ratio list, broadcast to all from_layers
        ratios = [ratios] * len(from_layers)
    assert len(ratios) == len(from_layers), \
        "ratios and from_layers must have same length"

    assert len(sizes) > 0, "sizes must not be empty list"
    if len(sizes) == 2 and not isinstance(sizes[0], list):
        # provided size range, we need to compute the sizes for each layer
         assert sizes[0] > 0 and sizes[0] < 1
         assert sizes[1] > 0 and sizes[1] < 1 and sizes[1] > sizes[0]
         tmp = np.linspace(sizes[0], sizes[1], num=(len(from_layers)-1))
         min_sizes = [start_offset] + tmp.tolist()
         max_sizes = tmp.tolist() + [tmp[-1]+start_offset]
         sizes = zip(min_sizes, max_sizes)
    assert len(sizes) == len(from_layers), \
        "sizes and from_layers must have same length"

    if not isinstance(normalization, list):
        normalization = [normalization] * len(from_layers)
    assert len(normalization) == len(from_layers)

    assert sum(x > 0 for x in normalization) == len(num_channels), \
        "must provide number of channels for each normalized layer"

    loc_pred_layers = []
    cls_pred_layers = []
    anchor_layers = []
    num_classes += 1 # always use background as label 0

    for k, from_layer in enumerate(from_layers):
        from_name = from_layer.name
        # normalize
        if normalization[k] > 0:
            from_layer = mx.symbol.L2Normalization(data=from_layer, mode="channel", name="{}_norm".format(from_name))
            scale = mx.symbol.Variable(name="{}_scale".format(from_name),shape=(1, num_channels.pop(0), 1, 1))
            from_layer = normalization[k] * mx.symbol.broadcast_mul(lhs=scale, rhs=from_layer)
        if interm_layer > 0:
            from_layer = mx.symbol.Convolution(data=from_layer, kernel=(3,3), stride=(1,1), pad=(1,1), num_filter=interm_layer, name="{}_inter_conv".format(from_name))
            from_layer = mx.symbol.Activation(data=from_layer, act_type="relu", name="{}_inter_relu".format(from_name))

        # estimate number of anchors per location
        # here I follow the original version in caffe
        # TODO: better way to shape the anchors??
        size = sizes[k]
        assert len(size) > 0, "must provide at least one size"
        size_str = "(" + ",".join([str(x) for x in size]) + ")"
        ratio = ratios[k]
        assert len(ratio) > 0, "must provide at least one ratio"
        ratio_str = "(" + ",".join([str(x) for x in ratio]) + ")"
        num_anchors = len(size) -1 + len(ratio)

        # create location prediction layer
        num_loc_pred = num_anchors * 4
        loc_pred = mx.symbol.Convolution(data=from_layer, kernel=(3,3), stride=(1,1), pad=(1,1), num_filter=num_loc_pred, name="{}_loc_pred_conv".format(from_name))
        loc_pred = mx.symbol.transpose(loc_pred, axes=(0,2,3,1))
        loc_pred = mx.symbol.Flatten(data=loc_pred)
        loc_pred_layers.append(loc_pred)

        # create class prediction layer
        num_cls_pred = num_anchors * num_classes
        cls_pred = mx.symbol.Convolution(data=from_layer, kernel=(3,3), stride=(1,1), pad=(1,1), num_filter=num_cls_pred, name="{}_cls_pred_conv".format(from_name))
        cls_pred = mx.symbol.transpose(cls_pred, axes=(0,2,3,1))
        cls_pred = mx.symbol.Flatten(data=cls_pred)
        cls_pred_layers.append(cls_pred)

        # create anchor generation layer
        anchors = mx.contrib.symbol.MultiBoxPrior(from_layer, sizes=size_str, ratios=ratio_str, clip=clip, name="{}_anchors".format(from_name))
        anchors = mx.symbol.Flatten(data=anchors)
        anchor_layers.append(anchors)

    loc_preds = mx.symbol.Concat(*loc_pred_layers, num_args=len(loc_pred_layers), dim=1, name="multibox_loc_pred")
    cls_preds = mx.symbol.Concat(*cls_pred_layers, num_args=len(cls_pred_layers), dim=1)
    cls_preds = mx.symbol.Reshape(data=cls_preds, shape=(0, -1, num_classes))
    cls_preds = mx.symbol.transpose(cls_preds, axes=(0, 2, 1), name="multibox_cls_pred")
    anchor_boxes = mx.symbol.Concat(*anchor_layers, num_args=len(anchor_layers), dim=1)
    anchor_boxes = mx.symbol.Reshape(data=anchor_boxes, shape=(0, -1, 4), name="multibox_anchors")
    return [loc_preds, cls_preds, anchor_boxes]

    
    
def bn_act_conv_layer(from_layer, name, num_filter, kernel=(1,1), pad=(0,0), stride=(1,1)):
    bn = mx.symbol.BatchNorm(data=from_layer, name="bn{}".format(name))
    relu = mx.symbol.Activation(data=bn, act_type='relu')
    conv = mx.symbol.Convolution(data=relu, kernel=kernel, pad=pad, stride=stride, num_filter=num_filter, name="conv{}".format(name))
    return conv, relu

def conv_act_layer(from_layer, name, num_filter, kernel=(1,1), pad=(0,0), stride=(1,1), act_type="relu"):
    relu = mx.symbol.Activation(data=from_layer, act_type=act_type, name="{}{}".format(act_type, name))
    conv = mx.symbol.Convolution(data=relu, kernel=kernel, pad=pad, stride=stride, num_filter=num_filter, name="conv{}".format(name))
    return conv, relu
    
def residual_unit(data, num_filter, stride, dim_match, name, bn_mom=0.9, workspace=256):
    bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn1')
    act1 = mx.symbol.Activation(data=bn1, act_type='relu', name=name + '_relu1')
    conv1 = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(3, 3), stride=stride, pad=(1, 1),
                               no_bias=True, workspace=workspace, name=name + '_conv1')
    bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2')
    act2 = mx.symbol.Activation(data=bn2, act_type='relu', name=name + '_relu2')
    conv2 = mx.sym.Convolution(data=act2, num_filter=num_filter, kernel=(3, 3), stride=(1, 1), pad=(1, 1),
                               no_bias=True, workspace=workspace, name=name + '_conv2')
    if dim_match:
        shortcut = data
    else:
        shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1, 1), stride=stride, no_bias=True,
                                      workspace=workspace, name=name + '_sc')
    return conv2 + shortcut


# fusion functions
def cf_unit(res_unit_rgb, res_unit_tir, num_filters=64, name='fusion1_unit_1', mode='conv'):
    if mode == 'conv':
        concat = mx.symbol.Concat(*[res_unit_rgb, res_unit_tir], dim=1)
        conv = mx.sym.Convolution(concat, num_filter=num_filters, kernel=(3, 3), stride=(1, 1), pad=(1, 1), no_bias=True, workspace=256, name=name + '_conv')
        act = mx.symbol.Activation(data=conv, act_type='relu', name=name + '_relu1')

    elif mode == 'sum':
        conv = mx.symbol.broadcast_add(res_unit_rgb, res_unit_tir, name=name + '_conv')

    elif mode == 'max':
        conv = mx.symbol.broadcast_maximum(res_unit_rgb, res_unit_tir, name=name + '_conv')
    return conv


def spectral_net():
    filter_list = [32, 32, 64, 128, 256]

    tir = mx.sym.Variable(name='tir')

    net_tir = mx.sym.Convolution(tir, num_filter=filter_list[0], kernel=(3, 3), stride=(1, 1), pad=(1, 1), name='conv_0')
    net_tir = mx.symbol.Activation(net_tir, act_type='relu', name='relu_0')
    net_tir = mx.symbol.Pooling(net_tir, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max', name='pool_0')

    net_tir = mx.sym.Convolution(net_tir, num_filter=filter_list[1], kernel=(3, 3), stride=(1, 1), pad=(1, 1), name='conv_1')
    net_tir = mx.symbol.Activation(net_tir, act_type='relu', name='relu_1')
    net_tir = mx.symbol.Pooling(net_tir, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max', name='pool1')

    net_tir = mx.sym.Convolution(net_tir, num_filter=filter_list[2], kernel=(3, 3), stride=(1, 1), pad=(1, 1), name='conv2')
    net_tir = mx.symbol.Activation(net_tir, act_type='relu', name='relu2')
    net_tir = mx.symbol.Pooling(net_tir, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max', name='pool2')

    net_tir = mx.sym.Convolution(net_tir, num_filter=filter_list[3], kernel=(3, 3), stride=(1, 1), pad=(1, 1), name='conv3')
    net_tir = mx.symbol.Activation(net_tir, act_type='relu', name='relu3')
    net_tir = mx.symbol.Pooling(net_tir, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max', name='pool3')

    net_tir = mx.sym.Convolution(net_tir, num_filter=filter_list[4], kernel=(3, 3), stride=(1, 1), pad=(1, 1), name='conv4')
    net_tir = mx.symbol.Activation(net_tir, act_type='relu', name='relu4')
    net_tir = mx.symbol.Pooling(net_tir, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max', name='pool4')

    net_tir, relu9_2t = conv_act_layer(net_tir, "f4", 128, kernel=(3, 3), pad=(1, 1), stride=(2, 2))
    net_tir, relu10_2t = conv_act_layer(net_tir, "f5", 128, kernel=(3, 3), pad=(1, 1), stride=(2, 2))

    fusion_1 = net_tir.get_internals()["pool2_output"]
    fusion_2 = net_tir.get_internals()["pool3_output"]
    fusion_3 = net_tir.get_internals()["pool4_output"]
    fusion_4 = net_tir.get_internals()["convf4_output"]
    fusion_5 = net_tir.get_internals()["convf5_output"]
    return [fusion_1, fusion_2, fusion_3, fusion_4, fusion_5]



def resnet():
    filter_list = [64, 64, 128, 256, 512]

    bn_mom = 0.9
    workspace = 256

    rgb = mx.sym.Variable(name='rgb')

    # rgb head
    net_rgb = mx.sym.Convolution(rgb, num_filter=filter_list[0], kernel=(7, 7), stride=(2, 2), pad=(3, 3), no_bias=True, name="conv0", workspace=workspace)
    net_rgb = mx.symbol.Activation(net_rgb, act_type='relu', name='relu0')
    net_rgb = mx.symbol.Pooling(net_rgb, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max')

    # stage 1
    net_rgb = residual_unit(net_rgb, filter_list[1], (1, 1), False, name='stage1_unit1', workspace=workspace)
    net_rgb = residual_unit(net_rgb, filter_list[1], (1, 1), True, name='stage1_unit2', workspace=workspace)

    # stage 2
    net_rgb = residual_unit(net_rgb, filter_list[2], (2, 2), False, name='stage2_unit1', workspace=workspace)
    net_rgb = residual_unit(net_rgb, filter_list[2], (1, 1), True, name='stage2_unit2', workspace=workspace)

    # stage 3
    net_rgb = residual_unit(net_rgb, filter_list[3], (2, 2), False, name='stage3_unit1', workspace=workspace)
    net_rgb = residual_unit(net_rgb, filter_list[3], (1, 1), True, name='stage3_unit2', workspace=workspace)

    # stage 4
    net_rgb = residual_unit(net_rgb, filter_list[4], (2, 2), False, name='stage4_unit1', workspace=workspace)
    net_rgb = residual_unit(net_rgb, filter_list[4], (1, 1), True, name='stage4_unit2', workspace=workspace)

    # ssd extra layers
    extra_layers_input = net_rgb.get_internals()["stage4_unit1_relu1_output"]  # 19x19
    conv8_1, relu8_1 = bn_act_conv_layer(extra_layers_input, "8_1", 256, kernel=(1, 1), pad=(0, 0), stride=(1, 1))
    conv8_2, relu8_2 = bn_act_conv_layer(conv8_1, "8_2", 512, kernel=(3, 3), pad=(1, 1), stride=(2, 2))
    conv9_1, relu9_1 = bn_act_conv_layer(conv8_2, "9_1", 128, kernel=(1, 1), pad=(0, 0), stride=(1, 1))
    conv9_2, relu9_2 = bn_act_conv_layer(conv9_1, "9_2", 256, kernel=(3, 3), pad=(1, 1), stride=(2, 2))
    conv10_1, relu10_1 = bn_act_conv_layer(conv9_2, "10_1", 128, kernel=(1, 1), pad=(0, 0), stride=(1, 1))
    conv10_2, relu10_2 = bn_act_conv_layer(conv10_1, "10_2", 256, kernel=(3, 3), pad=(1, 1), stride=(2, 2))
    pool10 = mx.symbol.Pooling(data=conv10_2, pool_type="avg", global_pool=True, kernel=(1, 1), name='pool10')
    net_rgb = pool10

    fusion_1 = net_rgb.get_internals()["stage3_unit1_relu1_output"]
    fusion_2 = net_rgb.get_internals()["stage4_unit1_relu1_output"]
    fusion_3 = conv8_2
    fusion_4 = conv9_2
    fusion_5 = conv10_2

    return [fusion_1, fusion_2, fusion_3, fusion_4, fusion_5, pool10]

def fusion_net():
    rgb_fusion_layers = resnet()
    tir_fusion_layers = spectral_net()
    input_1 = cf_unit(rgb_fusion_layers[0], tir_fusion_layers[0], num_filters=128, name='fusion1')
    input_2 = cf_unit(rgb_fusion_layers[1], tir_fusion_layers[1], num_filters=256, name='fusion2')
    input_3 = cf_unit(rgb_fusion_layers[2], tir_fusion_layers[2], num_filters=512, name='fusion3')
    input_4 = cf_unit(rgb_fusion_layers[3], tir_fusion_layers[3], num_filters=256, name='fusion4')
    input_5 = cf_unit(rgb_fusion_layers[4], tir_fusion_layers[4], num_filters=256, name='fusion5')
    input_6 = rgb_fusion_layers[5]

    
    from_layers = [input_1, input_2, input_3, input_4, input_5, input_6]
    sizes = [[.1], [.2, .276], [.38, .461], [.56, .644], [.74, .825], [.92, 1.01]]
    ratios = [[1, 2, .5],
              [1, 2, .5, 3, 1. / 3],
              [1, 2, .5, 3, 1. / 3],
              [1, 2, .5, 3, 1. / 3],
              [1, 2, .5, 3, 1. / 3],
              [1, 2, .5, 3, 1. / 3]]

    normalizations = [20, -1, -1, -1, -1, -1]
    num_channels = [128]

    loc_preds, cls_preds, anchor_boxes = multibox_layer(from_layers, 1,
                                                        sizes=sizes,
                                                        ratios=ratios,
                                                        normalization=normalizations,
                                                        num_channels=num_channels,
                                                        clip=True,
                                                        interm_layer=0)

    
    return loc_preds

In [54]:
symb=fusion_net()

In [55]:
internals = resnet()[-1].get_internals()
for output in internals.list_outputs():
    if output in ["stage3_unit1_relu1_output", 
                  "stage4_unit1_relu1_output", 
                  'conv8_2_output',
                  'conv9_2_output',
                  'conv10_2_output']:
        print output, internals[output].infer_shape(rgb=(1,1,300,300))[1]


stage3_unit1_relu1_output [(1L, 128L, 38L, 38L)]
stage4_unit1_relu1_output [(1L, 256L, 19L, 19L)]
conv8_2_output [(1L, 512L, 10L, 10L)]
conv9_2_output [(1L, 256L, 5L, 5L)]
conv10_2_output [(1L, 256L, 3L, 3L)]

In [56]:
internals = spectral_net()[-1].get_internals()
for output in internals.list_outputs():
    if output in ['pool2_output',
                  'pool3_output', 
                  'pool4_output',
                  'conv8_2t_output', 
                  'conv9_2t_output',
                  'conv10_2t_output',]:
        print output, internals[output].infer_shape(tir=(1,1,300,300))[1]


pool2_output [(1L, 64L, 38L, 38L)]
pool3_output [(1L, 128L, 19L, 19L)]
pool4_output [(1L, 256L, 10L, 10L)]

In [57]:
internals = fusion_net().get_internals()
for output in internals.list_outputs():
    if output in ['fusion1_conv_output',
                  'fusion2_conv_output', 
                  'fusion3_conv_output', 
                  'fusion4_conv_output', 
                  'fusion5_conv_output']:
        print output, internals[output].infer_shape(tir=(1,1,300,300), rgb=(1,1,300,300))[1]


fusion1_conv_output [(1L, 128L, 38L, 38L)]
fusion2_conv_output [(1L, 256L, 19L, 19L)]
fusion3_conv_output [(1L, 512L, 10L, 10L)]
fusion4_conv_output [(1L, 256L, 5L, 5L)]
fusion5_conv_output [(1L, 256L, 3L, 3L)]

In [58]:
# read trained ssd-resnet 0.68
net_params_ssd = mx.initializer.Load('ssd_300-0300.params')
net_params = {}
for i in net_params_ssd.param:
    if 'loc' in i.split('_'):
        continue
    if 'pred' in i.split('_'):
        continue
#     if i.split('_')[0] in ['stage1', 
#                            'stage2', 
#                            'stage3',
#                            'stage4', 'conv8', 'conv9', 'conv10', 'pool10', 'bn', 'conv0']:
    net_params[i] = net_params_ssd.param[i]
net_params_ssd.param = net_params
# sorted(net_params.keys())

In [59]:
model = mx.mod.Module(symbol=fusion_net(), data_names=['rgb', 'tir'])
model.bind(data_shapes=[('rgb', (1, 3, 300, 300)), ('tir', (1, 1, 300, 300))])
model.init_params(arg_params=net_params,  allow_missing=True, initializer=initializer)

In [60]:
save_net(model, 'spectral')

In [20]:
net_params_final = mx.initializer.Load('spectral-0001.params')

In [21]:
net_params_final.param


Out[21]:
{'bn0_beta': <NDArray 64 @cpu(0)>,
 'bn0_gamma': <NDArray 64 @cpu(0)>,
 'bn0_moving_mean': <NDArray 64 @cpu(0)>,
 'bn0_moving_var': <NDArray 64 @cpu(0)>,
 'bn10_1_beta': <NDArray 256 @cpu(0)>,
 'bn10_1_gamma': <NDArray 256 @cpu(0)>,
 'bn10_1_moving_mean': <NDArray 256 @cpu(0)>,
 'bn10_1_moving_var': <NDArray 256 @cpu(0)>,
 'bn10_2_beta': <NDArray 128 @cpu(0)>,
 'bn10_2_gamma': <NDArray 128 @cpu(0)>,
 'bn10_2_moving_mean': <NDArray 128 @cpu(0)>,
 'bn10_2_moving_var': <NDArray 128 @cpu(0)>,
 'bn8_1_beta': <NDArray 256 @cpu(0)>,
 'bn8_1_gamma': <NDArray 256 @cpu(0)>,
 'bn8_1_moving_mean': <NDArray 256 @cpu(0)>,
 'bn8_1_moving_var': <NDArray 256 @cpu(0)>,
 'bn8_2_beta': <NDArray 256 @cpu(0)>,
 'bn8_2_gamma': <NDArray 256 @cpu(0)>,
 'bn8_2_moving_mean': <NDArray 256 @cpu(0)>,
 'bn8_2_moving_var': <NDArray 256 @cpu(0)>,
 'bn9_1_beta': <NDArray 512 @cpu(0)>,
 'bn9_1_gamma': <NDArray 512 @cpu(0)>,
 'bn9_1_moving_mean': <NDArray 512 @cpu(0)>,
 'bn9_1_moving_var': <NDArray 512 @cpu(0)>,
 'bn9_2_beta': <NDArray 128 @cpu(0)>,
 'bn9_2_gamma': <NDArray 128 @cpu(0)>,
 'bn9_2_moving_mean': <NDArray 128 @cpu(0)>,
 'bn9_2_moving_var': <NDArray 128 @cpu(0)>,
 'bn_data_beta': <NDArray 3 @cpu(0)>,
 'bn_data_gamma': <NDArray 3 @cpu(0)>,
 'bn_data_moving_mean': <NDArray 3 @cpu(0)>,
 'bn_data_moving_var': <NDArray 3 @cpu(0)>,
 'conv0_weight': <NDArray 64x3x7x7 @cpu(0)>,
 'conv10_1_bias': <NDArray 128 @cpu(0)>,
 'conv10_1_weight': <NDArray 128x256x1x1 @cpu(0)>,
 'conv10_2_bias': <NDArray 256 @cpu(0)>,
 'conv10_2_weight': <NDArray 256x128x3x3 @cpu(0)>,
 'conv10_2t_bias': <NDArray 256 @cpu(0)>,
 'conv10_2t_weight': <NDArray 256x256x3x3 @cpu(0)>,
 'conv2_bias': <NDArray 128 @cpu(0)>,
 'conv2_weight': <NDArray 128x64x3x3 @cpu(0)>,
 'conv3_bias': <NDArray 256 @cpu(0)>,
 'conv3_weight': <NDArray 256x128x3x3 @cpu(0)>,
 'conv4_bias': <NDArray 512 @cpu(0)>,
 'conv4_weight': <NDArray 512x256x3x3 @cpu(0)>,
 'conv8_1_bias': <NDArray 256 @cpu(0)>,
 'conv8_1_weight': <NDArray 256x256x1x1 @cpu(0)>,
 'conv8_2_bias': <NDArray 512 @cpu(0)>,
 'conv8_2_weight': <NDArray 512x256x3x3 @cpu(0)>,
 'conv9_1_bias': <NDArray 128 @cpu(0)>,
 'conv9_1_weight': <NDArray 128x512x1x1 @cpu(0)>,
 'conv9_2_bias': <NDArray 256 @cpu(0)>,
 'conv9_2_weight': <NDArray 256x128x3x3 @cpu(0)>,
 'conv9_2t_bias': <NDArray 256 @cpu(0)>,
 'conv9_2t_weight': <NDArray 256x512x3x3 @cpu(0)>,
 'conv_0_bias': <NDArray 64 @cpu(0)>,
 'conv_0_weight': <NDArray 64x1x3x3 @cpu(0)>,
 'conv_1_bias': <NDArray 64 @cpu(0)>,
 'conv_1_weight': <NDArray 64x64x3x3 @cpu(0)>,
 'fusion1_conv_bias': <NDArray 128 @cpu(0)>,
 'fusion1_conv_loc_pred_conv_bias': <NDArray 12 @cpu(0)>,
 'fusion1_conv_loc_pred_conv_weight': <NDArray 12x128x3x3 @cpu(0)>,
 'fusion1_conv_scale': <NDArray 1x128x1x1 @cpu(0)>,
 'fusion1_conv_weight': <NDArray 128x256x3x3 @cpu(0)>,
 'fusion2_conv_bias': <NDArray 256 @cpu(0)>,
 'fusion2_conv_loc_pred_conv_bias': <NDArray 24 @cpu(0)>,
 'fusion2_conv_loc_pred_conv_weight': <NDArray 24x256x3x3 @cpu(0)>,
 'fusion2_conv_weight': <NDArray 256x512x3x3 @cpu(0)>,
 'fusion3_conv_bias': <NDArray 512 @cpu(0)>,
 'fusion3_conv_loc_pred_conv_bias': <NDArray 24 @cpu(0)>,
 'fusion3_conv_loc_pred_conv_weight': <NDArray 24x512x3x3 @cpu(0)>,
 'fusion3_conv_weight': <NDArray 512x1024x3x3 @cpu(0)>,
 'fusion4_conv_bias': <NDArray 256 @cpu(0)>,
 'fusion4_conv_loc_pred_conv_bias': <NDArray 24 @cpu(0)>,
 'fusion4_conv_loc_pred_conv_weight': <NDArray 24x256x3x3 @cpu(0)>,
 'fusion4_conv_weight': <NDArray 256x512x3x3 @cpu(0)>,
 'fusion5_conv_bias': <NDArray 256 @cpu(0)>,
 'fusion5_conv_loc_pred_conv_bias': <NDArray 24 @cpu(0)>,
 'fusion5_conv_loc_pred_conv_weight': <NDArray 24x256x3x3 @cpu(0)>,
 'fusion5_conv_weight': <NDArray 256x512x3x3 @cpu(0)>,
 'pool10_loc_pred_conv_bias': <NDArray 24 @cpu(0)>,
 'pool10_loc_pred_conv_weight': <NDArray 24x256x3x3 @cpu(0)>,
 'stage1_unit1_bn1_beta': <NDArray 64 @cpu(0)>,
 'stage1_unit1_bn1_gamma': <NDArray 64 @cpu(0)>,
 'stage1_unit1_bn1_moving_mean': <NDArray 64 @cpu(0)>,
 'stage1_unit1_bn1_moving_var': <NDArray 64 @cpu(0)>,
 'stage1_unit1_bn2_beta': <NDArray 64 @cpu(0)>,
 'stage1_unit1_bn2_gamma': <NDArray 64 @cpu(0)>,
 'stage1_unit1_bn2_moving_mean': <NDArray 64 @cpu(0)>,
 'stage1_unit1_bn2_moving_var': <NDArray 64 @cpu(0)>,
 'stage1_unit1_conv1_weight': <NDArray 64x64x3x3 @cpu(0)>,
 'stage1_unit1_conv2_weight': <NDArray 64x64x3x3 @cpu(0)>,
 'stage1_unit1_sc_weight': <NDArray 64x64x1x1 @cpu(0)>,
 'stage1_unit2_bn1_beta': <NDArray 64 @cpu(0)>,
 'stage1_unit2_bn1_gamma': <NDArray 64 @cpu(0)>,
 'stage1_unit2_bn1_moving_mean': <NDArray 64 @cpu(0)>,
 'stage1_unit2_bn1_moving_var': <NDArray 64 @cpu(0)>,
 'stage1_unit2_bn2_beta': <NDArray 64 @cpu(0)>,
 'stage1_unit2_bn2_gamma': <NDArray 64 @cpu(0)>,
 'stage1_unit2_bn2_moving_mean': <NDArray 64 @cpu(0)>,
 'stage1_unit2_bn2_moving_var': <NDArray 64 @cpu(0)>,
 'stage1_unit2_conv1_weight': <NDArray 64x64x3x3 @cpu(0)>,
 'stage1_unit2_conv2_weight': <NDArray 64x64x3x3 @cpu(0)>,
 'stage2_unit1_bn1_beta': <NDArray 64 @cpu(0)>,
 'stage2_unit1_bn1_gamma': <NDArray 64 @cpu(0)>,
 'stage2_unit1_bn1_moving_mean': <NDArray 64 @cpu(0)>,
 'stage2_unit1_bn1_moving_var': <NDArray 64 @cpu(0)>,
 'stage2_unit1_bn2_beta': <NDArray 128 @cpu(0)>,
 'stage2_unit1_bn2_gamma': <NDArray 128 @cpu(0)>,
 'stage2_unit1_bn2_moving_mean': <NDArray 128 @cpu(0)>,
 'stage2_unit1_bn2_moving_var': <NDArray 128 @cpu(0)>,
 'stage2_unit1_conv1_weight': <NDArray 128x64x3x3 @cpu(0)>,
 'stage2_unit1_conv2_weight': <NDArray 128x128x3x3 @cpu(0)>,
 'stage2_unit1_sc_weight': <NDArray 128x64x1x1 @cpu(0)>,
 'stage2_unit2_bn1_beta': <NDArray 128 @cpu(0)>,
 'stage2_unit2_bn1_gamma': <NDArray 128 @cpu(0)>,
 'stage2_unit2_bn1_moving_mean': <NDArray 128 @cpu(0)>,
 'stage2_unit2_bn1_moving_var': <NDArray 128 @cpu(0)>,
 'stage2_unit2_bn2_beta': <NDArray 128 @cpu(0)>,
 'stage2_unit2_bn2_gamma': <NDArray 128 @cpu(0)>,
 'stage2_unit2_bn2_moving_mean': <NDArray 128 @cpu(0)>,
 'stage2_unit2_bn2_moving_var': <NDArray 128 @cpu(0)>,
 'stage2_unit2_conv1_weight': <NDArray 128x128x3x3 @cpu(0)>,
 'stage2_unit2_conv2_weight': <NDArray 128x128x3x3 @cpu(0)>,
 'stage3_unit1_bn1_beta': <NDArray 128 @cpu(0)>,
 'stage3_unit1_bn1_gamma': <NDArray 128 @cpu(0)>,
 'stage3_unit1_bn1_moving_mean': <NDArray 128 @cpu(0)>,
 'stage3_unit1_bn1_moving_var': <NDArray 128 @cpu(0)>,
 'stage3_unit1_bn2_beta': <NDArray 256 @cpu(0)>,
 'stage3_unit1_bn2_gamma': <NDArray 256 @cpu(0)>,
 'stage3_unit1_bn2_moving_mean': <NDArray 256 @cpu(0)>,
 'stage3_unit1_bn2_moving_var': <NDArray 256 @cpu(0)>,
 'stage3_unit1_conv1_weight': <NDArray 256x128x3x3 @cpu(0)>,
 'stage3_unit1_conv2_weight': <NDArray 256x256x3x3 @cpu(0)>,
 'stage3_unit1_sc_weight': <NDArray 256x128x1x1 @cpu(0)>,
 'stage3_unit2_bn1_beta': <NDArray 256 @cpu(0)>,
 'stage3_unit2_bn1_gamma': <NDArray 256 @cpu(0)>,
 'stage3_unit2_bn1_moving_mean': <NDArray 256 @cpu(0)>,
 'stage3_unit2_bn1_moving_var': <NDArray 256 @cpu(0)>,
 'stage3_unit2_bn2_beta': <NDArray 256 @cpu(0)>,
 'stage3_unit2_bn2_gamma': <NDArray 256 @cpu(0)>,
 'stage3_unit2_bn2_moving_mean': <NDArray 256 @cpu(0)>,
 'stage3_unit2_bn2_moving_var': <NDArray 256 @cpu(0)>,
 'stage3_unit2_conv1_weight': <NDArray 256x256x3x3 @cpu(0)>,
 'stage3_unit2_conv2_weight': <NDArray 256x256x3x3 @cpu(0)>,
 'stage4_unit1_bn1_beta': <NDArray 256 @cpu(0)>,
 'stage4_unit1_bn1_gamma': <NDArray 256 @cpu(0)>,
 'stage4_unit1_bn1_moving_mean': <NDArray 256 @cpu(0)>,
 'stage4_unit1_bn1_moving_var': <NDArray 256 @cpu(0)>}

In [39]:
def bna(net):
    # net = mx.symbol.BatchNorm(net)
    net = mx.symbol.LeakyReLU(net, act_type="elu")
    return net


def conv_bna(net, num_filter, is_pool=False):
    if is_pool:
        net = mx.symbol.Convolution(net, num_filter=num_filter, kernel=(3, 3), pad=(1, 1), stride=(2, 2))
    else:
        net = mx.symbol.Convolution(net, num_filter=num_filter, kernel=(3, 3), pad=(1, 1))

    net = mx.symbol.BatchNorm(net)
    net = mx.symbol.LeakyReLU(net, act_type="elu")
    return net


def up_bna(net, net_merge, num_filter, num_filter_up, up_type='deconv'):
    net = conv_bna(net, num_filter)
    net = conv_bna(net, num_filter)

    if up_type == 'upsample':
        # Nearest Neighbor is best used for categorical data like land-use classification or slope classification.
        # The values that go into the grid stay exactly the same, a 2 comes out as a 2 and 99 comes out as 99.
        # The value of of the output cell is determined by the nearest cell center on the input grid.
        # Nearest Neighbor can be used on continuous data but the results can be blocky.
        net = mx.sym.UpSampling(net, scale=2, num_filter=num_filter_up, sample_type='nearest')
    elif up_type ==  'deconv':
        net = mx.sym.Deconvolution(net, kernel=(2, 2), pad=(0, 0), stride=(2, 2), num_filter=num_filter_up)

    net = mx.symbol.Concat(*[net, net_merge], dim=1)
    # net = mx.symbol.Concat(net, net_merge, num_args=2, dim=1)
    net = bna(net)
    return net


def get_unet_symbol():
    data = mx.sym.Variable('data')
    label = mx.sym.Variable('label')

    # group 1
    net = conv_bna(data, 64)
    net_merge_1 = conv_bna(net, 64)
    net = conv_bna(net_merge_1, 64, is_pool=True)

    # group 2
    net = conv_bna(net, 128)
    net_merge_2= conv_bna(net, 128)
    net = conv_bna(net_merge_2, 128, is_pool=True)

    # group 3
    net = conv_bna(net, 256)
    net_merge_3 = conv_bna(net, 256)
    net = conv_bna(net_merge_3, 256, is_pool=True)

    # group 4
    net = conv_bna(net, 512)
    net_merge_4 = conv_bna(net, 512)
    net = conv_bna(net_merge_4, 512, is_pool=True)

    # up groups
    net = up_bna(net, net_merge_4, 1024, 512)
    net = up_bna(net, net_merge_3, 512, 256)
    net = up_bna(net, net_merge_2, 256, 128)
    net = up_bna(net, net_merge_1,   128, 64)

    # final group
    net = conv_bna(conv_bna(net, 64), 64)
    net = mx.symbol.Convolution(net, num_filter=1, kernel=(1, 1))
    sigmoid = mx.symbol.Activation(net, act_type='sigmoid', name='sigmoid')

    return mx.symbol.LogisticRegressionOutput(sigmoid, label)

In [40]:
symb=get_unet_symbol()

In [42]:
model = mx.mod.Module(symbol=get_unet_symbol(), data_names=['data'], label_names=['label'])
model.bind(data_shapes=[('data', (1, 3, 400, 400))], label_shapes=[('label', (1,1,400,400))])
model.init_params(allow_missing=True, initializer=initializer)

In [ ]: