In [20]:
from sklearn import datasets
digits = datasets.load_digits()
print("Набор данных для обучения (2D NumPy Array): \n", digits.data, '\n')
print("Набор целей для данных (1D NumPy Array): \n", digits.target)
In [21]:
print("Форма массива данных для обучения: \n", digits.data.shape)
print("Форма массива целей: \n", digits.target.shape)
Какую задачу можно поставить для этого набора данных?
In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
n = 19
print("Каждая цифра представлена матрицей формы ", digits.data[n, :].shape)
Чтобы отобразить её на экране, нужно применить метод reshape. Целевая форма — $8 \times 8$.
In [3]:
digit = 255 - digits.data[n, :].reshape(8, 8)
plt.imshow(digit, cmap='gray', interpolation='none')
plt.title("This is " + str(digits.target[n]))
plt.show()
Возьмем один из методов прошлой лекции. Например, метод классификации, основанный на деревьях (CART).
In [4]:
from sklearn.tree import DecisionTreeClassifier
Почти у всех классов, отвечающих за методы классификации в scikit-learn, есть следующие методы:
Чтобы создать дерево-классификатор, достаточно создать объект класса DecisionTreeClassifier
In [24]:
clf = DecisionTreeClassifier(random_state=0)
Обучим классификатор на всех цифрах, кроме последних 10.
In [25]:
clf.fit(digits.data[:-10], digits.target[:-10])
Out[25]:
Теперь попробуем классифицировать оставшиеся 10 картинок.
In [26]:
errors = 0
for i in range(1, 11):
k = clf.predict(digits.data[-i].reshape(1, -1))
print("Классификатор предсказал число {}, на самом деле это {}. Числа {}совпали."
.format(k[0], digits.target[-i],
"" if k[0] == digits.target[-i] else "не "))
if k[0] != digits.target[-i]:
errors += 1
Давайте посмотрим на "проблемные" числа:
In [29]:
fig = plt.figure(figsize=(12, 4))
frame = 1
for i in range(1, 11):
k = clf.predict(digits.data[-i].reshape(1, -1))
if k[0] != digits.target[-i]:
digit = 255 - digits.data[-i, :].reshape(8, 8)
ax = fig.add_subplot(1, errors, frame)
ax.imshow(digit, cmap='gray', interpolation='none')
ax.set_title("This is {}, recognized as {}".format(digits.target[-i], k[0]))
frame += 1
Можно согласиться, что по крайней мере в двух из этих чисел могут ошибиться и люди.