Classification - Decision Tree Primer

Classify Iris (flowers) by their sepal/petal width/length to their species: 'setosa' 'versicolor' 'virginica' Original Image


In [10]:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from plotting_utilities import plot_decision_tree, plot_feature_importances
from sklearn.model_selection import train_test_split

import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

iris = load_iris()
iris.DESCR.split('\n')


Out[10]:
['Iris Plants Database',
 '====================',
 '',
 'Notes',
 '-----',
 'Data Set Characteristics:',
 '    :Number of Instances: 150 (50 in each of three classes)',
 '    :Number of Attributes: 4 numeric, predictive attributes and the class',
 '    :Attribute Information:',
 '        - sepal length in cm',
 '        - sepal width in cm',
 '        - petal length in cm',
 '        - petal width in cm',
 '        - class:',
 '                - Iris-Setosa',
 '                - Iris-Versicolour',
 '                - Iris-Virginica',
 '    :Summary Statistics:',
 '',
 '    ============== ==== ==== ======= ===== ====================',
 '                    Min  Max   Mean    SD   Class Correlation',
 '    ============== ==== ==== ======= ===== ====================',
 '    sepal length:   4.3  7.9   5.84   0.83    0.7826',
 '    sepal width:    2.0  4.4   3.05   0.43   -0.4194',
 '    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)',
 '    petal width:    0.1  2.5   1.20  0.76     0.9565  (high!)',
 '    ============== ==== ==== ======= ===== ====================',
 '',
 '    :Missing Attribute Values: None',
 '    :Class Distribution: 33.3% for each of 3 classes.',
 '    :Creator: R.A. Fisher',
 '    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)',
 '    :Date: July, 1988',
 '',
 'This is a copy of UCI ML iris datasets.',
 'http://archive.ics.uci.edu/ml/datasets/Iris',
 '',
 'The famous Iris database, first used by Sir R.A Fisher',
 '',
 'This is perhaps the best known database to be found in the',
 "pattern recognition literature.  Fisher's paper is a classic in the field and",
 'is referenced frequently to this day.  (See Duda & Hart, for example.)  The',
 'data set contains 3 classes of 50 instances each, where each class refers to a',
 'type of iris plant.  One class is linearly separable from the other 2; the',
 'latter are NOT linearly separable from each other.',
 '',
 'References',
 '----------',
 '   - Fisher,R.A. "The use of multiple measurements in taxonomic problems"',
 '     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to',
 '     Mathematical Statistics" (John Wiley, NY, 1950).',
 '   - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.',
 '     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.',
 '   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System',
 '     Structure and Classification Rule for Recognition in Partially Exposed',
 '     Environments".  IEEE Transactions on Pattern Analysis and Machine',
 '     Intelligence, Vol. PAMI-2, No. 1, 67-71.',
 '   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions',
 '     on Information Theory, May 1972, 431-433.',
 '   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II',
 '     conceptual clustering system finds 3 classes in the data.',
 '   - Many, many more ...',
 '']

In [11]:
# IN: Features aka Predictors
print(iris.data.dtype)
print(iris.data.shape)

print(iris.feature_names)
iris.data[:5,:]


float64
(150, 4)
['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
Out[11]:
array([[ 5.1,  3.5,  1.4,  0.2],
       [ 4.9,  3. ,  1.4,  0.2],
       [ 4.7,  3.2,  1.3,  0.2],
       [ 4.6,  3.1,  1.5,  0.2],
       [ 5. ,  3.6,  1.4,  0.2]])

In [12]:
# OUT: Target, here: species
print(iris.target.dtype)
print(iris.target.shape)

print(iris.target_names)
iris.target[:5]


int64
(150,)
['setosa' 'versicolor' 'virginica']
Out[12]:
array([0, 0, 0, 0, 0])

Task: Create a Decision Tree

to be able to classify an unseen Iris by sepal/petal with into its species: 'setosa' 'versicolor' 'virginica'


In [13]:
X = iris.data
y = iris.target

# TODO: Try with and without max_depth (setting also avoids overfitting)
# clf = DecisionTreeClassifier().fit(X, y)
clf = DecisionTreeClassifier(max_depth = 3).fit(X, y)
plot_decision_tree(clf, iris.feature_names, iris.target_names)


Out[13]:
Tree 0 petal width (cm) <= 0.8 samples = 150 value = [50, 50, 50] class = setosa 1 samples = 50 value = [50, 0, 0] class = setosa 0->1 True 2 petal width (cm) <= 1.75 samples = 100 value = [0, 50, 50] class = versicolor 0->2 False 3 petal length (cm) <= 4.95 samples = 54 value = [0, 49, 5] class = versicolor 2->3 6 petal length (cm) <= 4.85 samples = 46 value = [0, 1, 45] class = virginica 2->6 4 samples = 48 value = [0, 47, 1] class = versicolor 3->4 5 samples = 6 value = [0, 2, 4] class = virginica 3->5 7 samples = 3 value = [0, 1, 2] class = virginica 6->7 8 samples = 43 value = [0, 0, 43] class = virginica 6->8

Wait, how do I know that the Decision Tree works???

A: Split your data into test and train and evaluate with the test data.


In [14]:
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state = 3)

# Train the classifier only with the trainings data
clf = DecisionTreeClassifier().fit(X_train, y_train)

In [15]:
# predict for the test data and compare with the actual outcome
y_pred = clf.predict(X_test)

from sklearn.metrics import confusion_matrix

print(" ------ Predicted ")
print(" Actual ")
confusion_matrix(y_test, y_pred)


 ------ Predicted 
 Actual 
Out[15]:
array([[15,  0,  0],
       [ 0, 11,  1],
       [ 0,  1, 10]])

In [16]:
print('Accuracy of Decision Tree classifier on test set == sum(TP)/sum(): {}'.format((15+11+11)/(15+11+11+1)))
print('Accuracy of Decision Tree classifier on test set with "score"-function: {:.2f}'
     .format(clf.score(X_test, y_test)))


Accuracy of Decision Tree classifier on test set == sum(TP)/sum(): 0.9736842105263158
Accuracy of Decision Tree classifier on test set with "score"-function: 0.95

Feature importance

TODO: Compare with level in Tree


In [17]:
plt.figure(figsize=(10,4), dpi=80)
plot_feature_importances(clf, np.array(iris.feature_names))
plt.show()

print('Feature names      : {}'.format(iris.feature_names))
print('Feature importances: {}'.format(clf.feature_importances_))


Feature names      : ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
Feature importances: [ 0.          0.02457904  0.55984437  0.41557658]