验证码识别


In [2]:
import time
import os
from multiprocessing import Pool

from captcha.image import ImageCaptcha
import numpy as np
import skimage.io as io
import tensorflow as tf

import matplotlib.pylab as plt
%matplotlib inline

In [3]:
GEN_IMG_NUM = 1000000
IMG_H = 64
IMG_W = 160
IMG_CHANNALS = 1
CAPTCHA_SIZE = 4
CAPTCHA_NUM =  36
N_CLASSES = CAPTCHA_SIZE * CAPTCHA_NUM

生成验证码


In [4]:
def gen_baptcha(text):
    image = ImageCaptcha()
    img = image.generate_image(text)
    img = img.convert("L").resize([IMG_W, IMG_H])
    i = hash(time.time())
    img.save('data/%s_%d.png' % (text, i), format='png')


def gen_text(num, size):
    char_set = list('1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ')
    text = np.random.choice(char_set, num*size, replace=True)
    ret = [''.join(t) for t in np.split(text,num)]
    return ret


def gen_baptcha_batch(num, size=CAPTCHA_SIZE):
    text_list = gen_text(num, size)
    pool = Pool(12)
    pool.map(gen_baptcha, text_list)

读取数据


In [6]:
def get_file(file_dir, size=CAPTCHA_SIZE):
    image_list = []
    label_list = []
    filename_list = os.listdir(file_dir)
    for filename in filename_list:
        image_list.append(os.path.join(file_dir, filename))
        label = filename[:size].upper()
        label_list.append(label)
    return image_list, label_list


def _process_image(img_file):
    img = io.imread(img_file)
    return img.tobytes()


def _process_label(label):
    key_list = list('0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ')
    value_list = np.eye(BAPTCHA_NUM, dtype=np.int32).tolist()
    label_dict = dict(zip(key_list, value_list))
    label_ = map(lambda t: label_dict[t], list(label.upper()))
    ret = np.array(label_, dtype=np.uint8).flatten().tobytes()
    return ret


def convert_2_tfrecord(images, labels, save_filename):
    n_samples = len(labels)
    if np.shape(images)[0] != n_samples:
        raise ValueError("Images size %d does not match label size %d" % (images.shape[0], n_samples))

    writer = tf.python_io.TFRecordWriter(save_filename)
    print("\nTransform start ...")
    for i in range(n_samples):
        try:
            image_raw = _process_image(images[i])
            label_raw = _process_label(labels[i])
            example = tf.train.Example(
                features=tf.train.Features(feature={
                    "image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
                    "label_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[label_raw]))
                })
            )
            writer.write(example.SerializeToString())
        except IOError as e:
            print("could not read %s, error %s, skip it" % (images[i], e))
    writer.close()
    print("Transorm done!")


def read_and_decode(tfrecords_file, batch_size):
    filename_queue = tf.train.string_input_producer([tfrecords_file])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    img_features = tf.parse_single_example(serialized_example, features={
        "label_raw": tf.FixedLenFeature([], tf.string),
        "image_raw": tf.FixedLenFeature([], tf.string)
    })
    image = tf.decode_raw(img_features['image_raw'], tf.uint8)
    image = tf.reshape(image, [IMG_H, IMG_W, IMG_CHANNALS])
    image = tf.cast(image, tf.float32)
    label = tf.decode_raw(img_features['label_raw'], tf.uint8)
    label = tf.reshape(label, [N_CLASSES])
    label = tf.cast(label, tf.int32)
    # data argumentaion
    image = tf.image.per_image_standardization(image)
    image_batch, label_batch = tf.train.batch([image, label],
                                              batch_size=batch_size,
                                              num_threads=12,
                                              capacity=1000)
    return image_batch, label_batch


def convert_all():
    val_ratio = 0.2
    image_list, label_list = get_file('data')
    val_size = int(len(image_list) * val_ratio)
    val_image_list, tra_image_list = image_list[:val_size], image_list[val_size:]
    val_label_list, tra_label_list = label_list[:val_size], label_list[val_size:]
    convert_2_tfrecord(tra_image_list, tra_label_list, 'tfrecord/train.tfrecord')
    convert_2_tfrecord(val_image_list, val_label_list, 'tfrecord/validation.tfrecord')


