Decision Trees


In [1]:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
iris = load_iris()
X = iris.data[:, 2:] # petal length and width
y = iris.target
tree_clf = DecisionTreeClassifier(max_depth=2)
tree_clf.fit(X, y)


Out[1]:
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=2,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best')

In [4]:
from sklearn.tree import export_graphviz

dot_data = export_graphviz(
 tree_clf, out_file=None,
 feature_names=iris.feature_names[2:],
 class_names=iris.target_names,
 rounded=True,
 filled=True
 )

graph = graphviz.Source(dot_data)
graph


Out[4]:
Tree 0 petal length (cm) <= 2.45 gini = 0.667 samples = 150 value = [50, 50, 50] class = setosa 1 gini = 0.0 samples = 50 value = [50, 0, 0] class = setosa 0->1 True 2 petal width (cm) <= 1.75 gini = 0.5 samples = 100 value = [0, 50, 50] class = versicolor 0->2 False 3 gini = 0.168 samples = 54 value = [0, 49, 5] class = versicolor 2->3 4 gini = 0.043 samples = 46 value = [0, 1, 45] class = virginica 2->4

In [ ]: