In [1]:
%load_ext watermark
%watermark  -d -u -a 'Sebastian Raschka' -v -p numpy,scipy,matplotlib,sklearn


Sebastian Raschka 
last updated: 2019-10-06 

CPython 3.7.1
IPython 7.8.0

numpy 1.17.2
scipy 1.3.1
matplotlib 3.1.0
sklearn 0.21.3

In [2]:
from sklearn import datasets
import numpy as np


iris = datasets.load_iris()
X = iris.data[:, [2, 3]]
y = iris.target

print('Class labels:', np.unique(y))


Class labels: [0 1 2]

In [3]:
from sklearn.model_selection import train_test_split


X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=1, stratify=y)

In [4]:
print('Labels counts in y:', np.bincount(y))
print('Labels counts in y_train:', np.bincount(y_train))
print('Labels counts in y_test:', np.bincount(y_test))


Labels counts in y: [50 50 50]
Labels counts in y_train: [35 35 35]
Labels counts in y_test: [15 15 15]

In [5]:
%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from mlxtend.plotting import plot_decision_regions
from sklearn import tree


dtree = tree.DecisionTreeClassifier(criterion='entropy', 
                              max_depth=2, 
                              random_state=1)
dtree.fit(X_train, y_train)


plot_decision_regions(X_train, y_train, dtree)

plt.xlabel('petal length [cm]')
plt.ylabel('petal width [cm]')
plt.legend(loc='upper left')
plt.tight_layout()
plt.show()



In [6]:
# you may need to run

# conda install pydotplus
# conda install graphviz

# in your command line

In [7]:
from pydotplus import graph_from_dot_data
from sklearn.tree import export_graphviz


dot_data = export_graphviz(dtree,
                           filled=True, 
                           rounded=True,
                           class_names=['Setosa', 
                                        'Versicolor',
                                        'Virginica'],
                           feature_names=['petal length', 
                                          'petal width'],
                           out_file=None) 
graph = graph_from_dot_data(dot_data) 
graph.write_png('tree.png')


Out[7]:
True

In [8]:
from IPython.display import Image


Image('tree.png')


Out[8]:

In [11]:
tree.plot_tree(dtree, 
               filled=True, 
               rounded=True,
               class_names=['Setosa', 
                            'Versicolor',
                            'Virginica'],
               feature_names=['petal length', 
                              'petal width']) 

plt.show()



In [13]:
plt.figure(figsize=(10, 7))

tree.plot_tree(dtree, 
               filled=True, 
               rounded=True,
               class_names=['Setosa', 
                            'Versicolor',
                            'Virginica'],
               feature_names=['petal length', 
                              'petal width']) 

plt.show()



In [ ]: