Number Recognizer

今回は、ブラウザ上に書いた手書きの数字を認識させます。具体的には、canvasに書かれた数字が0~9のどれであるかを当てさせます。
その予測を行うためのモデルを、以下のステップに沿って作成していきます。


In [1]:
# グラフが文章中に表示されるようにするおまじない
%matplotlib inline

Load the Data

scikit-learnには、最初から手書きの数字を認識するための学習データ(手書き数字の画像データと、その画像の数字が0~9の何れであるかという答えのセット)が搭載されているため、それを利用します。


In [2]:
def load_data():
    from sklearn import datasets
    dataset = datasets.load_digits()
    return dataset

digits = load_data()
print(digits.data.shape)


(1797, 64)

1797は行数、64は次元数です。手書き文字の画像データが8×8のサイズであるため、その中のピクセル情報は64となります(今回値はグレースケールですが、RGBの場合3倍になります)。


In [6]:
def show_image(image):
    import matplotlib.pyplot as plt

    plt.figure(1, figsize=(3, 3))
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.show()

show_image(digits.images[0])


Create the Model

今回扱うのは画像の分類問題になります。そこで、シンプルな線形分類機であるSGDClassifierを利用します。


In [13]:
def make_model():
    from sklearn.linear_model import SGDClassifier    
    clf = SGDClassifier(alpha=0.0001, fit_intercept=True, n_iter=200)
    return clf
    
classifier = make_model()

Training the Model

データとモデルがそろったため、学習させてみます。


In [23]:
classifier.fit(digits.data, digits.target)


Out[23]:
SGDClassifier(alpha=0.0001, average=False, class_weight=None, epsilon=0.1,
       eta0=0.0, fit_intercept=True, l1_ratio=0.15,
       learning_rate='optimal', loss='hinge', n_iter=200, n_jobs=1,
       penalty='l2', power_t=0.5, random_state=None, shuffle=True,
       verbose=0, warm_start=False)

Evaluate the Model

学習させたモデルの精度を計測してみます。predictで予測させることができるので、これで予測させた値と実際の答え(digits.target)を比べてみます。


In [24]:
def calculate_accuracy(model, dataset):
    from sklearn import metrics

    predicted = model.predict(dataset.data)
    score = metrics.accuracy_score(dataset.target, predicted)
    return score

print(calculate_accuracy(classifier, digits))


0.987200890373

Store the Model

最後に、学習させたモデルを保存します。アプリケーション側で、その結果を確認してみてください。


In [16]:
from sklearn.externals import joblib

joblib.dump(classifier, "./machine.pkl")


Out[16]:
['./machine.pkl',
 './machine.pkl_01.npy',
 './machine.pkl_02.npy',
 './machine.pkl_03.npy',
 './machine.pkl_04.npy']

In [ ]: