In [ ]:
import mxnet as mx
import matplotlib.pyplot as plt
import numpy as np
%pylab inline
pylab.rcParams['figure.figsize'] = (2, 3)
In [ ]:
# Step 1 data
# input data debug
data_iter = mx.io.ImageRecordIter(
path_imgrec = 'data/cifar10_train.rec',
data_shape = (3,28,28),
label_width = 1,
batch_size = 128
)
print (data_iter)
i = 0
for each in data_iter:
i+=1
if i>5:
break
print each
batch_numpy = each.data[0].asnumpy()
label_numpy = each.label[0].asnumpy()
print (type(batch_numpy))
print (type(label_numpy))
#show img
randidx = np.random.randint(0,128)
img = batch_numpy[randidx]
img = np.squeeze(img).sum(axis=0)
plt.imshow(img, cmap='gray')
plt.show()
In [ ]:
# debug model
from importlib import import_module
net = import_module('symbols.'+'resnet')
sym = net.get_symbol(10,20,"3,28,28")
model_prefix = 'cifar10_resnet'
#check_point = mx.callback.do_checkpoint(model_prefix)
arg_name = sym.list_arguments()
out_name = sym.list_outputs()
print (arg_name)
print (out_name)
mx.viz.plot_network(sym,hide_weights=True,save_format='pdf',title='resnet8')
In [ ]:
In [ ]: