In [3]:
from sklearn import datasets
digits = datasets.load_digits()
print digits.DESCR
手写体数字图像的数据,则存储在digit.images,数组中每个元素表示一张图像,每个元素为 $8 \times 8$形状的矩阵,矩阵各项为数值类型,每个数值对应着一种灰度等级,0代表白色,15代表黑色
In [4]:
digits.images[0]
Out[4]:
借助matplotlib库,生成图像
In [6]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.imshow(digits.images[0],cmap=plt.cm.gray_r,interpolation='nearest')
Out[6]:
In [7]:
digits.target
Out[7]:
In [8]:
digits.target.size
Out[8]:
digits数据集有1797个元素,考虑使用前1791个作为训练集,剩余的6个作为验证集,查看细节
In [10]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.subplot(321)
plt.imshow(digits.images[1791], cmap=plt.cm.gray_r, interpolation='nearest')
plt.subplot(322)
plt.imshow(digits.images[1792], cmap=plt.cm.gray_r, interpolation='nearest')
plt.subplot(323)
plt.imshow(digits.images[1793], cmap=plt.cm.gray_r, interpolation='nearest')
plt.subplot(324)
plt.imshow(digits.images[1794], cmap=plt.cm.gray_r, interpolation='nearest')
plt.subplot(325)
plt.imshow(digits.images[1795], cmap=plt.cm.gray_r, interpolation='nearest')
plt.subplot(326)
plt.imshow(digits.images[1796], cmap=plt.cm.gray_r, interpolation='nearest')
Out[10]:
In [13]:
from sklearn import svm
svc = svm.SVC(gamma=0.0001,C=100.)
svc.fit(digits.data[1:1790],digits.target[1:1790])
Out[13]:
In [14]:
svc.predict(digits.data[1791:1976])
Out[14]:
In [19]:
digits.target[1791:1976]
Out[19]:
In [ ]: