In [5]:
# (p. 167ff) Decision Trees
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
iris = load_iris()
X = iris.data[:, 2:] # petal length and width
y = iris.target
tree_clf = DecisionTreeClassifier(max_depth=2)
tree_clf.fit(X, y)
from sklearn.tree import export_graphviz
export_graphviz(
tree_clf,
out_file="iris_tree.dot",
feature_names=iris.feature_names[2:],
class_names=iris.target_names,
rounded=True,
filled=True
)
In [6]:
tree_clf.predict_proba([[5, 1.5]])
Out[6]:
In [7]:
tree_clf.predict([[5, 1.5]])
Out[7]:
In [ ]: