NDArray


In [1]:
import mxnet as mx 
import numpy as np
import matplotlib.pyplot as plt 
dir(mx)


Out[1]:
['AttrScope',
 'Context',
 'MXNetError',
 '__builtins__',
 '__doc__',
 '__file__',
 '__name__',
 '__package__',
 '__path__',
 '__version__',
 '_ctypes',
 '_cy2',
 'absolute_import',
 'attribute',
 'autograd',
 'base',
 'callback',
 'context',
 'contrib',
 'cpu',
 'current_context',
 'engine',
 'executor',
 'executor_manager',
 'gluon',
 'gpu',
 'image',
 'img',
 'init',
 'initializer',
 'io',
 'kv',
 'kvstore',
 'kvstore_server',
 'libinfo',
 'log',
 'lr_scheduler',
 'metric',
 'mod',
 'model',
 'module',
 'mon',
 'monitor',
 'name',
 'nd',
 'ndarray',
 'ndarray_doc',
 'notebook',
 'operator',
 'optimizer',
 'profiler',
 'random',
 'recordio',
 'registry',
 'rnd',
 'rnn',
 'rtc',
 'sym',
 'symbol',
 'symbol_doc',
 'test_utils',
 'th',
 'torch',
 'visualization',
 'viz']

What is NDArray Module, like numpy!


In [2]:
dir(mx.nd)


Out[2]:
['Activation',
 'BatchNorm',
 'BatchNorm_v1',
 'BilinearSampler',
 'BlockGrad',
 'CachedOp',
 'Cast',
 'Concat',
 'Convolution',
 'Convolution_v1',
 'Correlation',
 'Crop',
 'Custom',
 'Deconvolution',
 'Dropout',
 'ElementWiseSum',
 'Embedding',
 'Flatten',
 'FullyConnected',
 'GridGenerator',
 'IdentityAttachKLSparseReg',
 'InstanceNorm',
 'L2Normalization',
 'LRN',
 'LeakyReLU',
 'LinearRegressionOutput',
 'LogisticRegressionOutput',
 'MAERegressionOutput',
 'MakeLoss',
 'NDArray',
 'Pad',
 'Pooling',
 'Pooling_v1',
 'RNN',
 'ROIPooling',
 'Reshape',
 'SVMOutput',
 'SequenceLast',
 'SequenceMask',
 'SequenceReverse',
 'SliceChannel',
 'Softmax',
 'SoftmaxActivation',
 'SoftmaxOutput',
 'SpatialTransformer',
 'SwapAxis',
 'UpSampling',
 '_DTYPE_MX_TO_NP',
 '_DTYPE_NP_TO_MX',
 '_GRAD_REQ_MAP',
 '__all__',
 '__builtins__',
 '__doc__',
 '__file__',
 '__name__',
 '__package__',
 '__path__',
 '_internal',
 '_ndarray_cls',
 '_new_empty_handle',
 'abs',
 'adam_update',
 'add',
 'add_n',
 'arange',
 'arccos',
 'arccosh',
 'arcsin',
 'arcsinh',
 'arctan',
 'arctanh',
 'argmax',
 'argmax_channel',
 'argmin',
 'argsort',
 'array',
 'batch_dot',
 'batch_take',
 'broadcast_add',
 'broadcast_axes',
 'broadcast_axis',
 'broadcast_div',
 'broadcast_equal',
 'broadcast_greater',
 'broadcast_greater_equal',
 'broadcast_hypot',
 'broadcast_lesser',
 'broadcast_lesser_equal',
 'broadcast_maximum',
 'broadcast_minimum',
 'broadcast_minus',
 'broadcast_mod',
 'broadcast_mul',
 'broadcast_not_equal',
 'broadcast_plus',
 'broadcast_power',
 'broadcast_sub',
 'broadcast_to',
 'cast',
 'cast_storage',
 'cbrt',
 'ceil',
 'choose_element_0index',
 'clip',
 'concat',
 'concatenate',
 'contrib',
 'cos',
 'cosh',
 'crop',
 'degrees',
 'divide',
 'dot',
 'elemwise_add',
 'elemwise_div',
 'elemwise_mul',
 'elemwise_sub',
 'empty',
 'equal',
 'exp',
 'expand_dims',
 'expm1',
 'fill_element_0index',
 'fix',
 'flatten',
 'flip',
 'floor',
 'ftrl_update',
 'full',
 'gamma',
 'gammaln',
 'gather_nd',
 'gen__internal',
 'gen_contrib',
 'gen_linalg',
 'gen_op',
 'gen_sparse',
 'greater',
 'greater_equal',
 'identity',
 'imdecode',
 'lesser',
 'lesser_equal',
 'linalg',
 'linalg_gelqf',
 'linalg_gemm',
 'linalg_gemm2',
 'linalg_potrf',
 'linalg_potri',
 'linalg_sumlogdiag',
 'linalg_syrk',
 'linalg_trmm',
 'linalg_trsm',
 'load',
 'log',
 'log10',
 'log1p',
 'log2',
 'log_softmax',
 'make_loss',
 'max',
 'max_axis',
 'maximum',
 'mean',
 'min',
 'min_axis',
 'minimum',
 'modulo',
 'moveaxis',
 'mp_sgd_mom_update',
 'mp_sgd_update',
 'multiply',
 'nanprod',
 'nansum',
 'ndarray',
 'negative',
 'norm',
 'normal',
 'not_equal',
 'one_hot',
 'onehot_encode',
 'ones',
 'ones_like',
 'op',
 'pad',
 'pick',
 'power',
 'prod',
 'radians',
 'random',
 'random_exponential',
 'random_gamma',
 'random_generalized_negative_binomial',
 'random_negative_binomial',
 'random_normal',
 'random_poisson',
 'random_uniform',
 'rcbrt',
 'reciprocal',
 'register',
 'relu',
 'repeat',
 'reshape',
 'reshape_like',
 'reverse',
 'rint',
 'rmsprop_update',
 'rmspropalex_update',
 'round',
 'rsqrt',
 'sample_exponential',
 'sample_gamma',
 'sample_generalized_negative_binomial',
 'sample_multinomial',
 'sample_negative_binomial',
 'sample_normal',
 'sample_poisson',
 'sample_uniform',
 'save',
 'scatter_nd',
 'sgd_mom_update',
 'sgd_update',
 'sigmoid',
 'sign',
 'sin',
 'sinh',
 'slice',
 'slice_axis',
 'smooth_l1',
 'softmax',
 'softmax_cross_entropy',
 'sort',
 'sparse',
 'split',
 'sqrt',
 'square',
 'stack',
 'stop_gradient',
 'subtract',
 'sum',
 'sum_axis',
 'swapaxes',
 'take',
 'tan',
 'tanh',
 'tile',
 'topk',
 'transpose',
 'true_divide',
 'trunc',
 'uniform',
 'utils',
 'waitall',
 'where',
 'zeros',
 'zeros_like']

In [3]:
dir(np)


Out[3]:
['ALLOW_THREADS',
 'AxisError',
 'BUFSIZE',
 'CLIP',
 'ComplexWarning',
 'DataSource',
 'ERR_CALL',
 'ERR_DEFAULT',
 'ERR_IGNORE',
 'ERR_LOG',
 'ERR_PRINT',
 'ERR_RAISE',
 'ERR_WARN',
 'FLOATING_POINT_SUPPORT',
 'FPE_DIVIDEBYZERO',
 'FPE_INVALID',
 'FPE_OVERFLOW',
 'FPE_UNDERFLOW',
 'False_',
 'Inf',
 'Infinity',
 'MAXDIMS',
 'MAY_SHARE_BOUNDS',
 'MAY_SHARE_EXACT',
 'MachAr',
 'ModuleDeprecationWarning',
 'NAN',
 'NINF',
 'NZERO',
 'NaN',
 'PINF',
 'PZERO',
 'PackageLoader',
 'RAISE',
 'RankWarning',
 'SHIFT_DIVIDEBYZERO',
 'SHIFT_INVALID',
 'SHIFT_OVERFLOW',
 'SHIFT_UNDERFLOW',
 'ScalarType',
 'Tester',
 'TooHardError',
 'True_',
 'UFUNC_BUFSIZE_DEFAULT',
 'UFUNC_PYVALS_NAME',
 'VisibleDeprecationWarning',
 'WRAP',
 '_NoValue',
 '__NUMPY_SETUP__',
 '__all__',
 '__builtins__',
 '__config__',
 '__doc__',
 '__file__',
 '__git_revision__',
 '__name__',
 '__package__',
 '__path__',
 '__version__',
 '_distributor_init',
 '_globals',
 '_import_tools',
 '_mat',
 'abs',
 'absolute',
 'absolute_import',
 'add',
 'add_docstring',
 'add_newdoc',
 'add_newdoc_ufunc',
 'add_newdocs',
 'alen',
 'all',
 'allclose',
 'alltrue',
 'amax',
 'amin',
 'angle',
 'any',
 'append',
 'apply_along_axis',
 'apply_over_axes',
 'arange',
 'arccos',
 'arccosh',
 'arcsin',
 'arcsinh',
 'arctan',
 'arctan2',
 'arctanh',
 'argmax',
 'argmin',
 'argpartition',
 'argsort',
 'argwhere',
 'around',
 'array',
 'array2string',
 'array_equal',
 'array_equiv',
 'array_repr',
 'array_split',
 'array_str',
 'asanyarray',
 'asarray',
 'asarray_chkfinite',
 'ascontiguousarray',
 'asfarray',
 'asfortranarray',
 'asmatrix',
 'asscalar',
 'atleast_1d',
 'atleast_2d',
 'atleast_3d',
 'average',
 'bartlett',
 'base_repr',
 'bench',
 'binary_repr',
 'bincount',
 'bitwise_and',
 'bitwise_not',
 'bitwise_or',
 'bitwise_xor',
 'blackman',
 'block',
 'bmat',
 'bool',
 'bool8',
 'bool_',
 'broadcast',
 'broadcast_arrays',
 'broadcast_to',
 'busday_count',
 'busday_offset',
 'busdaycalendar',
 'byte',
 'byte_bounds',
 'bytes_',
 'c_',
 'can_cast',
 'cast',
 'cbrt',
 'cdouble',
 'ceil',
 'cfloat',
 'char',
 'character',
 'chararray',
 'choose',
 'clip',
 'clongdouble',
 'clongfloat',
 'column_stack',
 'common_type',
 'compare_chararrays',
 'compat',
 'complex',
 'complex128',
 'complex256',
 'complex64',
 'complex_',
 'complexfloating',
 'compress',
 'concatenate',
 'conj',
 'conjugate',
 'convolve',
 'copy',
 'copysign',
 'copyto',
 'core',
 'corrcoef',
 'correlate',
 'cos',
 'cosh',
 'count_nonzero',
 'cov',
 'cross',
 'csingle',
 'ctypeslib',
 'cumprod',
 'cumproduct',
 'cumsum',
 'datetime64',
 'datetime_as_string',
 'datetime_data',
 'deg2rad',
 'degrees',
 'delete',
 'deprecate',
 'deprecate_with_doc',
 'diag',
 'diag_indices',
 'diag_indices_from',
 'diagflat',
 'diagonal',
 'diff',
 'digitize',
 'disp',
 'divide',
 'division',
 'divmod',
 'dot',
 'double',
 'dsplit',
 'dstack',
 'dtype',
 'e',
 'ediff1d',
 'einsum',
 'einsum_path',
 'emath',
 'empty',
 'empty_like',
 'equal',
 'errstate',
 'euler_gamma',
 'exp',
 'exp2',
 'expand_dims',
 'expm1',
 'extract',
 'eye',
 'fabs',
 'fastCopyAndTranspose',
 'fft',
 'fill_diagonal',
 'find_common_type',
 'finfo',
 'fix',
 'flatiter',
 'flatnonzero',
 'flexible',
 'flip',
 'fliplr',
 'flipud',
 'float',
 'float128',
 'float16',
 'float32',
 'float64',
 'float_',
 'float_power',
 'floating',
 'floor',
 'floor_divide',
 'fmax',
 'fmin',
 'fmod',
 'format_parser',
 'frexp',
 'frombuffer',
 'fromfile',
 'fromfunction',
 'fromiter',
 'frompyfunc',
 'fromregex',
 'fromstring',
 'full',
 'full_like',
 'fv',
 'generic',
 'genfromtxt',
 'geomspace',
 'get_array_wrap',
 'get_include',
 'get_printoptions',
 'getbuffer',
 'getbufsize',
 'geterr',
 'geterrcall',
 'geterrobj',
 'gradient',
 'greater',
 'greater_equal',
 'half',
 'hamming',
 'hanning',
 'heaviside',
 'histogram',
 'histogram2d',
 'histogramdd',
 'hsplit',
 'hstack',
 'hypot',
 'i0',
 'identity',
 'iinfo',
 'imag',
 'in1d',
 'index_exp',
 'indices',
 'inexact',
 'inf',
 'info',
 'infty',
 'inner',
 'insert',
 'int',
 'int0',
 'int16',
 'int32',
 'int64',
 'int8',
 'int_',
 'int_asbuffer',
 'intc',
 'integer',
 'interp',
 'intersect1d',
 'intp',
 'invert',
 'ipmt',
 'irr',
 'is_busday',
 'isclose',
 'iscomplex',
 'iscomplexobj',
 'isfinite',
 'isfortran',
 'isin',
 'isinf',
 'isnan',
 'isnat',
 'isneginf',
 'isposinf',
 'isreal',
 'isrealobj',
 'isscalar',
 'issctype',
 'issubclass_',
 'issubdtype',
 'issubsctype',
 'iterable',
 'ix_',
 'kaiser',
 'kron',
 'ldexp',
 'left_shift',
 'less',
 'less_equal',
 'lexsort',
 'lib',
 'linalg',
 'linspace',
 'little_endian',
 'load',
 'loads',
 'loadtxt',
 'log',
 'log10',
 'log1p',
 'log2',
 'logaddexp',
 'logaddexp2',
 'logical_and',
 'logical_not',
 'logical_or',
 'logical_xor',
 'logspace',
 'long',
 'longcomplex',
 'longdouble',
 'longfloat',
 'longlong',
 'lookfor',
 'ma',
 'mafromtxt',
 'mask_indices',
 'mat',
 'math',
 'matmul',
 'matrix',
 'matrixlib',
 'max',
 'maximum',
 'maximum_sctype',
 'may_share_memory',
 'mean',
 'median',
 'memmap',
 'meshgrid',
 'mgrid',
 'min',
 'min_scalar_type',
 'minimum',
 'mintypecode',
 'mirr',
 'mod',
 'modf',
 'moveaxis',
 'msort',
 'multiply',
 'nan',
 'nan_to_num',
 'nanargmax',
 'nanargmin',
 'nancumprod',
 'nancumsum',
 'nanmax',
 'nanmean',
 'nanmedian',
 'nanmin',
 'nanpercentile',
 'nanprod',
 'nanstd',
 'nansum',
 'nanvar',
 'nbytes',
 'ndarray',
 'ndenumerate',
 'ndfromtxt',
 'ndim',
 'ndindex',
 'nditer',
 'negative',
 'nested_iters',
 'newaxis',
 'newbuffer',
 'nextafter',
 'nonzero',
 'not_equal',
 'nper',
 'npv',
 'numarray',
 'number',
 'obj2sctype',
 'object',
 'object0',
 'object_',
 'ogrid',
 'oldnumeric',
 'ones',
 'ones_like',
 'outer',
 'packbits',
 'pad',
 'partition',
 'percentile',
 'pi',
 'piecewise',
 'pkgload',
 'place',
 'pmt',
 'poly',
 'poly1d',
 'polyadd',
 'polyder',
 'polydiv',
 'polyfit',
 'polyint',
 'polymul',
 'polynomial',
 'polysub',
 'polyval',
 'positive',
 'power',
 'ppmt',
 'print_function',
 'prod',
 'product',
 'promote_types',
 'ptp',
 'put',
 'putmask',
 'pv',
 'r_',
 'rad2deg',
 'radians',
 'random',
 'rank',
 'rate',
 'ravel',
 'ravel_multi_index',
 'real',
 'real_if_close',
 'rec',
 'recarray',
 'recfromcsv',
 'recfromtxt',
 'reciprocal',
 'record',
 'remainder',
 'repeat',
 'require',
 'reshape',
 'resize',
 'result_type',
 'right_shift',
 'rint',
 'roll',
 'rollaxis',
 'roots',
 'rot90',
 'round',
 'round_',
 'row_stack',
 's_',
 'safe_eval',
 'save',
 'savetxt',
 'savez',
 'savez_compressed',
 'sctype2char',
 'sctypeDict',
 'sctypeNA',
 'sctypes',
 'searchsorted',
 'select',
 'set_numeric_ops',
 'set_printoptions',
 'set_string_function',
 'setbufsize',
 'setdiff1d',
 'seterr',
 'seterrcall',
 'seterrobj',
 'setxor1d',
 'shape',
 'shares_memory',
 'short',
 'show_config',
 'sign',
 'signbit',
 'signedinteger',
 'sin',
 'sinc',
 'single',
 'singlecomplex',
 'sinh',
 'size',
 'sometrue',
 'sort',
 'sort_complex',
 'source',
 'spacing',
 'split',
 'sqrt',
 'square',
 'squeeze',
 'stack',
 'std',
 'str',
 'str_',
 'string0',
 'string_',
 'subtract',
 'sum',
 'swapaxes',
 'sys',
 'take',
 'tan',
 'tanh',
 'tensordot',
 'test',
 'testing',
 'tile',
 'timedelta64',
 'trace',
 'tracemalloc_domain',
 'transpose',
 'trapz',
 'tri',
 'tril',
 'tril_indices',
 'tril_indices_from',
 'trim_zeros',
 'triu',
 'triu_indices',
 'triu_indices_from',
 'true_divide',
 'trunc',
 'typeDict',
 'typeNA',
 'typecodes',
 'typename',
 'ubyte',
 'ufunc',
 'uint',
 'uint0',
 'uint16',
 'uint32',
 'uint64',
 'uint8',
 'uintc',
 'uintp',
 'ulonglong',
 'unicode',
 'unicode0',
 'unicode_',
 'union1d',
 'unique',
 'unpackbits',
 'unravel_index',
 'unsignedinteger',
 'unwrap',
 'ushort',
 'vander',
 'var',
 'vdot',
 'vectorize',
 'version',
 'void',
 'void0',
 'vsplit',
 'vstack',
 'warnings',
 'where',
 'who',
 'zeros',
 'zeros_like']

What if I want to fast verify my network? Use NDArray


In [4]:
img_numpy = plt.imread('Lenna.png')
print ('img shape,', img_numpy.shape) #(HWC)
print ('img type, ', type(img_numpy))
data = mx.nd.array(img_numpy,ctx=mx.cpu()) # use mx.gpu(#) if you happen to have a/many gpus
print ('mx ND ', data)
plt.imshow(data.asnumpy()) # asnumpy is the magic communicating numpy and mxnet NDarray 
plt.show()


('img shape,', (512, 512, 3))
('img type, ', <type 'numpy.ndarray'>)
('mx ND ', 
[[[ 0.88627452  0.53725493  0.49019608]
  [ 0.88627452  0.53725493  0.49019608]
  [ 0.87450981  0.53725493  0.52156866]
  ..., 
  [ 0.90196079  0.58039218  0.47843137]
  [ 0.86666667  0.50980395  0.43137255]
  [ 0.78431374  0.3882353   0.35294119]]

 [[ 0.88627452  0.53725493  0.49019608]
  [ 0.88627452  0.53725493  0.49019608]
  [ 0.87450981  0.53725493  0.52156866]
  ..., 
  [ 0.90196079  0.58039218  0.47843137]
  [ 0.86666667  0.50980395  0.43137255]
  [ 0.78431374  0.3882353   0.35294119]]

 [[ 0.88627452  0.53725493  0.49019608]
  [ 0.88627452  0.53725493  0.49019608]
  [ 0.87450981  0.53725493  0.52156866]
  ..., 
  [ 0.90196079  0.58039218  0.47843137]
  [ 0.86666667  0.50980395  0.43137255]
  [ 0.78431374  0.3882353   0.35294119]]

 ..., 
 [[ 0.32941177  0.07058824  0.23529412]
  [ 0.32941177  0.07058824  0.23529412]
  [ 0.36078432  0.10588235  0.22745098]
  ..., 
  [ 0.67843139  0.28627452  0.32941177]
  [ 0.67450982  0.26666668  0.29803923]
  [ 0.69411767  0.24313726  0.30980393]]

 [[ 0.32156864  0.08627451  0.22352941]
  [ 0.32156864  0.08627451  0.22352941]
  [ 0.3764706   0.1254902   0.24313726]
  ..., 
  [ 0.7019608   0.27450982  0.30980393]
  [ 0.70980394  0.27843139  0.31764707]
  [ 0.72549021  0.29019609  0.31764707]]

 [[ 0.32156864  0.08627451  0.22352941]
  [ 0.32156864  0.08627451  0.22352941]
  [ 0.3764706   0.1254902   0.24313726]
  ..., 
  [ 0.7019608   0.27450982  0.30980393]
  [ 0.70980394  0.27843139  0.31764707]
  [ 0.72549021  0.29019609  0.31764707]]]
<NDArray 512x512x3 @cpu(0)>)

Let us build a CNN!


In [5]:
# Again we have an image, Lenna, the Goddess. 
# But its axis follows HWC which is not the format that numpy/mxnet like.
# MXNet like N->C->H->W
# We need to do some image preprocessing
image = mx.nd.transpose(data, axes=(2,0,1))
image = mx.nd.expand_dims(image,0)
print (image)


[[[[ 0.88627452  0.88627452  0.87450981 ...,  0.90196079  0.86666667
     0.78431374]
   [ 0.88627452  0.88627452  0.87450981 ...,  0.90196079  0.86666667
     0.78431374]
   [ 0.88627452  0.88627452  0.87450981 ...,  0.90196079  0.86666667
     0.78431374]
   ..., 
   [ 0.32941177  0.32941177  0.36078432 ...,  0.67843139  0.67450982
     0.69411767]
   [ 0.32156864  0.32156864  0.3764706  ...,  0.7019608   0.70980394
     0.72549021]
   [ 0.32156864  0.32156864  0.3764706  ...,  0.7019608   0.70980394
     0.72549021]]

  [[ 0.53725493  0.53725493  0.53725493 ...,  0.58039218  0.50980395
     0.3882353 ]
   [ 0.53725493  0.53725493  0.53725493 ...,  0.58039218  0.50980395
     0.3882353 ]
   [ 0.53725493  0.53725493  0.53725493 ...,  0.58039218  0.50980395
     0.3882353 ]
   ..., 
   [ 0.07058824  0.07058824  0.10588235 ...,  0.28627452  0.26666668
     0.24313726]
   [ 0.08627451  0.08627451  0.1254902  ...,  0.27450982  0.27843139
     0.29019609]
   [ 0.08627451  0.08627451  0.1254902  ...,  0.27450982  0.27843139
     0.29019609]]

  [[ 0.49019608  0.49019608  0.52156866 ...,  0.47843137  0.43137255
     0.35294119]
   [ 0.49019608  0.49019608  0.52156866 ...,  0.47843137  0.43137255
     0.35294119]
   [ 0.49019608  0.49019608  0.52156866 ...,  0.47843137  0.43137255
     0.35294119]
   ..., 
   [ 0.23529412  0.23529412  0.22745098 ...,  0.32941177  0.29803923
     0.30980393]
   [ 0.22352941  0.22352941  0.24313726 ...,  0.30980393  0.31764707
     0.31764707]
   [ 0.22352941  0.22352941  0.24313726 ...,  0.30980393  0.31764707
     0.31764707]]]]
<NDArray 1x3x512x512 @cpu(0)>

We want to do imge->conv->relu->pool->FC->softmax

Let's do CONV first


In [6]:
# define filter
w = mx.nd.ones((32,3,3,3), ctx=mx.cpu())/27. #follow NCHW
# bias = mx.nd.ones((32,1),ctx=mx.cpu())
conv = mx.nd.Convolution(data=image,num_filter=32, kernel=(3,3), pad=(1,1),stride=(1,1),weight=w,no_bias=True)
type(conv)


Out[6]:
mxnet.ndarray.ndarray.NDArray

In [7]:
plt.imshow(conv.asnumpy()[0,0,:,:],cmap='gray')
plt.show()


Then ReLU


In [8]:
relu = mx.nd.Activation(data=conv,act_type='tanh')
#print (relu.asnumpy())
plt.imshow(relu.asnumpy()[0,0,:,:],cmap='gray')
plt.show()


Then Pooling


In [9]:
pool = mx.nd.Pooling(data=relu, kernel=(2,2), stride=(2,2), pool_type='max')
print (pool.asnumpy()[0,0,:,:].shape)
plt.imshow(pool.asnumpy()[0,0,:,:],cmap='gray')
plt.show()


(256, 256)

Then FC


In [10]:
w_fc = mx.nd.ones((10,np.prod(pool.shape)))/np.prod(pool.shape)
fc = mx.nd.FullyConnected(data=pool, num_hidden=10, weight=w_fc, no_bias=True)
print (fc.asnumpy())


[[ 0.46400189  0.46400189  0.46400189  0.46400189  0.46400189  0.46400189
   0.46400189  0.46400189  0.46400189  0.46400189]]

Finally Softmax


In [11]:
softmax = mx.nd.softmax(fc)
print (softmax.asnumpy())
# Assume out label is People which index is 1
label = mx.nd.array([1])
loss = mx.nd.softmax_cross_entropy(fc,label)
print (loss.asnumpy())


[[ 0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1]]
[ 2.30258512]

GOOOOOOD JOB


In [ ]: