In [37]:
from __future__ import print_function
import sys, os
import h2o
from h2o.estimators.deepwater import H2ODeepWaterEstimator
import importlib
h2o.init()
In [38]:
if not H2ODeepWaterEstimator.available(): quit()
In [39]:
def lenet(num_classes):
import mxnet as mx
data = mx.symbol.Variable('data')
# first conv
conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20)
tanh1 = mx.symbol.Activation(data=conv1, act_type="tanh")
pool1 = mx.symbol.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2))
# second conv
conv2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50)
tanh2 = mx.symbol.Activation(data=conv2, act_type="tanh")
pool2 = mx.symbol.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2))
# first fullc
flatten = mx.symbol.Flatten(data=pool2)
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)
tanh3 = mx.symbol.Activation(data=fc1, act_type="tanh")
# second fullc
fc2 = mx.symbol.FullyConnected(data=tanh3, num_hidden=num_classes)
# loss
lenet = mx.symbol.SoftmaxOutput(data=fc2, name='softmax')
return lenet
In [40]:
train = h2o.import_file("../../bigdata/laptop/mnist/train.csv.gz")
test = h2o.import_file("../../bigdata/laptop/mnist/test.csv.gz")
predictors = list(range(0,784))
resp = 784
train[resp] = train[resp].asfactor()
test[resp] = test[resp].asfactor()
nclasses = train[resp].nlevels()[0]
Let's create the lenet model architecture from scratch using the MXNet Python API
In [41]:
model = lenet(nclasses)
To import the model inside the DeepWater training engine we need to save the model to a file:
In [42]:
model_path = "/tmp/symbol_lenet-py.json"
model.save(model_path)
The model is just the structure of the network expressed as a json dict
In [43]:
#!head "/tmp/symbol_lenet-py.json"
In [44]:
model = H2ODeepWaterEstimator(epochs=100, learning_rate=1e-3,
mini_batch_size=64,
network='user',
network_definition_file=model_path,
image_shape=[28,28], channels=1)
In [45]:
model.train(x=predictors,y=resp, training_frame=train, validation_frame=test)
In [46]:
model.show()