https://www.cs.toronto.edu/~kriz/cifar.html
The CIFAR-10 dataset
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
The dataset is divided into five training batches and one test batch, each with 10000 images. The test batch contains exactly 1000 randomly-selected images from each class. The training batches contain the remaining images in random order, but some training batches may contain more images from one class than another. Between them, the training batches contain exactly 5000 images from each class.
In [9]:
%matplotlib inline
import matplotlib
import scipy.io
import matplotlib.pyplot as plt
import cPickle
import numpy as np
from scipy.misc import imsave
from IPython.display import Image, display, HTML
In [10]:
!mkdir -p ~/.h2o/datasets/
In [11]:
!wget -c https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz -O ~/.h2o/datasets/cifar-10-python.tar.gz
In [17]:
!tar xvzf ~/.h2o/datasets/cifar-10-python.tar.gz -C ~/.h2o/datasets/
In [24]:
import os.path
with open(os.path.expanduser("~/.h2o/datasets/cifar-10-batches-py/batches.meta")) as fd:
meta = cPickle.load(fd)
print meta
In [206]:
labels = meta['label_names']
labels
Out[206]:
In [31]:
def load_cifar10_image_list(filepath):
images = []
labels = []
with open(filepath, 'rb') as fd:
d = cPickle.load(fd)
for image, label, filename in zip(d['data'], d['labels'], d['filenames']):
x = np.array(image)
x = np.dstack((x[:1024], x[1024:2048], x[2048:]))
x = x.reshape(32,32,3)
filename=os.path.expanduser("~/.h2o/datasets/cifar-10-batches-py/"+filename)
imsave(filename, x)
images.append(filename)
labels.append(label)
return images, labels
In [32]:
x_train = []
y_train = []
for batch in range(1,6):
batch_name = os.path.expanduser('~/.h2o/datasets/cifar-10-batches-py/data_batch_%d' % batch)
x,y = load_cifar10_image_list(batch_name)
x_train.extend(x)
y_train.extend(y)
In [181]:
!ls ~/.h2o/datasets/cifar-10-batches-py/ | sed -n '1~5000p' # show every 5000th file
In [183]:
for x in x_train[:10]:
display(Image(filename=x))
In [198]:
[labels[x] for x in y_train[:10]]
Out[198]:
In [41]:
len(x_train)
Out[41]:
In [43]:
batch_test = os.path.expanduser('~/.h2o/datasets/cifar-10-batches-py/test_batch')
x_test, y_test = load_cifar10_image_list(batch_test)
In [74]:
import h2o
h2o.init()
In [60]:
!nvidia-smi
In [75]:
train_df = {"x0": x_train, "x1": y_train }
In [76]:
test_df = {"x0" : x_test, "x1": y_test }
In [77]:
train_hf = h2o.H2OFrame(train_df)
In [78]:
test_hf = h2o.H2OFrame(test_df)
Let's turn the class label into a factor
In [79]:
train_hf['x1'] = train_hf['x1'].asfactor()
test_hf['x1'] = test_hf['x1'].asfactor()
In [80]:
train_hf.head(10)
Out[80]:
In [81]:
from h2o.estimators.deepwater import H2ODeepWaterEstimator
In [82]:
deepwater_model = H2ODeepWaterEstimator(
epochs=10, ##
nfolds=3, ## 3-fold cross-validation
learning_rate=2e-3,
mini_batch_size=64,
# problem_type='image', ## autodetected by default
network='vgg',
# network_definition_file="mycnn.json" ## provide your own mxnet .json model
image_shape=[32,32],
channels=3,
gpu=True
)
In [83]:
deepwater_model.train(x=['x0'], y='x1', training_frame=train_hf)
Anytime, especially during training, you can inspect the model in Flow (http://localhost:54321)
Here's the first (of three) cross-validation models:
In [87]:
train_error = deepwater_model.model_performance(train=True).mean_per_class_error()
print "training error:", train_error
In [88]:
xval_error = deepwater_model.model_performance(xval=True).mean_per_class_error()
print "cross-validated error:", xval_error
In [207]:
deepwater_model
Out[207]:
In [173]:
random_test_image_hf = test_hf[int(np.random.random()*len(test_df)),:]['x0']
In [174]:
random_test_image_hf
Out[174]:
In [175]:
filename = random_test_image_hf.as_data_frame(use_pandas=False)[1][0]
filename
Out[175]:
In [176]:
Image(filename=filename)
Out[176]:
In [177]:
pred = deepwater_model.predict(random_test_image_hf)
In [178]:
predlabel = int(pred['predict'].as_data_frame(use_pandas=False)[1][0])
In [179]:
labels[predlabel]
Out[179]: