In [1]:
    
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
iris = load_iris()
X = iris.data[:, 2:] # petal length and width
y = iris.target
tree_clf = DecisionTreeClassifier(max_depth=2)
tree_clf.fit(X, y)
    
    Out[1]:
In [4]:
    
from sklearn.tree import export_graphviz
dot_data = export_graphviz(
 tree_clf, out_file=None,
 feature_names=iris.feature_names[2:],
 class_names=iris.target_names,
 rounded=True,
 filled=True
 )
graph = graphviz.Source(dot_data)
graph
    
    Out[4]:
In [ ]: