Clasificación de dígitos MNIST


In [23]:
import mnist
import matplotlib.pyplot as plt
import numpy as np

from sklearn import svm

In [22]:
# Load dataset
train_images = mnist.train_images()
train_labels = mnist.train_labels()
test_images = mnist.test_images()
test_labels = mnist.test_labels()
print('dimensions:')
print('train_images: ', train_images.shape)
print('train_labels: ', train_labels.shape)
print('test_images: ', test_images.shape)
print('test_labels: ', test_labels.shape)


dimensions:
train_images:  (60000, 28, 28)
train_labels:  (60000,)
test_images:  (10000, 28, 28)
test_labels:  (10000,)

In [18]:
plt.imshow(train_images[69, :, :])
plt.show()



In [ ]:
# Training
X_train = np.reshape(train_images,
                     (train_images.shape[0],
                      train_images.shape[1] * train_images.shape[2]))
classifier = svm.SVC()
classifier.fit(X_train, train_labels)

In [27]:



Out[27]:
(60000, 784)