In [1]:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
iris = load_iris()
X = iris.data[:, 2:] # petal length and width
y = iris.target
tree_clf = DecisionTreeClassifier(max_depth=2)
tree_clf.fit(X, y)
Out[1]:
In [4]:
from sklearn.tree import export_graphviz
dot_data = export_graphviz(
tree_clf, out_file=None,
feature_names=iris.feature_names[2:],
class_names=iris.target_names,
rounded=True,
filled=True
)
graph = graphviz.Source(dot_data)
graph
Out[4]:
In [ ]: