In [1]:
# グラフが文章中に表示されるようにするおまじない
%matplotlib inline
scikit-learnには、最初から手書きの数字を認識するための学習データ(手書き数字の画像データと、その画像の数字が0~9の何れであるかという答えのセット)が搭載されているため、それを利用します。
In [2]:
from sklearn import datasets
digits = datasets.load_digits()
print(digits.data.shape)
1797
は行数、64
は次元数です。手書き文字の画像データが8×8のサイズであるため、その中のピクセル情報は64となります。
In [3]:
import matplotlib.pyplot as plt
plt.figure(1, figsize=(3, 3))
plt.imshow(digits.images[0], cmap=plt.cm.gray_r, interpolation='nearest')
plt.show()
今回扱うのは画像の分類問題になります。そこで、分類問題で非常によく使われるSupport Vector Machineを利用します。
In [4]:
from sklearn import svm
create_model = lambda : svm.SVC(C=1, gamma=0.0001)
classifier = create_model()
データとモデルがそろったため、学習させてみます。
In [5]:
classifier.fit(digits.data, digits.target)
Out[5]:
学習させたモデルの精度を計測してみます。predict
で予測させることができるので、これで予測させた値と実際の答え(digits.target
)を比べてみます。
In [6]:
from sklearn import metrics
predicted = classifier.predict(digits.data)
score = metrics.accuracy_score(digits.target, predicted)
print(score)
最後に、学習させたモデルを保存します。アプリケーション側で、その結果を確認してみてください。
In [17]:
from sklearn.externals import joblib
joblib.dump(classifier, "./machine.pkl")
Out[17]:
In [ ]: