In [ ]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn import datasets
In [ ]:
# workaround to mldata.org shutdown
# https://github.com/scikit-learn/scikit-learn/issues/8588#issuecomment-292634781
from shutil import copyfileobj
from six.moves import urllib
from sklearn.datasets.base import get_data_home
import os
def fetch_mnist(data_home=None):
mnist_alternative_url = "https://github.com/amplab/datascience-sp14/raw/master/lab7/mldata/mnist-original.mat"
data_home = get_data_home(data_home=data_home)
data_home = os.path.join(data_home, 'mldata')
if not os.path.exists(data_home):
os.makedirs(data_home)
mnist_save_path = os.path.join(data_home, "mnist-original.mat")
if not os.path.exists(mnist_save_path):
mnist_url = urllib.request.urlopen(mnist_alternative_url)
with open(mnist_save_path, "wb") as matlab_file:
copyfileobj(mnist_url, matlab_file)
In [ ]:
fetch_mnist()
mnist = datasets.fetch_mldata('MNIST original')
In [ ]:
mnist.data.shape
In [ ]:
mnist.target.shape
In [ ]:
np.unique(mnist.target)
In [ ]: