Import


In [ ]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import shutil

from sklearn import datasets, metrics, cross_validation
from tensorflow.contrib import learn

import chainer.functions as F
import chainer.links as L
from chainer import serializers, optimizers, Chain
from commonml.skchainer import ChainerEstimator, SoftmaxCrossEntropyClassifier

import logging
logging.basicConfig(format='%(levelname)s : %(message)s', level=logging.INFO)
logging.root.level = 20

In [ ]:
iris = datasets.load_iris()
X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target,
    test_size=0.2, random_state=42)

In [ ]:
class Model(Chain):

    def __init__(self, in_size):
        super(Model, self).__init__(l1=L.Linear(in_size, 3))

    def __call__(self, x):
        h1 = self.l1(x)
        return h1

classifier = ChainerEstimator(model=SoftmaxCrossEntropyClassifier(Model(X_train.shape[1])),
                              optimizer=optimizers.AdaGrad(lr=0.1),
                              batch_size=100,
                              device=0,
                              stop_trigger=(100, 'epoch'))
classifier.fit(X_train, y_train)
score = metrics.accuracy_score(y_test, classifier.predict(X_test))
print('Accuracy: {0:f}'.format(score))

Clean checkpoint folder if exists


In [ ]:
try:
    shutil.rmtree('/tmp/chainer_examples')
except OSError:
    pass

Save model, parameters and learned variables.


In [ ]:
os.makedirs('/tmp/chainer_examples/')
serializers.save_hdf5('/tmp/chainer_examples/iris_custom_model', classifier.model.predictor)
serializers.save_hdf5('/tmp/chainer_examples/iris_custom_optimizer', classifier.optimizer)
classifier = None

Restore everything


In [ ]:
model = Model(X_train.shape[1])
serializers.load_hdf5('/tmp/chainer_examples/iris_custom_model', model)
new_classifier = ChainerEstimator(model=SoftmaxCrossEntropyClassifier(model),
                                  optimizer=optimizers.AdaGrad(lr=0.1),
                                  batch_size=100,
                                  device=0,
                                  stop_trigger=(100, 'epoch'))
serializers.load_hdf5('/tmp/chainer_examples/iris_custom_optimizer', new_classifier.optimizer)
score = metrics.accuracy_score(y_test, new_classifier.predict(X_test))
print('Accuracy: {0:f}'.format(score))