In [ ]:
import logging
import random
import sys
from io import BytesIO
import gzip
import struct
import mxnet as mx
import numpy as np
from captcha.image import ImageCaptcha
from collections import namedtuple
import matplotlib.pyplot as plt
import cv2
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
In [ ]:
def read_data(label_url, image_url):
with gzip.open(label_url) as flbl:
magic, num = struct.unpack(">II", flbl.read(8))
label = np.fromstring(flbl.read(), dtype=np.int8)
with gzip.open(image_url, 'rb') as fimg:
magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(
len(label), rows, cols)
return (label, image)
In [ ]:
def Get_image_lable(img,lable):
x = [random.randint(0,9) for x in range(3)]
black = np.zeros((28,28),dtype='uint8')
for i in range(3):
if x[i] == 0:
img[:,i*28:(i+1)*28] = black
lable[i] = 10
return img,lable
In [ ]:
def get_image():
(lable, image) = read_data(
't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz')
num = [random.randint(0, 5000 - 1)
for i in range(3)]
img, _ = Get_image_lable(np.hstack(
(image[x] for x in num)), np.array([lable[x] for x in num]))
imgw = 255 - img
cv2.imwrite("img.jpg", imgw)
img = np.multiply(img, 1 / 255.0)
img = img.reshape(1, 1, 28, 84)
return img
In [ ]:
def get_predictnet():
# 数据层
data = mx.symbol.Variable('data')
# 卷积层一
conv1 = mx.symbol.Convolution(data=data, kernel=(5, 5), num_filter=32)
# 池化层一
pool1 = mx.symbol.Pooling(
data=conv1, pool_type="max", kernel=(2, 2), stride=(1, 1))
# 激活层一
relu1 = mx.symbol.Activation(data=pool1, act_type="relu")
# 卷积层二
conv2 = mx.symbol.Convolution(data=relu1, kernel=(5, 5), num_filter=32)
# 池化层二
pool2 = mx.symbol.Pooling(
data=conv2, pool_type="avg", kernel=(2, 2), stride=(1, 1))
# 激活层二
relu2 = mx.symbol.Activation(data=pool2, act_type="relu")
# 卷积层三
conv3 = mx.symbol.Convolution(data=relu2, kernel=(3, 3), num_filter=32)
# 池化层三
pool3 = mx.symbol.Pooling(
data=conv3, pool_type="avg", kernel=(2, 2), stride=(1, 1))
# 激活层三
relu3 = mx.symbol.Activation(data=pool3, act_type="relu")
# 卷积层四
conv4 = mx.symbol.Convolution(data=relu3, kernel=(3, 3), num_filter=32)
# 池化层四
pool4 = mx.symbol.Pooling(
data=conv4, pool_type="avg", kernel=(2, 2), stride=(1, 1))
# 激活层四
relu4 = mx.symbol.Activation(data=pool4, act_type="relu")
# 衔接层
flatten = mx.symbol.Flatten(data=relu4)
# 全链接层一
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=256)
# 第一个数字的全链接层
fc21 = mx.symbol.FullyConnected(data=fc1, num_hidden=11)
# 第二个数字的全链接层
fc22 = mx.symbol.FullyConnected(data=fc1, num_hidden=11)
# 第三个数字的全链接层
fc23 = mx.symbol.FullyConnected(data=fc1, num_hidden=11)
# 联合层,将各个数字链接层的结果联合在一起
fc2 = mx.symbol.Concat(*[fc21, fc22, fc23], dim=0)
# 输出层
SoftmaxOut = mx.symbol.SoftmaxOutput(data=fc2, name="softmax")
out = mx.symbol.Group([SoftmaxOut, conv1, conv2, conv3, conv4])
return out
In [ ]:
_, arg_params, aux_params = mx.model.load_checkpoint("cnn-ocr-mnist", 2)
net = get_predictnet()
predictmod = mx.mod.Module(symbol=net, context=mx.cpu())
predictmod.bind(data_shapes=[('data', (1, 1, 28, 84))])
predictmod.set_params(arg_params, aux_params)
Batch = namedtuple('Batch', ['data'])
In [ ]:
def predict(out):
prob = out[0].asnumpy()
for n in range(4):
cnnout = out[n + 1].asnumpy()
width = int(np.shape(cnnout[0])[1])
height = int(np.shape(cnnout[0])[2])
cimg = np.zeros((width * 8 + 80, height * 4 + 40), dtype=float)
cimg = cimg + 255
k = 0
for i in range(4):
for j in range(8):
cg = cnnout[0][k]
cg = cg.reshape(width, height)
cg = np.multiply(cg, 255)
k = k + 1
gm = np.zeros((width + 10, height + 10), dtype=float)
gm = gm + 255
gm[0:width, 0:height] = cg
cimg[j * (width + 10):(j + 1) * (width + 10), i *
(height + 10):(i + 1) * (height + 10)] = gm
cv2.imwrite("c" + str(n) + ".jpg", cimg)
line = ''
for i in range(prob.shape[0]):
line += str(np.argmax(prob[i]) if int(np.argmax(prob[i]))!=10 else ' ')
return line
In [ ]:
img = get_image()
predictmod.forward(Batch([mx.nd.array(img)]),is_train=False)
out = predictmod.get_outputs()
line = predict(out)
plt.imshow(cv2.imread('img.jpg'), cmap='Greys_r')
plt.axis('off')
plt.show()
print '预测结果:\''+line+'\''
for i in range(4):
plt.imshow(cv2.imread('c'+str(i)+'.jpg'), cmap='Greys_r')
plt.axis('off')
plt.show()