In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from pyrallel.ensemble import EnsembleGrower
from pyrallel.ensemble import sub_ensemble
from sklearn.datasets import load_digits
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.cross_validation import train_test_split
In [2]:
from IPython.parallel import Client
lb = Client().load_balanced_view()
len(lb)
Out[2]:
In [3]:
digits = load_digits()
X_train, X_test, y_train, y_test = train_test_split(
digits.data, digits.target)
In [4]:
grower = EnsembleGrower(lb, ExtraTreesClassifier(n_estimators=5))
In [5]:
grower.launch(X_train, y_train, n_estimators=100,
folder='digits', name="digits_trees")
Out[5]:
In [9]:
grower
Out[9]:
In [10]:
grower.wait()
Out[10]:
In [12]:
%time final_model = grower.aggregate_model()
print("number of trees: {}".format(final_model.n_estimators))
In [13]:
score = final_model.score(X_test, y_test)
print("score: {:.3f}".format(score))
In [ ]: