In [20]:
from sklearn import datasets
digits = datasets.load_digits()
print("Набор данных для обучения (2D NumPy Array): \n", digits.data, '\n')
print("Набор целей для данных (1D NumPy Array): \n", digits.target)


Набор данных для обучения (2D NumPy Array): 
 [[  0.   0.   5. ...,   0.   0.   0.]
 [  0.   0.   0. ...,  10.   0.   0.]
 [  0.   0.   0. ...,  16.   9.   0.]
 ..., 
 [  0.   0.   1. ...,   6.   0.   0.]
 [  0.   0.   2. ...,  12.   0.   0.]
 [  0.   0.  10. ...,  12.   1.   0.]] 

Набор целей для данных (1D NumPy Array): 
 [0 1 2 ..., 8 9 8]

In [21]:
print("Форма массива данных для обучения: \n", digits.data.shape)
print("Форма массива целей: \n", digits.target.shape)


Форма массива данных для обучения: 
 (1797, 64)
Форма массива целей: 
 (1797,)

Какую задачу можно поставить для этого набора данных?


In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
n = 19
print("Каждая цифра представлена матрицей формы ", digits.data[n, :].shape)


Каждая цифра представлена матрицей формы  (64,)

Чтобы отобразить её на экране, нужно применить метод reshape. Целевая форма — $8 \times 8$.


In [3]:
digit = 255 - digits.data[n, :].reshape(8, 8)
plt.imshow(digit, cmap='gray', interpolation='none')
plt.title("This is " + str(digits.target[n]))
plt.show()


Возьмем один из методов прошлой лекции. Например, метод классификации, основанный на деревьях (CART).


In [4]:
from sklearn.tree import DecisionTreeClassifier

Почти у всех классов, отвечающих за методы классификации в scikit-learn, есть следующие методы:

  • fit — обучение модели;
  • predict — классификация примера обученным классификатором;
  • score —оценка качества классификации в соответствии с некоторым критерием.

Чтобы создать дерево-классификатор, достаточно создать объект класса DecisionTreeClassifier


In [24]:
clf = DecisionTreeClassifier(random_state=0)

Обучим классификатор на всех цифрах, кроме последних 10.


In [25]:
clf.fit(digits.data[:-10], digits.target[:-10])


Out[25]:
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_split=1e-07, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            presort=False, random_state=0, splitter='best')

Теперь попробуем классифицировать оставшиеся 10 картинок.


In [26]:
errors = 0
for i in range(1, 11):
    k = clf.predict(digits.data[-i].reshape(1, -1))
    print("Классификатор предсказал число {}, на самом деле это {}. Числа {}совпали."
          .format(k[0], digits.target[-i], 
                  "" if k[0] == digits.target[-i] else "не "))
    
    if k[0] != digits.target[-i]:
        errors += 1


Классификатор предсказал число 8, на самом деле это 8. Числа совпали.
Классификатор предсказал число 9, на самом деле это 9. Числа совпали.
Классификатор предсказал число 8, на самом деле это 8. Числа совпали.
Классификатор предсказал число 0, на самом деле это 0. Числа совпали.
Классификатор предсказал число 9, на самом деле это 9. Числа совпали.
Классификатор предсказал число 4, на самом деле это 4. Числа совпали.
Классификатор предсказал число 1, на самом деле это 8. Числа не совпали.
Классификатор предсказал число 3, на самом деле это 8. Числа не совпали.
Классификатор предсказал число 4, на самом деле это 4. Числа совпали.
Классификатор предсказал число 3, на самом деле это 5. Числа не совпали.

Давайте посмотрим на "проблемные" числа:


In [29]:
fig = plt.figure(figsize=(12, 4))
frame = 1
for i in range(1, 11):
    k = clf.predict(digits.data[-i].reshape(1, -1))
    if k[0] != digits.target[-i]:
        digit = 255 - digits.data[-i, :].reshape(8, 8)
        
        ax = fig.add_subplot(1, errors, frame)        
        ax.imshow(digit, cmap='gray', interpolation='none')
        ax.set_title("This is {}, recognized as {}".format(digits.target[-i], k[0]))
        frame += 1


Можно согласиться, что по крайней мере в двух из этих чисел могут ошибиться и люди.