基于卷积神经网络的手写数字序列识别(训练)

导入库


In [ ]:
import logging
import random
import sys
from io import BytesIO
import gzip
import struct
import mxnet as mx
import numpy as np
from captcha.image import ImageCaptcha
from collections import namedtuple
import matplotlib.pyplot as plt
import cv2
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

读取数据

这块的功能是读取手写数字数据集(mnist),代码由mxnet官方教程提供


In [ ]:
def read_data(label_url, image_url):
    with gzip.open(label_url) as flbl:
        magic, num = struct.unpack(">II", flbl.read(8))
        label = np.fromstring(flbl.read(), dtype=np.int8)
    with gzip.open(image_url, 'rb') as fimg:
        magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
        image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(
            len(label), rows, cols)
    return (label, image)

设置训练网络模型

训练网络模型有四层卷积层和两层全链接层,第二层全链接层分三个分别识别三个数字


In [ ]:
def get_ocrtrain():
    # 数据层
    data = mx.symbol.Variable('data')
    label = mx.symbol.Variable('softmax_label')
    
    # 卷积层一
    conv1 = mx.symbol.Convolution(data=data, kernel=(5, 5), num_filter=32)
    # 池化层一
    pool1 = mx.symbol.Pooling(
        data=conv1, pool_type="max", kernel=(2, 2), stride=(1, 1))
    # 激活层一
    relu1 = mx.symbol.Activation(data=pool1, act_type="relu")

    # 卷积层二
    conv2 = mx.symbol.Convolution(data=relu1, kernel=(5, 5), num_filter=32)
    # 池化层二
    pool2 = mx.symbol.Pooling(
        data=conv2, pool_type="avg", kernel=(2, 2), stride=(1, 1))
    # 激活层二
    relu2 = mx.symbol.Activation(data=pool2, act_type="relu")

    # 卷积层三
    conv3 = mx.symbol.Convolution(data=relu2, kernel=(3, 3), num_filter=32)
    # 池化层三
    pool3 = mx.symbol.Pooling(
        data=conv3, pool_type="avg", kernel=(2, 2), stride=(1, 1))
    # 激活层三
    relu3 = mx.symbol.Activation(data=pool3, act_type="relu")

    # 卷积层四
    conv4 = mx.symbol.Convolution(data=relu3, kernel=(3, 3), num_filter=32)
    # 池化层四
    pool4 = mx.symbol.Pooling(
        data=conv4, pool_type="avg", kernel=(2, 2), stride=(1, 1))
    # 激活层四
    relu4 = mx.symbol.Activation(data=pool4, act_type="relu")

    # 衔接层
    flatten = mx.symbol.Flatten(data=relu4)
    # 全链接层一
    fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=256)
    # 第一个数字的全链接层
    fc21 = mx.symbol.FullyConnected(data=fc1, num_hidden=11)
    # 第二个数字的全链接层
    fc22 = mx.symbol.FullyConnected(data=fc1, num_hidden=11)
    # 第三个数字的全链接层
    fc23 = mx.symbol.FullyConnected(data=fc1, num_hidden=11)
    # 联合层,将各个数字链接层的结果联合在一起
    fc2 = mx.symbol.Concat(*[fc21, fc22, fc23], dim=0)
    
    # 标签处理层
    label = mx.symbol.transpose(data=label)
    label = mx.symbol.Reshape(data=label, target_shape=(0, ))
    
    # 输出层
    SoftmaxOut = mx.symbol.SoftmaxOutput(data=fc2, label=label, name="softmax")
    return SoftmaxOut

网络可视化

可视化整个网络


In [ ]:
network = get_ocrtrain()
shape = {"data": (8, 1, 28, 84), "softmax_label": (8, 3)}
g = mx.viz.plot_network(network, shape=shape).view(cleanup=True,filename='net')

构建数据生成器

继承mxnet提供的生成器接口,并且从手写数字数据集中随机生成拼接成连续数字

图像和标签的合成


In [ ]:
def Get_image_lable(img,lable):
    x = [random.randint(0,9) for x in range(3)]
    black = np.zeros((28,28),dtype='uint8')
    for i in range(3):
        if x[i] == 0:
            img[:,i*28:(i+1)*28] = black
            lable[i] = 10
    return img,lable

迭代器


In [ ]:
class OCRIter(mx.io.DataIter):
    def __init__(self, count, batch_size, num_label, height, width, lable, image):
        super(OCRIter, self).__init__()

        self.batch_size = batch_size
        self.count = count
        self.height = height
        self.width = width
        self.provide_data = [('data', (batch_size, 1, height, width))]
        self.provide_label = [('softmax_label', (self.batch_size, num_label))]
        self.lable = lable
        self.image = image
        self.num_label = num_label

    def __iter__(self):
        for k in range(self.count / self.batch_size):
            data = []
            label = []
            for i in range(self.batch_size):
                num = [random.randint(0, self.count - 1)
                       for i in range(self.num_label)]
                img, lab = Get_image_lable(np.hstack(
                    (self.image[x] for x in num)), np.array([self.lable[x] for x in num]))        
                img = np.multiply(img, 1 / 255.0)
                data.append(img.reshape(1, self.height, self.width))
                label.append(lab)

            data_all = [mx.nd.array(data)]
            label_all = [mx.nd.array(label)]
            data_names = ['data']
            label_names = ['softmax_label']

            data_batch = OCRBatch(data_names, data_all, label_names, label_all)
            yield data_batch

    def reset(self):
        pass

数据准备


In [ ]:
class OCRBatch(object):
    def __init__(self, data_names, data, label_names, label):
        self.data = data
        self.label = label
        self.data_names = data_names
        self.label_names = label_names

    @property
    def provide_data(self):
        return [(n, x.shape) for n, x in zip(self.data_names, self.data)]

    @property
    def provide_label(self):
        return [(n, x.shape) for n, x in zip(self.label_names, self.label)]

验证函数


In [ ]:
def Accuracy(label, pred):
    label = label.T.reshape((-1, ))
    hit = 0
    total = 0
    for i in range(pred.shape[0] / 3):
        ok = True
        for j in range(3):
            k = i * 3 + j
            if np.argmax(pred[k]) != int(label[k]):
                ok = False
        if ok:
            hit += 1
        total += 1
    return 1.0 * hit / total

训练


In [ ]:
# 设置计算平台
devs = [mx.cpu(i) for i in range(1)]

#_, arg_params, __ = mx.model.load_checkpoint("cnn-ocr-mnist", 1)

# 创建训练模型
model = mx.mod.Module(network, context=devs)

# 读取数据集
(train_lable, train_image) = read_data(
    'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz')
(test_lable, test_image) = read_data(
    't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz')

# 构建数据生成器
batch_size = 8
data_train = OCRIter(60000, batch_size, 3, 28,
                     84, train_lable, train_image)
data_test = OCRIter(5000, batch_size, 3, 28, 84, test_lable, test_image)

# 训练
model.fit(
    data_train,
    eval_data=data_test,
    num_epoch=1,
    optimizer='sgd',
    eval_metric=Accuracy,
    #arg_params=arg_params,
    initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
    optimizer_params={'learning_rate': 0.001, 'wd': 0.00001},
    batch_end_callback=mx.callback.Speedometer(batch_size, 50),
)
# 保存训练模型
model.save_checkpoint(prefix="cnn-ocr-mnist", epoch=2)

识别

创建识别网络

因为训练网络需要提供标签,而识别的时候不需要提供标签,所以重写了识别网络,也就是去掉了标签相关层,并且在最后加了层组合层,将每层卷积层处理后的结果也一并返回


In [ ]:
def get_predictnet():
    # 数据层
    data = mx.symbol.Variable('data')
    
    # 卷积层一
    conv1 = mx.symbol.Convolution(data=data, kernel=(5, 5), num_filter=32)
    # 池化层一
    pool1 = mx.symbol.Pooling(
        data=conv1, pool_type="max", kernel=(2, 2), stride=(1, 1))
    # 激活层一
    relu1 = mx.symbol.Activation(data=pool1, act_type="relu")

    # 卷积层二
    conv2 = mx.symbol.Convolution(data=relu1, kernel=(5, 5), num_filter=32)
    # 池化层二
    pool2 = mx.symbol.Pooling(
        data=conv2, pool_type="avg", kernel=(2, 2), stride=(1, 1))
    # 激活层二
    relu2 = mx.symbol.Activation(data=pool2, act_type="relu")

    # 卷积层三
    conv3 = mx.symbol.Convolution(data=relu2, kernel=(3, 3), num_filter=32)
    # 池化层三
    pool3 = mx.symbol.Pooling(
        data=conv3, pool_type="avg", kernel=(2, 2), stride=(1, 1))
    # 激活层三
    relu3 = mx.symbol.Activation(data=pool3, act_type="relu")

    # 卷积层四
    conv4 = mx.symbol.Convolution(data=relu3, kernel=(3, 3), num_filter=32)
    # 池化层四
    pool4 = mx.symbol.Pooling(
        data=conv4, pool_type="avg", kernel=(2, 2), stride=(1, 1))
    # 激活层四
    relu4 = mx.symbol.Activation(data=pool4, act_type="relu")

    # 衔接层
    flatten = mx.symbol.Flatten(data=relu4)
    # 全链接层一
    fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=256)
    # 第一个数字的全链接层
    fc21 = mx.symbol.FullyConnected(data=fc1, num_hidden=11)
    # 第二个数字的全链接层
    fc22 = mx.symbol.FullyConnected(data=fc1, num_hidden=11)
    # 第三个数字的全链接层
    fc23 = mx.symbol.FullyConnected(data=fc1, num_hidden=11)
    # 联合层,将各个数字链接层的结果联合在一起
    fc2 = mx.symbol.Concat(*[fc21, fc22, fc23], dim=0)

    
    # 输出层
    SoftmaxOut = mx.symbol.SoftmaxOutput(data=fc2, name="softmax")
    out = mx.symbol.Group([SoftmaxOut, conv1, conv2, conv3, conv4])

    return out

获取图像

随机组合手写数字,并将图像处理适合网络输入


In [ ]:
def get_image():
    (lable, image) = read_data(
        't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz')

    num = [random.randint(0, 5000 - 1)
           for i in range(3)]
    
    img, _ = Get_image_lable(np.hstack(
                    (image[x] for x in num)), np.array([lable[x] for x in num]))

    imgw = 255 - img
    cv2.imwrite("img.jpg", imgw)
    img = np.multiply(img, 1 / 255.0)
    img = img.reshape(1, 1, 28, 84)
    return img

加载模型设置参数


In [ ]:
_, arg_params, aux_params = mx.model.load_checkpoint("cnn-ocr-mnist", 1)
net = get_predictnet()

predictmod = mx.mod.Module(symbol=net, context=mx.cpu())
predictmod.bind(data_shapes=[('data', (1, 1, 28, 84))])
predictmod.set_params(arg_params, aux_params)
Batch = namedtuple('Batch', ['data'])

处理识别结果


In [ ]:
def predict(out):
    
    prob = out[0].asnumpy()

    for n in range(4):
        cnnout = out[n + 1].asnumpy()
        width = int(np.shape(cnnout[0])[1])
        height = int(np.shape(cnnout[0])[2])
        cimg = np.zeros((width * 8 + 80, height * 4 + 40), dtype=float)
        cimg = cimg + 255
        k = 0
        for i in range(4):
            for j in range(8):
                cg = cnnout[0][k]
                cg = cg.reshape(width, height)
                cg = np.multiply(cg, 255)
                k = k + 1
                gm = np.zeros((width + 10, height + 10), dtype=float)
                gm = gm + 255
                gm[0:width, 0:height] = cg
                cimg[j * (width + 10):(j + 1) * (width + 10), i *
                     (height + 10):(i + 1) * (height + 10)] = gm
        cv2.imwrite("c" + str(n) + ".jpg", cimg)
        
    line = ''
    for i in range(prob.shape[0]):
        line += str(np.argmax(prob[i]) if int(np.argmax(prob[i]))!=10 else ' ')
    return line

识别


In [ ]:
img = get_image()

predictmod.forward(Batch([mx.nd.array(img)]),is_train=False)
out = predictmod.get_outputs()

line = predict(out)


plt.imshow(cv2.imread('img.jpg'), cmap='Greys_r')
plt.axis('off')
plt.show()
print '预测结果:\''+line+'\''
for i in range(4):
    plt.imshow(cv2.imread('c'+str(i)+'.jpg'), cmap='Greys_r')
    plt.axis('off')
    plt.show()

开启web服务


In [ ]:
import BaseHTTPServer
import CGIHTTPServer

HOST = ''
PORT = 8000

# Create the server, CGIHTTPRequestHandler is pre-defined handler
server = BaseHTTPServer.HTTPServer(
    (HOST, PORT), CGIHTTPServer.CGIHTTPRequestHandler)
# Start the server
server.serve_forever()