def test_read(batch_size):
    image_batch, label_batch = read_and_decode("./tfrecord/validation.tfrecord", batch_size)
    with tf.Session() as sess:
        i = 0
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            while not coord.should_stop() and i < 1:
                images, labels = sess.run([image_batch, label_batch])
                i += 1
        except tf.errors.OutOfRangeError:
            print("done!")
        finally:
            coord.request_stop()
        coord.join(threads)
    return images, labels
    
def test_plot():
    nr,nc = 10, 5
    batch = nr*nc
    images,_ = test_read(batch)
    plt.figure(figsize=(12,5))
    for i in range(batch):
        plt.subplot(nr,nc,i+1)
        plt.axis("off")
        plt.subplots_adjust(top=1.5)
        plt.imshow(images[i,:,:,0])
    plt.show()

模型


In [7]:
def batchnorm(x):
    epsilon = 1e-3
    batch_mean, batch_var = tf.nn.moments(x, [0])
    x = tf.nn.batch_normalization(x,
                                  mean=batch_mean,
                                  variance=batch_var,
                                  offset=None,
                                  scale=None,
                                  variance_epsilon=epsilon)
    return x


def conv(layer_name, x, out_channels, kernel_size=[3, 3], strides=[1, 1, 1, 1]):
    in_channels = x.get_shape()[-1]
    with tf.variable_scope(layer_name):
        w = tf.get_variable(name="weight",
                            shape=[kernel_size[0], kernel_size[1], in_channels, out_channels],
                            trainable=True,
                            initializer=tf.contrib.layers.xavier_initializer()
                            )
        b = tf.get_variable(name="biases",
                            shape=[out_channels],
                            trainable=True,
                            initializer=tf.constant_initializer(0.0)
                            )
        x = tf.nn.conv2d(x, w, strides, padding="SAME")
        x = tf.nn.bias_add(x, b)
        x = batchnorm(x)
        x = tf.nn.relu(x)
        return x


def pool(layer_name, x, kernel=[1, 2, 2, 1], strides=[1, 2, 2, 1], is_maxpool = True):
    if is_maxpool:
        x = tf.nn.max_pool(x, ksize=kernel, strides=strides, padding="SAME", name=layer_name)
    else:
        x = tf.nn.avg_pool(x, ksize=kernel, strides=strides, padding="SAME", name=layer_name)
    return x



def fc_layer(layer_name, x, out_nodes):
    shape = x.get_shape()
    if len(shape) == 4:
        size = shape[1].value * shape[2].value * shape[3].value
    else:
        size = shape[-1].value

    with tf.variable_scope(layer_name):
        w = tf.get_variable(name="weight",
                            shape=[size, out_nodes],
                            initializer=tf.contrib.layers.xavier_initializer()
                            )
        b = tf.get_variable(name="biases", shape=[out_nodes], initializer=tf.constant_initializer(0.0))
        xflat = tf.reshape(x, [-1, size])
        x = tf.nn.bias_add(tf.matmul(xflat, w), b)
        x = batchnorm(x)
        x = tf.nn.relu(x)
        return x

In [8]:
def interface(x, n_class):
    x = conv("conv-1", x, 32, kernel_size=[3, 3], strides=[1, 1, 1, 1])
    x = pool("pool-1", x, kernel=[1, 2, 2, 1], strides=[1, 2, 2, 1], is_maxpool = True)
    x = conv("conv-2", x, 64, kernel_size=[3, 3], strides=[1, 1, 1, 1])
    x = pool("pool-2", x, kernel=[1, 2, 2, 1], strides=[1, 2, 2, 1], is_maxpool = True)
    x = conv("conv-3", x, 64, kernel_size=[3, 3], strides=[1, 1, 1, 1])
    x = pool("pool-3", x, kernel=[1, 2, 2, 1], strides=[1, 2, 2, 1], is_maxpool = True)
    x = fc_layer("fc4", x, out_nodes=1024)
    x = fc_layer("fc5", x, out_nodes=n_class)
    return x

学习


In [10]:
BATCH_SIZE = 64
MAX_STEP = 100000
learning_rate = 0.001

def train():
    train_log_dir = "./logs/train/"
    val_log_dir = "./logs/val"
    train_data_dir = './tfrecord/train.tfrecord'
    val_data_dir = './tfrecord/validation.tfrecord'
    model_dir = './logs/model'

    with tf.name_scope('input'):
        tra_img_bt, tra_label_bt = read_and_decode(train_data_dir,  BATCH_SIZE)
        val_img_bt, val_label_bt = read_and_decode(val_data_dir,  BATCH_SIZE)

    x = tf.placeholder(dtype=tf.float32, shape=[BATCH_SIZE, IMG_H, IMG_W, IMG_CHANNALS])
    y = tf.placeholder(dtype=tf.int32, shape=[BATCH_SIZE, N_CLASSES])

    logits = interface(tra_img_bt, N_CLASSES)
    logits_ = tf.reshape(logits, [-1, CAPTCHA_SIZE, CAPTCHA_NUM])
    labels_ = tf.reshape(tra_label_bt, [-1, CAPTCHA_SIZE, CAPTCHA_NUM])

    with tf.name_scope("loss"):
        cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits_, labels=labels_, dim=-1)
        loss = tf.reduce_mean(cross_entropy, name="loss")
        tf.summary.scalar("loss", loss)

    with tf.name_scope("accuracy"):
        correct = tf.equal(tf.argmax(logits_, -1), tf.argmax(labels_, -1))
        accuracy_one = tf.reduce_mean(tf.cast(correct, tf.float32), name="accuracy_one")
        tf.summary.scalar("accuracy-one", accuracy_one)
        correct_all = tf.reduce_all(correct, axis=-1)
        accuracy_all = tf.reduce_mean(tf.cast(correct_all, tf.float32), name="accuracy_all")
        tf.summary.scalar("accuracy-all", accuracy_all)

    global_step = tf.Variable(0, name="global_step", trainable=False)
    with tf.name_scope("optimizer"):
        #optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        train_op = optimizer.minimize(loss, global_step=global_step)

    saver = tf.train.Saver(tf.global_variables())
    summary_op = tf.summary.merge_all()

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    tra_summary_write = tf.summary.FileWriter(train_log_dir, sess.graph)
    val_summary_write = tf.summary.FileWriter(val_log_dir, sess.graph)

    try:
        for step in np.arange(MAX_STEP):
            if coord.should_stop():
                break

            tra_img, tra_label = sess.run([tra_img_bt, tra_label_bt])
            _, tra_loss, tra_acc_one, tra_acc_all = sess.run([train_op, loss, accuracy_one, accuracy_all],
                                            feed_dict={x: tra_img, y:tra_label})

            if step % 100 == 0 or (step+1) == MAX_STEP:
                print('train step: %d, loss: %.4f, accuracy: %.4f | %.4f'
                      % (step, tra_loss, tra_acc_one, tra_acc_all))
                summary_str = sess.run(summary_op)
                tra_summary_write.add_summary(summary_str, step)

            if step % 500 == 0 or (step+1) == MAX_STEP:
                val_img, val_label = sess.run([val_img_bt, val_label_bt])
                val_loss, val_acc_one, val_acc_all = sess.run([loss, accuracy_one, accuracy_all],
                                                              feed_dict={x:val_img, y:val_label})
                print('validation step: %d, loss: %.4f, accuracy: %.4f | %.4f'
                      % (step, val_loss, val_acc_one, val_acc_all))
                summary_str = sess.run(summary_op)
                val_summary_write.add_summary(summary_str, step)

            if step % 10000 == 0 or (step+1) == MAX_STEP:
                checkpoint_path = os.path.join(model_dir, "model.ckpt")
                saver.save(sess, checkpoint_path, step)

    except tf.errors.OutOfRangeError:
        print("Done training --epoch limit reached")

    finally:
        coord.request_stop()

    coord.join(threads)
    sess.close()

In [10]:
print("生成数据")
start = time.clock()
gen_baptcha_batch(GEN_IMG_NUM, CAPTCHA_SIZE)
end = time.clock()
print("生成数据Over! Running time .%s S" % (end-start))

print("数据转化为tfrecord格式")
start = time.clock()
convert_all()
end = time.clock()
print("数据转化为tfrecord格式Over! Running time .%s S" % (end-start))

print("训练数据")
start = time.clock()
train()
end = time.clock()
print("训练数据over! Running time .%s S" % (end-start))


Over! Running time .3.60000000015e-05 S

In [ ]: