This notebook sets up the the notMNIST dataset. This dataset is designed to look like the classic MNIST dataset, while looking a little more like real data: it's a harder task, and the data is a lot less 'clean' than MNIST.
This notebook is derived from the Udacity Tensorflow Course Assignment 1
In [1]:
%matplotlib inline
from __future__ import print_function
import gzip
import os
import sys
import tarfile
import urllib.request
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display, Image
from scipy import ndimage
from six.moves import cPickle as pickle
import outputer
Download the dataset of characters 'A' to 'J' rendered in various fonts as 28x28 images.
There is training set of about 500k images and a test set of about 19000 images.
In [2]:
url = "http://yaroslavvb.com/upload/notMNIST/"
data_path = outputer.setup_directory("notMNIST")
def maybe_download(path, filename, expected_bytes):
"""Download a file if not present, and make sure it's the right size."""
file_path = os.path.join(path, filename)
if not os.path.exists(file_path):
file_path, _ = urllib.request.urlretrieve(url + filename, file_path)
statinfo = os.stat(file_path)
if statinfo.st_size == expected_bytes:
print("Found", file_path, "with correct size.")
else:
raise Exception("Error downloading " + filename)
return file_path
train_filename = maybe_download(data_path, "notMNIST_large.tar.gz", 247336696)
test_filename = maybe_download(data_path, "notMNIST_small.tar.gz", 8458043)
Extract the dataset from the compressed .tar.gz file. This should give you a set of directories, labelled A through J.
In [3]:
def extract(filename, root, class_count):
# remove path and .tar.gz
dir_name = os.path.splitext(os.path.splitext(os.path.basename(filename))[0])[0]
path = os.path.join(root, dir_name)
print("Extracting", filename, "to", path)
tar = tarfile.open(filename)
tar.extractall(path=root)
tar.close()
data_folders = [os.path.join(path, d) for d in sorted(os.listdir(path))]
if len(data_folders) != class_count:
raise Exception("Expected %d folders, one per class. Found %d instead." %
(class_count, len(data_folders)))
print(data_folders)
return data_folders
train_folders = []
test_folders = []
for name in os.listdir(data_path):
path = os.path.join(data_path, name)
target = None
print("Checking", path)
if path.endswith("_small"):
target = test_folders
elif path.endswith("_large"):
target = train_folders
if target is not None:
target.extend([os.path.join(path, name) for name in os.listdir(path)])
print("Found", target)
expected_classes = 10
if len(train_folders) < expected_classes:
train_folders = extract(train_filename, data_path, expected_classes)
if len(test_folders) < expected_classes:
test_folders = extract(test_filename, data_path, expected_classes)
In [4]:
Image(filename="notMNIST/notMNIST_small/A/MDEtMDEtMDAudHRm.png")
Out[4]:
In [5]:
Image(filename="notMNIST/notMNIST_large/A/a2F6b28udHRm.png")
Out[5]:
In [6]:
Image(filename="notMNIST/notMNIST_large/C/ZXVyb2Z1cmVuY2UgaXRhbGljLnR0Zg==.png")
Out[6]:
In [7]:
# This I is all white
Image(filename="notMNIST/notMNIST_small/I/SVRDIEZyYW5rbGluIEdvdGhpYyBEZW1pLnBmYg==.png")
Out[7]:
Convert the data into an array of normalized grayscale floating point images, and an array of classification labels.
Unreadable images are skipped.
In [8]:
def normalize_separator(path):
return path.replace("\\", "/")
def load(data_folders, set_id, min_count, max_count):
# Create arrays large enough for maximum expected data.
dataset = np.ndarray(shape=(max_count, image_size, image_size), dtype=np.float32)
labels = np.ndarray(shape=(max_count), dtype=np.int32)
label_index = 0
image_index = 0
solid_blacks = []
solid_whites = []
for folder in sorted(data_folders):
print(folder)
for image in os.listdir(folder):
if image_index >= max_count:
raise Exception("More than %d images!" % (max_count,))
image_file = os.path.join(folder, image)
if normalize_separator(image_file) in skip_list:
continue
try:
raw_data = ndimage.imread(image_file)
# Keep track of images a that are solid white or solid black.
if np.all(raw_data == 0):
solid_blacks.append(image_file)
if np.all(raw_data == int(pixel_depth)):
solid_whites.append(image_file)
# Convert to float and normalize.
image_data = (raw_data.astype(float) - pixel_depth / 2) / pixel_depth
if image_data.shape != (image_size, image_size):
raise Exception("Unexpected image shape: %s" % str(image_data.shape))
# Capture the image data and label.
dataset[image_index, :, :] = image_data
labels[image_index] = label_index
image_index += 1
except IOError as e:
skip_list.append(normalize_separator(image_file))
print("Could not read:", image_file, ':', e, "skipping.")
label_index += 1
image_count = image_index
# Trim down to just the used portion of the arrays.
dataset = dataset[0:image_count, :, :]
labels = labels[0:image_count]
if image_count < min_count:
raise Exception('Many fewer images than expected: %d < %d' %
(num_images, min_num_images))
print("Input data shape:", dataset.shape)
print("Mean of all normalized pixels:", np.mean(dataset))
print("Standard deviation of normalized pixels:", np.std(dataset))
print('Labels shape:', labels.shape)
print("Found", len(solid_whites), "solid white images, and",
len(solid_blacks), "solid black images.")
return dataset, labels
In [9]:
train_dataset, train_labels = load(train_folders, "train", 450000, 550000)
test_dataset, test_labels = load(test_folders, 'test', 18000, 20000)
skip_list
Out[9]:
In [10]:
exemplar = plt.imshow(train_dataset[0])
train_labels[0]
Out[10]:
In [11]:
exemplar = plt.imshow(train_dataset[373])
train_labels[373]
Out[11]:
In [12]:
exemplar = plt.imshow(test_dataset[18169])
test_labels[18169]
Out[12]:
In [13]:
exemplar = plt.imshow(train_dataset[-9])
train_labels[-9]
Out[13]:
In [14]:
pickle_file = 'notMNIST/full.pickle'
try:
f = gzip.open(pickle_file, 'wb')
save = {
'train_dataset': train_dataset,
'train_labels': train_labels,
'test_dataset': test_dataset,
'test_labels': test_labels
}
pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)
f.close()
except Exception as e:
print('Unable to save data to', pickle_file, ':', e)
raise
statinfo = os.stat(pickle_file)
print('Compressed pickle size:', statinfo.st_size)
In [ ]: