In [33]:
import random
import sys
from io import BytesIO
import gzip
import struct
import mxnet as mx
from mxnet import nd
from mxnet.gluon import nn
from mxnet import gluon
import numpy as np
import cv2
import matplotlib.pyplot as plt
from collections import namedtuple
In [2]:
def read_data(label_url, image_url):
with gzip.open(label_url) as flbl:
magic, num = struct.unpack(">II", flbl.read(8))
label = np.fromstring(flbl.read(), dtype=np.int8)
with gzip.open(image_url, 'rb') as fimg:
magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(
len(label), rows, cols)
return (label, image)
In [3]:
def Get_image_lable(img, lable):
x = [random.randint(0, 9) for x in range(3)]
black = np.zeros((28, 28), dtype='uint8')
for i in range(3):
if x[i] == 0:
img[:, i * 28:(i + 1) * 28] = black
lable[i] = 10
return img, lable
In [10]:
def GetImage(image,lable):
num = [random.randint(0, 5000 - 1)
for i in range(3)]
img, _ = Get_image_lable(
np.hstack((image[x] for x in num)), np.array([lable[x] for x in num]))
imgw = 255 - img
cv2.imwrite("img.jpg", imgw)
img = np.multiply(img, 1 / 255.0)
img = img.reshape(1, 1, 28, 84)
return img,imgw
In [4]:
class OCRIter():
def __init__(self, count, batch_size, num_label, height, width, lable, image):
self.batch_size = batch_size
self.count = count
self.height = height
self.width = width
self.lable = lable
self.image = image
self.num_label = num_label
def __iter__(self):
for k in range(self.count / self.batch_size):
data = []
label = []
for i in range(self.batch_size):
num = [random.randint(0, self.count - 1)
for i in range(self.num_label)]
img, lab = Get_image_lable(np.hstack(
(self.image[x] for x in num)), np.array([self.lable[x] for x in num]))
img = np.multiply(img, 1 / 255.0)
data.append(img.reshape(1, self.height, self.width))
label.append(lab)
data_all = nd.array(data,ctx=mx.gpu())
label_all = nd.array(label,ctx=mx.gpu())
yield data_all,label_all
In [5]:
def Accuracy(label, pred):
label = label.T.reshape((-1, ))
hit = 0
for i in range(pred.shape[0] / 3):
ok = True
for j in range(3):
k = i * 3 + j
if np.argmax(pred[k]) != int(label[k]):
ok = False
if ok:
hit += 1
return hit
In [6]:
class Cont(nn.HybridBlock):
def __init__(self,**kwargs):
super(Cont,self).__init__(**kwargs)
with self.name_scope():
self.dese0 = nn.Dense(11)
self.dese1 = nn.Dense(11)
self.dese2 = nn.Dense(11)
def hybrid_forward(self,F,X):
return F.concat(*[self.dese0(X),self.dese1(X),self.dese2(X)],dim=0)
In [7]:
def GetNet():
net = nn.HybridSequential()
with net.name_scope():
net.add(nn.Conv2D(channels=32,kernel_size=5,activation='relu'))
net.add(nn.MaxPool2D(pool_size=2,strides=1))
net.add(nn.Conv2D(channels=32,kernel_size=5,activation='relu'))
net.add(nn.AvgPool2D(pool_size=2,strides=1))
net.add(nn.Conv2D(channels=32,kernel_size=3,activation='relu'))
net.add(nn.AvgPool2D(pool_size=2,strides=1))
net.add(nn.Conv2D(channels=32,kernel_size=3,activation='relu'))
net.add(nn.AvgPool2D(pool_size=2,strides=1))
net.add(nn.Flatten())
net.add(nn.Dense(256))
net.add(Cont())
return net
In [35]:
def predict(img,mod):
mod.forward(Batch([mx.nd.array(img)]))
out = mod.get_outputs()
prob = out[0].asnumpy()
line = ''
for i in range(prob.shape[0]):
line += str(np.argmax(prob[i])
if int(np.argmax(prob[i])) != 10 else ' ')
return line
In [9]:
(train_lable, train_image) = read_data(
'../train-labels-idx1-ubyte.gz', '../train-images-idx3-ubyte.gz')
(test_lable, test_image) = read_data(
'../t10k-labels-idx1-ubyte.gz', '../t10k-images-idx3-ubyte.gz')
将Gluon模型转换成Symbol类型保存到硬盘,并且保存Gluon模型参数至硬盘
In [22]:
net = GetNet()
net.load_params('../cnn_mnist_gluon',ctx=mx.cpu())
net.hybridize()
x = mx.sym.var('data')
y = net(x)
y.save('../cnn_mnist_gluon.json')
net.collect_params().save('../cnn_mnist_gluon.params')
加载Symbol模型网络,绑定模型至mod,加载模型参数,设置模型参数
In [45]:
symnet = mx.symbol.load('../cnn_mnist_gluon.json')
mod = mx.mod.Module(symbol=y, context=mx.cpu())
mod.bind(data_shapes=[('data', (1, 1, 28, 84))])
params = nd.load('../cnn_mnist_gluon.params')
mod.set_params(arg_params=params,aux_params={})
Batch = namedtuple('Batch', ['data'])
In [56]:
img,imgw = GetImage(test_image,test_lable)
line = predict(img,mod)
plt.imshow(imgw, cmap='Greys_r')
plt.axis('off')
plt.show()
print(line)