In [151]:
import mxnet as mx
from collections import namedtuple
import cv2
import os, urllib
import numpy as np

def save_net(model, prefix):
    model.save_params(prefix+'-0001.params')
    model.symbol.save(prefix+'-symbol.json')
    
def print_net_params(param):
    for m in sorted(param):
        print m, param[m].shape
        
def plot_network(symbol, name='network'):
    graph = mx.viz.plot_network(symbol=symbol, node_attrs={"shape":'rect', "fixedsize":'false', 'rankdir': 'TB'})
    graph.format = 'png'
    graph.render('network.gv', view=True)
    
def print_inferred_shape(net, name, nch=3, size=300, node_type='rgb'):
    if node_type == 'rgb':
        ar, ou, au = net.infer_shape(rgb=(1, nch, size, size))
    if node_type == 'tir': 
        ar, ou, au = net.infer_shape(tir=(1, nch, size, size))
    if node_type == 'spectral':
        ar, ou, au = net.infer_shape(tir=(1, 1, size, size), rgb=(1, nch, size, size))
    print ou
    
    
def bn_act_conv_layer(from_layer, name, num_filter, kernel=(1,1), pad=(0,0), stride=(1,1)):
    bn = mx.symbol.BatchNorm(data=from_layer, name="bn{}".format(name))
    relu = mx.symbol.Activation(data=bn, act_type='relu')
    conv = mx.symbol.Convolution(data=relu, kernel=kernel, pad=pad, stride=stride, num_filter=num_filter, name="conv{}".format(name))
    return conv, relu

def conv_act_layer(from_layer, name, num_filter, kernel=(1,1), pad=(0,0), stride=(1,1), act_type="relu"):
    relu = mx.symbol.Activation(data=from_layer, act_type=act_type, name="{}{}".format(act_type, name))
    conv = mx.symbol.Convolution(data=relu, kernel=kernel, pad=pad, stride=stride, num_filter=num_filter, name="conv{}".format(name))
    return conv, relu
    
def residual_unit(data, num_filter, stride, dim_match, name, bn_mom=0.9, workspace=256):
    bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn1')
    act1 = mx.symbol.Activation(data=bn1, act_type='relu', name=name + '_relu1')
    conv1 = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(3, 3), stride=stride, pad=(1, 1),
                               no_bias=True, workspace=workspace, name=name + '_conv1')
    bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2')
    act2 = mx.symbol.Activation(data=bn2, act_type='relu', name=name + '_relu2')
    conv2 = mx.sym.Convolution(data=act2, num_filter=num_filter, kernel=(3, 3), stride=(1, 1), pad=(1, 1),
                               no_bias=True, workspace=workspace, name=name + '_conv2')
    if dim_match:
        shortcut = data
    else:
        shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1, 1), stride=stride, no_bias=True,
                                      workspace=workspace, name=name + '_sc')
    return conv2 + shortcut


# fusion functions
def cf_unit(res_unit_rgb, res_unit_tir, num_filters=64, name='fusion1_unit_1', mode='conv'):
    if mode == 'conv':
        concat = mx.symbol.Concat(name=name + '_concat', *[res_unit_rgb, res_unit_tir], dim=1)
        act = mx.symbol.Activation(data=concat, act_type='relu', name=name + '_relu1')
        conv = mx.sym.Convolution(act, num_filter=num_filters, kernel=(3, 3), stride=(1, 1), pad=(1, 1),
                                  workspace=256, name=name + '_conv')
    elif mode == 'sum':
        conv = mx.symbol.broadcast_add(res_unit_rgb, res_unit_tir, name=name + '_conv')
    elif mode == 'max':
        conv = mx.symbol.broadcast_maximum(res_unit_rgb, res_unit_tir, name=name + '_conv')
    return conv

def spectral_net():
    filter_list = [64, 64, 128, 256, 512]    
    tir = mx.sym.Variable(name='tir')
    
    net_tir = mx.sym.Convolution(tir, num_filter=filter_list[0], kernel=(3, 3), stride=(1, 1), pad=(1, 1), name='conv0')
    net_tir = mx.symbol.Activation(net_tir, act_type='relu', name='relu0')
    net_tir = mx.symbol.Pooling(net_tir, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max', name='pool0')
    net_tir = mx.sym.Convolution(net_tir, num_filter=filter_list[1], kernel=(3, 3), stride=(1, 1), pad=(1, 1), name='conv1')
    net_tir = mx.symbol.Activation(net_tir, act_type='relu', name='relu1')
    net_tir = mx.symbol.Pooling(net_tir, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max', name='pool1')
    net_tir = mx.sym.Convolution(net_tir, num_filter=filter_list[2], kernel=(3, 3), stride=(1, 1), pad=(1, 1),name='conv2')
    net_tir = mx.symbol.Activation(net_tir, act_type='relu', name='relu2')
    net_tir = mx.symbol.Pooling(net_tir, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max', name='pool2')
    net_tir = mx.sym.Convolution(net_tir, num_filter=filter_list[3], kernel=(3, 3), stride=(1, 1), pad=(1, 1),name='conv3')
    net_tir = mx.symbol.Activation(net_tir, act_type='relu', name='relu3')
    net_tir = mx.symbol.Pooling(net_tir, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max', name='pool3')
    net_tir = mx.sym.Convolution(net_tir, num_filter=filter_list[4], kernel=(3, 3), stride=(1, 1), pad=(1, 1), name='conv4')
    net_tir = mx.symbol.Activation(net_tir, act_type='relu', name='relu4')
    net_tir = mx.symbol.Pooling(net_tir, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max', name='pool4')
    net_tir, relu8_2t = conv_act_layer(net_tir, "8_2t", 512, kernel=(3, 3), pad=(1, 1), stride=(2, 2))
    net_tir, relu9_2t = conv_act_layer(net_tir, "9_2t", 256, kernel=(3, 3), pad=(1, 1), stride=(2, 2))
    net_tir, relu10_2t = conv_act_layer(net_tir, "10_2t", 256, kernel=(3, 3), pad=(1, 1), stride=(2, 2))

    fusion_1 = net_tir.get_internals()["pool2_output"]
    fusion_2 = net_tir.get_internals()["pool3_output"]
    fusion_3 = net_tir.get_internals()["conv8_2t_output"]
    fusion_4 = net_tir.get_internals()["conv9_2t_output"]
    fusion_5 = net_tir.get_internals()["conv10_2t_output"]
   
    return [fusion_1, fusion_2, fusion_3, fusion_4, fusion_5]


def resnet():
    filter_list = [64, 64, 128, 256, 512]
    bn_mom = 0.9
    workspace = 256

    rgb = mx.sym.Variable(name='rgb')

    # rgb head
    rgb = mx.sym.BatchNorm(rgb, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data')
    net_rgb = mx.sym.Convolution(rgb, num_filter=filter_list[0], kernel=(7, 7), stride=(2, 2), pad=(3, 3), no_bias=True, name="conv0", workspace=workspace)
    net_rgb = mx.sym.BatchNorm(net_rgb, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0')
    net_rgb = mx.symbol.Activation(net_rgb, act_type='relu', name='relu0')
    net_rgb = mx.symbol.Pooling(net_rgb, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max')

    # stage 1
    net_rgb = residual_unit(net_rgb, filter_list[1], (1, 1), False, name='stage1_unit1', workspace=workspace)
    net_rgb = residual_unit(net_rgb, filter_list[1], (1, 1), True, name='stage1_unit2', workspace=workspace)

    # stage 2
    net_rgb = residual_unit(net_rgb, filter_list[2], (2, 2), False, name='stage2_unit1', workspace=workspace)
    net_rgb = residual_unit(net_rgb, filter_list[2], (1, 1), True, name='stage2_unit2', workspace=workspace)

    # stage 3
    net_rgb = residual_unit(net_rgb, filter_list[3], (2, 2), False, name='stage3_unit1', workspace=workspace)
    net_rgb = residual_unit(net_rgb, filter_list[3], (1, 1), True, name='stage3_unit2', workspace=workspace)

    # stage 4
    net_rgb = residual_unit(net_rgb, filter_list[4], (2, 2), False, name='stage4_unit1', workspace=workspace)
    net_rgb = residual_unit(net_rgb, filter_list[4], (1, 1), True, name='stage4_unit2', workspace=workspace)

    bn1 = mx.sym.BatchNorm(net_rgb, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn1')
    relu1 = mx.symbol.Activation(bn1, act_type='relu', name='relu1')

    # ssd extra layers
    conv8_1, relu8_1 = bn_act_conv_layer(relu1, "8_1", 256, kernel=(1, 1), pad=(0, 0), stride=(1, 1))
    conv8_2, relu8_2 = bn_act_conv_layer(conv8_1, "8_2", 512, kernel=(3, 3), pad=(1, 1), stride=(2, 2))
    conv9_1, relu9_1 = bn_act_conv_layer(conv8_2, "9_1", 128, kernel=(1, 1), pad=(0, 0), stride=(1, 1))
    conv9_2, relu9_2 = bn_act_conv_layer(conv9_1, "9_2", 256, kernel=(3, 3), pad=(1, 1), stride=(2, 2))
    conv10_1, relu10_1 = bn_act_conv_layer(conv9_2, "10_1", 128, kernel=(1, 1), pad=(0, 0), stride=(1, 1))
    conv10_2, relu10_2 = bn_act_conv_layer(conv10_1, "10_2", 256, kernel=(3, 3), pad=(1, 1), stride=(2, 2))
    pool10 = mx.symbol.Pooling(data=conv10_2, pool_type="avg", global_pool=True, kernel=(1, 1), name='pool10')
    net_rgb = pool10

    fusion_1 = net_rgb.get_internals()["stage3_unit1_relu1_output"]
    fusion_2 = net_rgb.get_internals()["stage4_unit1_relu1_output"]
    fusion_3 = conv8_2
    fusion_4 = conv9_2
    fusion_5 = conv10_2

    return [fusion_1, fusion_2, fusion_3, fusion_4, fusion_5, pool10]


def fusion_net():
    rgb_fusion_layers = resnet()
    tir_fusion_layers = spectral_net()
    input_1 = cf_unit(rgb_fusion_layers[0], tir_fusion_layers[0], num_filters=128, name='fusion1')
    input_2 = cf_unit(rgb_fusion_layers[1], tir_fusion_layers[1], num_filters=256, name='fusion2')
    input_3 = cf_unit(rgb_fusion_layers[2], tir_fusion_layers[2], num_filters=512, name='fusion3')
    input_4 = cf_unit(rgb_fusion_layers[3], tir_fusion_layers[3], num_filters=256, name='fusion4')
    input_5 = cf_unit(rgb_fusion_layers[4], tir_fusion_layers[4], num_filters=256, name='fusion5')
    input_6 = rgb_fusion_layers[5]
    return [input_1, input_2, input_3, input_4, input_5, input_6]

In [152]:
symb=fusion_net()

In [147]:
internals = resnet()[-1].get_internals()
for output in internals.list_outputs():
    if output in ["stage3_unit1_relu1_output", 
                  "stage4_unit1_relu1_output", 
                  'conv8_2_output',
                  'conv9_2_output',
                  'conv10_2_output']:
        print output, internals[output].infer_shape(rgb=(1,3,300,300))[1]


stage3_unit1_relu1_output [(1L, 128L, 38L, 38L)]
stage4_unit1_relu1_output [(1L, 256L, 19L, 19L)]
conv8_2_output [(1L, 512L, 5L, 5L)]
conv9_2_output [(1L, 256L, 3L, 3L)]
conv10_2_output [(1L, 256L, 2L, 2L)]

In [149]:
internals = spectral_net().get_internals()
for output in internals.list_outputs():
    if output in ['pool2_output',
                  'pool3_output', 
                  'conv8_2t_output', 
                  'conv9_2t_output',
                  'conv10_2t_output',]:
        print output, internals[output].infer_shape(tir=(1,1,300,300))[1]


pool2_output [(1L, 128L, 38L, 38L)]
pool3_output [(1L, 256L, 19L, 19L)]
conv8_2t_output [(1L, 512L, 5L, 5L)]
conv9_2t_output [(1L, 256L, 3L, 3L)]
conv10_2t_output [(1L, 256L, 2L, 2L)]

In [174]:
internals = fusion_net()[-2].get_internals()
for output in internals.list_outputs():
#     print output
    if output in ['fusion5_relu1_output']:
        print output, internals[output].infer_shape(tir=(1,1,300,300), rgb=(1,3,300,300))[1]


fusion5_relu1_output [(1L, 512L, 2L, 2L)]

In [128]:
# read trained ssd-resnet 0.68
net_params_ssd = mx.initializer.Load('ssd_300-0300.params')
net_params = {}

# add same weights for tir network
for arg in net_params_ssd.param:
    net_params[arg] = net_params_ssd.param[arg]
#     if arg.startswith('bn_data') or arg.startswith('bn0') or arg.startswith('conv0') or arg.startswith('stage1') \
#     or arg.startswith('stage2') or arg.startswith('stage3') or arg.startswith('stage4'):
#         net_params['tir_'+arg] = net_params_ssd.param[arg]
        
# scale param for fusion unit
net_params['fusion3_1_conv_scale'] = net_params_ssd.param['stage3_unit1_relu1_scale']

# prediction params for fusion units
# net_params['fusion3_1_conv_cls_pred_conv_bias'] = net_params_ssd.param['stage3_unit1_relu1_cls_pred_conv_bias']
# net_params['fusion3_1_conv_cls_pred_conv_weight'] = net_params_ssd.param['stage3_unit1_relu1_cls_pred_conv_weight']
# net_params['fusion3_1_conv_loc_pred_conv_bias'] = net_params_ssd.param['stage3_unit1_relu1_loc_pred_conv_bias']
# net_params['fusion3_1_conv_loc_pred_conv_weight'] = net_params_ssd.param['stage3_unit1_relu1_loc_pred_conv_weight']

# net_params['fusion4_1_conv_cls_pred_conv_bias'] = net_params_ssd.param['stage4_unit1_relu1_cls_pred_conv_bias']
# net_params['fusion4_1_conv_cls_pred_conv_weight'] = net_params_ssd.param['stage4_unit1_relu1_cls_pred_conv_weight']
# net_params['fusion4_1_conv_loc_pred_conv_bias'] = net_params_ssd.param['stage4_unit1_relu1_loc_pred_conv_bias']
# net_params['fusion4_1_conv_loc_pred_conv_weight'] = net_params_ssd.param['stage4_unit1_relu1_loc_pred_conv_weight']

# #take all values for red chanell since it close to TIR
# net_params['tir_conv0_weight'] = mx.nd.array(net_params['tir_conv0_weight'].asnumpy()[:, :1, :, :])
# net_params['tir_bn_data_gamma'] = mx.nd.array(net_params['tir_bn_data_gamma'].asnumpy()[:1])
# net_params['tir_bn_data_beta'] = mx.nd.array(net_params['tir_bn_data_beta'].asnumpy()[:1])
# net_params['tir_bn_data_moving_var'] = mx.nd.array(net_params['tir_bn_data_moving_var'].asnumpy()[:1])
# net_params['tir_bn_data_moving_mean'] = mx.nd.array(net_params['tir_bn_data_moving_mean'].asnumpy()[:1])

In [129]:
# def get_person_class(name, bbox=3):
#     cls=14
#     b = net_params[name].asnumpy()[:bbox]
#     p = net_params[name].asnumpy()[cls*bbox:(cls+1)*bbox]
#     return mx.nd.array(np.concatenate([b,p], axis=0))

In [130]:
# transform predictions only for person class
# net_params['fusion3_1_conv_cls_pred_conv_bias'] = get_person_class('fusion3_1_conv_cls_pred_conv_bias', 3)
# net_params['fusion3_1_conv_cls_pred_conv_weight'] = get_person_class('fusion3_1_conv_cls_pred_conv_weight', 3)

# net_params['fusion4_1_conv_cls_pred_conv_bias'] = get_person_class('fusion4_1_conv_cls_pred_conv_bias', 6)
# net_params['fusion4_1_conv_cls_pred_conv_weight'] = get_person_class('fusion4_1_conv_cls_pred_conv_weight', 6)

# net_params['conv8_2_cls_pred_conv_weight'] = get_person_class('conv8_2_cls_pred_conv_weight', 6)
# net_params['conv8_2_cls_pred_conv_bias'] = get_person_class('conv8_2_cls_pred_conv_bias', 6)

# net_params['conv9_2_cls_pred_conv_weight'] = get_person_class('conv9_2_cls_pred_conv_weight', 6)
# net_params['conv9_2_cls_pred_conv_bias'] = get_person_class('conv9_2_cls_pred_conv_bias', 6)

# net_params['conv10_2_cls_pred_conv_weight'] = get_person_class('conv10_2_cls_pred_conv_weight', 6)
# net_params['conv10_2_cls_pred_conv_bias'] = get_person_class('conv10_2_cls_pred_conv_bias', 6)

# net_params['pool10_cls_pred_conv_weight'] = get_person_class('pool10_cls_pred_conv_weight', 6)
# net_params['pool10_cls_pred_conv_bias'] = get_person_class('pool10_cls_pred_conv_bias', 6)

In [131]:
model = mx.mod.Module(symbol=resnet(), data_names=['rgb', 'tir'])
model.bind(data_shapes=[('rgb', (1, 3, 300, 300)), ('tir', (1, 1, 300, 300))])
model.init_params(arg_params=net_params, allow_missing=True, initializer=mx.initializer.Xavier())


[(1L, 128L, 38L, 38L)]
[(1L, 256L, 19L, 19L)]
[(1L, 128L, 38L, 38L)]
[(1L, 256L, 19L, 19L)]
[(1L, 512L, 10L, 10L)]

In [132]:
save_net(model, 'spectral')

In [ ]: