Autor notebooka: Jakub Nowacki.
Oprócz danych numerycznych i tekstowych, używa się uczenia maszynowego do klasyfikacji obrazu. Jednym z tradycyjnych zadań klasyfikacji obrazu jest MINST. Dane są dostępne do pobrania za pomocą narzędzi scikit-learn.
In [3]:
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_mldata
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import classification_report
import numpy as np
%matplotlib inline
mnist = fetch_mldata("MNIST original")
mnist.data[0]
Out[3]:
In [4]:
np.unique(mnist.target)
Out[4]:
Podzielmy dane w sposób tradycyjny.
In [5]:
X, y = mnist.data / 255., mnist.target
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]
Dane są w istocie obrazami w rozmiarze 28 na 28 pikseli i są to liczby od 0 do 9.
In [6]:
sample_data = np.array([mnist.data[mnist.target == c][np.random.randint(0, 1000)] for c in range(10)])
fig, axes = plt.subplots(3, 3, figsize=(10, 8))
for data, ax in zip(sample_data, axes.ravel()):
ax.matshow(data.reshape(28, 28), cmap=plt.cm.gray)
ax.set_xticks(())
ax.set_yticks(())
Wielowarstwowy perceptron (multi-layer perceptron (MLP)) jest prostą siecią neuronową, która składa się z przynajmniej 3 warstw:
Najczęściej stosowaną funkcją aktywacyjną jest sigmoida w postaci tangensa hyperbolicznego.
MLP stosuje się zarówno do klasyfikacji jak i do regresji, niemniej, przykładzie wykorzystamy klasyfikator.
In [7]:
mlp = MLPClassifier(hidden_layer_sizes=(50,), max_iter=10, alpha=1e-4,
solver='sgd', verbose=10, tol=1e-4, random_state=1,
learning_rate_init=.1)
mlp.fit(X_train, y_train)
print("Training set score: %f" % mlp.score(X_train, y_train))
print("Test set score: %f" % mlp.score(X_test, y_test))
print(classification_report(y_test, mlp.predict(X_test)))
fig, axes = plt.subplots(4, 4, figsize=(10, 8))
vmin, vmax = mlp.coefs_[0].min(), mlp.coefs_[0].max()
for coef, ax in zip(mlp.coefs_[0].T, axes.ravel()):
ax.matshow(coef.reshape(28, 28), cmap=plt.cm.gray, vmin=.5 * vmin,
vmax=.5 * vmax)
ax.set_xticks(())
ax.set_yticks(())