In [ ]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn import datasets

In [ ]:
# workaround to mldata.org shutdown
# https://github.com/scikit-learn/scikit-learn/issues/8588#issuecomment-292634781
from shutil import copyfileobj
from six.moves import urllib
from sklearn.datasets.base import get_data_home
import os

def fetch_mnist(data_home=None):
    mnist_alternative_url = "https://github.com/amplab/datascience-sp14/raw/master/lab7/mldata/mnist-original.mat"
    data_home = get_data_home(data_home=data_home)
    data_home = os.path.join(data_home, 'mldata')
    if not os.path.exists(data_home):
        os.makedirs(data_home)
    mnist_save_path = os.path.join(data_home, "mnist-original.mat")
    if not os.path.exists(mnist_save_path):
        mnist_url = urllib.request.urlopen(mnist_alternative_url)
        with open(mnist_save_path, "wb") as matlab_file:
            copyfileobj(mnist_url, matlab_file)

In [ ]:
fetch_mnist()
mnist = datasets.fetch_mldata('MNIST original')

In [ ]:
mnist.data.shape

In [ ]:
mnist.target.shape

In [ ]:
np.unique(mnist.target)

In [ ]: