In [1]:
# See: http://scikit-learn.org/stable/modules/tree.html

import numpy as np
import pandas as pd
from sklearn.datasets import load_iris # test dataset to run classification
from sklearn import tree               # sklearn tree classifier

In [2]:
iris = load_iris() # a dictionary of data/metadata for iris dataset

In [3]:
iris


Out[3]:
{'DESCR': 'Iris Plants Database\n====================\n\nNotes\n-----\nData Set Characteristics:\n    :Number of Instances: 150 (50 in each of three classes)\n    :Number of Attributes: 4 numeric, predictive attributes and the class\n    :Attribute Information:\n        - sepal length in cm\n        - sepal width in cm\n        - petal length in cm\n        - petal width in cm\n        - class:\n                - Iris-Setosa\n                - Iris-Versicolour\n                - Iris-Virginica\n    :Summary Statistics:\n\n    ============== ==== ==== ======= ===== ====================\n                    Min  Max   Mean    SD   Class Correlation\n    ============== ==== ==== ======= ===== ====================\n    sepal length:   4.3  7.9   5.84   0.83    0.7826\n    sepal width:    2.0  4.4   3.05   0.43   -0.4194\n    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)\n    petal width:    0.1  2.5   1.20  0.76     0.9565  (high!)\n    ============== ==== ==== ======= ===== ====================\n\n    :Missing Attribute Values: None\n    :Class Distribution: 33.3% for each of 3 classes.\n    :Creator: R.A. Fisher\n    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n    :Date: July, 1988\n\nThis is a copy of UCI ML iris datasets.\nhttp://archive.ics.uci.edu/ml/datasets/Iris\n\nThe famous Iris database, first used by Sir R.A Fisher\n\nThis is perhaps the best known database to be found in the\npattern recognition literature.  Fisher\'s paper is a classic in the field and\nis referenced frequently to this day.  (See Duda & Hart, for example.)  The\ndata set contains 3 classes of 50 instances each, where each class refers to a\ntype of iris plant.  One class is linearly separable from the other 2; the\nlatter are NOT linearly separable from each other.\n\nReferences\n----------\n   - Fisher,R.A. "The use of multiple measurements in taxonomic problems"\n     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to\n     Mathematical Statistics" (John Wiley, NY, 1950).\n   - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.\n     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.\n   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System\n     Structure and Classification Rule for Recognition in Partially Exposed\n     Environments".  IEEE Transactions on Pattern Analysis and Machine\n     Intelligence, Vol. PAMI-2, No. 1, 67-71.\n   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions\n     on Information Theory, May 1972, 431-433.\n   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II\n     conceptual clustering system finds 3 classes in the data.\n   - Many, many more ...\n',
 'data': array([[ 5.1,  3.5,  1.4,  0.2],
        [ 4.9,  3. ,  1.4,  0.2],
        [ 4.7,  3.2,  1.3,  0.2],
        [ 4.6,  3.1,  1.5,  0.2],
        [ 5. ,  3.6,  1.4,  0.2],
        [ 5.4,  3.9,  1.7,  0.4],
        [ 4.6,  3.4,  1.4,  0.3],
        [ 5. ,  3.4,  1.5,  0.2],
        [ 4.4,  2.9,  1.4,  0.2],
        [ 4.9,  3.1,  1.5,  0.1],
        [ 5.4,  3.7,  1.5,  0.2],
        [ 4.8,  3.4,  1.6,  0.2],
        [ 4.8,  3. ,  1.4,  0.1],
        [ 4.3,  3. ,  1.1,  0.1],
        [ 5.8,  4. ,  1.2,  0.2],
        [ 5.7,  4.4,  1.5,  0.4],
        [ 5.4,  3.9,  1.3,  0.4],
        [ 5.1,  3.5,  1.4,  0.3],
        [ 5.7,  3.8,  1.7,  0.3],
        [ 5.1,  3.8,  1.5,  0.3],
        [ 5.4,  3.4,  1.7,  0.2],
        [ 5.1,  3.7,  1.5,  0.4],
        [ 4.6,  3.6,  1. ,  0.2],
        [ 5.1,  3.3,  1.7,  0.5],
        [ 4.8,  3.4,  1.9,  0.2],
        [ 5. ,  3. ,  1.6,  0.2],
        [ 5. ,  3.4,  1.6,  0.4],
        [ 5.2,  3.5,  1.5,  0.2],
        [ 5.2,  3.4,  1.4,  0.2],
        [ 4.7,  3.2,  1.6,  0.2],
        [ 4.8,  3.1,  1.6,  0.2],
        [ 5.4,  3.4,  1.5,  0.4],
        [ 5.2,  4.1,  1.5,  0.1],
        [ 5.5,  4.2,  1.4,  0.2],
        [ 4.9,  3.1,  1.5,  0.1],
        [ 5. ,  3.2,  1.2,  0.2],
        [ 5.5,  3.5,  1.3,  0.2],
        [ 4.9,  3.1,  1.5,  0.1],
        [ 4.4,  3. ,  1.3,  0.2],
        [ 5.1,  3.4,  1.5,  0.2],
        [ 5. ,  3.5,  1.3,  0.3],
        [ 4.5,  2.3,  1.3,  0.3],
        [ 4.4,  3.2,  1.3,  0.2],
        [ 5. ,  3.5,  1.6,  0.6],
        [ 5.1,  3.8,  1.9,  0.4],
        [ 4.8,  3. ,  1.4,  0.3],
        [ 5.1,  3.8,  1.6,  0.2],
        [ 4.6,  3.2,  1.4,  0.2],
        [ 5.3,  3.7,  1.5,  0.2],
        [ 5. ,  3.3,  1.4,  0.2],
        [ 7. ,  3.2,  4.7,  1.4],
        [ 6.4,  3.2,  4.5,  1.5],
        [ 6.9,  3.1,  4.9,  1.5],
        [ 5.5,  2.3,  4. ,  1.3],
        [ 6.5,  2.8,  4.6,  1.5],
        [ 5.7,  2.8,  4.5,  1.3],
        [ 6.3,  3.3,  4.7,  1.6],
        [ 4.9,  2.4,  3.3,  1. ],
        [ 6.6,  2.9,  4.6,  1.3],
        [ 5.2,  2.7,  3.9,  1.4],
        [ 5. ,  2. ,  3.5,  1. ],
        [ 5.9,  3. ,  4.2,  1.5],
        [ 6. ,  2.2,  4. ,  1. ],
        [ 6.1,  2.9,  4.7,  1.4],
        [ 5.6,  2.9,  3.6,  1.3],
        [ 6.7,  3.1,  4.4,  1.4],
        [ 5.6,  3. ,  4.5,  1.5],
        [ 5.8,  2.7,  4.1,  1. ],
        [ 6.2,  2.2,  4.5,  1.5],
        [ 5.6,  2.5,  3.9,  1.1],
        [ 5.9,  3.2,  4.8,  1.8],
        [ 6.1,  2.8,  4. ,  1.3],
        [ 6.3,  2.5,  4.9,  1.5],
        [ 6.1,  2.8,  4.7,  1.2],
        [ 6.4,  2.9,  4.3,  1.3],
        [ 6.6,  3. ,  4.4,  1.4],
        [ 6.8,  2.8,  4.8,  1.4],
        [ 6.7,  3. ,  5. ,  1.7],
        [ 6. ,  2.9,  4.5,  1.5],
        [ 5.7,  2.6,  3.5,  1. ],
        [ 5.5,  2.4,  3.8,  1.1],
        [ 5.5,  2.4,  3.7,  1. ],
        [ 5.8,  2.7,  3.9,  1.2],
        [ 6. ,  2.7,  5.1,  1.6],
        [ 5.4,  3. ,  4.5,  1.5],
        [ 6. ,  3.4,  4.5,  1.6],
        [ 6.7,  3.1,  4.7,  1.5],
        [ 6.3,  2.3,  4.4,  1.3],
        [ 5.6,  3. ,  4.1,  1.3],
        [ 5.5,  2.5,  4. ,  1.3],
        [ 5.5,  2.6,  4.4,  1.2],
        [ 6.1,  3. ,  4.6,  1.4],
        [ 5.8,  2.6,  4. ,  1.2],
        [ 5. ,  2.3,  3.3,  1. ],
        [ 5.6,  2.7,  4.2,  1.3],
        [ 5.7,  3. ,  4.2,  1.2],
        [ 5.7,  2.9,  4.2,  1.3],
        [ 6.2,  2.9,  4.3,  1.3],
        [ 5.1,  2.5,  3. ,  1.1],
        [ 5.7,  2.8,  4.1,  1.3],
        [ 6.3,  3.3,  6. ,  2.5],
        [ 5.8,  2.7,  5.1,  1.9],
        [ 7.1,  3. ,  5.9,  2.1],
        [ 6.3,  2.9,  5.6,  1.8],
        [ 6.5,  3. ,  5.8,  2.2],
        [ 7.6,  3. ,  6.6,  2.1],
        [ 4.9,  2.5,  4.5,  1.7],
        [ 7.3,  2.9,  6.3,  1.8],
        [ 6.7,  2.5,  5.8,  1.8],
        [ 7.2,  3.6,  6.1,  2.5],
        [ 6.5,  3.2,  5.1,  2. ],
        [ 6.4,  2.7,  5.3,  1.9],
        [ 6.8,  3. ,  5.5,  2.1],
        [ 5.7,  2.5,  5. ,  2. ],
        [ 5.8,  2.8,  5.1,  2.4],
        [ 6.4,  3.2,  5.3,  2.3],
        [ 6.5,  3. ,  5.5,  1.8],
        [ 7.7,  3.8,  6.7,  2.2],
        [ 7.7,  2.6,  6.9,  2.3],
        [ 6. ,  2.2,  5. ,  1.5],
        [ 6.9,  3.2,  5.7,  2.3],
        [ 5.6,  2.8,  4.9,  2. ],
        [ 7.7,  2.8,  6.7,  2. ],
        [ 6.3,  2.7,  4.9,  1.8],
        [ 6.7,  3.3,  5.7,  2.1],
        [ 7.2,  3.2,  6. ,  1.8],
        [ 6.2,  2.8,  4.8,  1.8],
        [ 6.1,  3. ,  4.9,  1.8],
        [ 6.4,  2.8,  5.6,  2.1],
        [ 7.2,  3. ,  5.8,  1.6],
        [ 7.4,  2.8,  6.1,  1.9],
        [ 7.9,  3.8,  6.4,  2. ],
        [ 6.4,  2.8,  5.6,  2.2],
        [ 6.3,  2.8,  5.1,  1.5],
        [ 6.1,  2.6,  5.6,  1.4],
        [ 7.7,  3. ,  6.1,  2.3],
        [ 6.3,  3.4,  5.6,  2.4],
        [ 6.4,  3.1,  5.5,  1.8],
        [ 6. ,  3. ,  4.8,  1.8],
        [ 6.9,  3.1,  5.4,  2.1],
        [ 6.7,  3.1,  5.6,  2.4],
        [ 6.9,  3.1,  5.1,  2.3],
        [ 5.8,  2.7,  5.1,  1.9],
        [ 6.8,  3.2,  5.9,  2.3],
        [ 6.7,  3.3,  5.7,  2.5],
        [ 6.7,  3. ,  5.2,  2.3],
        [ 6.3,  2.5,  5. ,  1.9],
        [ 6.5,  3. ,  5.2,  2. ],
        [ 6.2,  3.4,  5.4,  2.3],
        [ 5.9,  3. ,  5.1,  1.8]]),
 'feature_names': ['sepal length (cm)',
  'sepal width (cm)',
  'petal length (cm)',
  'petal width (cm)'],
 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
 'target_names': array(['setosa', 'versicolor', 'virginica'], 
       dtype='|S10')}

In [4]:
## Define sklearn decision tree
# Parameters:
# criterion=’gini’ (gini coefficient) or 'entropy' (information gain) 
# splitter=’best’, 'random'
# max_depth=None, (max depth of the tree--num of steps from root to leaf node)
# min_samples_split=2, (min num. of samples required to do a split)
# min_samples_leaf=1, (min. num. of samples required to make a leaf node)
# min_weight_fraction_leaf=0.0, 
# max_features=None, 
# random_state=None, (used to seed random number generator)
# max_leaf_nodes=None, 
# min_impurity_decrease=0.0, 
# min_impurity_split=None, 
# class_weight=None, 
# presort=False

clf = tree.DecisionTreeClassifier(criterion='entropy', min_samples_split=5)

In [5]:
import random

In [6]:
help(random.sample)


Help on method sample in module random:

sample(self, population, k) method of random.Random instance
    Chooses k unique random elements from a population sequence.
    
    Returns a new list containing elements from the population while
    leaving the original population unchanged.  The resulting list is
    in selection order so that all sub-slices will also be valid random
    samples.  This allows raffle winners (the sample) to be partitioned
    into grand prize and second place winners (the subslices).
    
    Members of the population need not be hashable or unique.  If the
    population contains repeats, then each occurrence is a possible
    selection in the sample.
    
    To choose a sample in a range of integers, use xrange as an argument.
    This is especially fast and space efficient for sampling from a
    large population:   sample(xrange(10000000), 60)


In [7]:
# Split the iris dataset into training and testing data:

tr_idx = random.sample(range(len(iris.data)), int(len(iris.data)*0.8))
tr_data = iris.data[tr_idx]     # training data
tr_target = iris.target[tr_idx] # training prediction target

te_data = iris.data[np.setdiff1d(range(len(iris.data)), tr_idx)]     # test data
te_target = iris.target[np.setdiff1d(range(len(iris.data)), tr_idx)] # test prediction target

In [8]:
# Fitting the model
# iris.data = array of feature variable values
# iris.target = prediction target values

clf = clf.fit(tr_data, tr_target) # inputs and outputs should be in array-format

In [9]:
# let's try to predict the type of iris for
names = np.array(iris.target_names)

results = pd.DataFrame({"Prediction" : names[clf.predict(te_data)], "Truth" : names[te_target]})
results['Correct'] = results.apply(lambda x: x.Prediction==x.Truth, axis=1)
results


Out[9]:
Prediction Truth Correct
0 setosa setosa True
1 setosa setosa True
2 setosa setosa True
3 setosa setosa True
4 setosa setosa True
5 setosa setosa True
6 setosa setosa True
7 setosa setosa True
8 versicolor versicolor True
9 versicolor versicolor True
10 versicolor versicolor True
11 versicolor versicolor True
12 virginica versicolor False
13 versicolor versicolor True
14 versicolor versicolor True
15 versicolor versicolor True
16 versicolor versicolor True
17 versicolor versicolor True
18 versicolor versicolor True
19 virginica virginica True
20 virginica virginica True
21 virginica virginica True
22 virginica virginica True
23 virginica virginica True
24 virginica virginica True
25 versicolor virginica False
26 versicolor virginica False
27 virginica virginica True
28 virginica virginica True
29 virginica virginica True

In [10]:
# overall misclassification rate:

print "Misclassification rate: %3.2f%%" % (float(sum(results.Correct==False))/float(len(results))*100)


Misclassification rate: 10.00%

In [11]:
# Accuracy for each iris species

results.groupby('Truth').apply(lambda x: float(sum(x.Correct))/float(len(x.Correct)))


Out[11]:
Truth
setosa        1.000000
versicolor    0.909091
virginica     0.818182
dtype: float64

In [12]:
# creating a tree visualization:

tree.export_graphviz(clf, out_file='iris.dat', 
                     feature_names=iris.feature_names, 
                     class_names=iris.target_names,
                     filled=True)

In [13]:
%%sh 

more iris.dat # copy and paste text into http://www.webgraphviz.com/


::::::::::::::
iris.dat
::::::::::::::
digraph Tree {
node [shape=box, style="filled", color="black"] ;
0 [label="petal width (cm) <= 0.8\nentropy = 1.5841\nsamples = 120\nvalue = [42, 39, 39]\nclass = setosa", fillcolor="#e5813909"] ;
1 [label="entropy = 0.0\nsamples = 42\nvalue = [42, 0, 0]\nclass = setosa", fillcolor="#e58139ff"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="petal width (cm) <= 1.75\nentropy = 1.0\nsamples = 78\nvalue = [0, 39, 39]\nclass = versicolor", fillcolor="#39e58100"] ;
0 -> 2 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
3 [label="petal width (cm) <= 1.45\nentropy = 0.3712\nsamples = 42\nvalue = [0, 39, 3]\nclass = versicolor", fillcolor="#39e581eb"] ;
2 -> 3 ;
4 [label="entropy = 0.0\nsamples = 28\nvalue = [0, 28, 0]\nclass = versicolor", fillcolor="#39e581ff"] ;
3 -> 4 ;
5 [label="sepal length (cm) <= 5.15\nentropy = 0.7496\nsamples = 14\nvalue = [0, 11, 3]\nclass = versicolor", fillcolor="#39e581b9"] ;
3 -> 5 ;
6 [label="entropy = 0.0\nsamples = 1\nvalue = [0, 0, 1]\nclass = virginica", fillcolor="#8139e5ff"] ;
5 -> 6 ;
7 [label="petal length (cm) <= 4.95\nentropy = 0.6194\nsamples = 13\nvalue = [0, 11, 2]\nclass = versicolor", fillcolor="#39e581d1"] ;
5 -> 7 ;
8 [label="entropy = 0.0\nsamples = 9\nvalue = [0, 9, 0]\nclass = versicolor", fillcolor="#39e581ff"] ;
7 -> 8 ;
9 [label="entropy = 1.0\nsamples = 4\nvalue = [0, 2, 2]\nclass = versicolor", fillcolor="#39e58100"] ;
7 -> 9 ;
10 [label="entropy = 0.0\nsamples = 36\nvalue = [0, 0, 36]\nclass = virginica", fillcolor="#8139e5ff"] ;
2 -> 10 ;
}