In [1]:
# setup our standard computation environment
import numpy as np, pandas as pd, matplotlib.pyplot as plt, seaborn as sns
%matplotlib inline
sns.set_style('darkgrid')
sns.set_context('poster')

In [2]:
# set random seed for reproducibility
np.random.seed(12345)

In [3]:
# load the decision tree module of sklearn
import sklearn.tree

In [4]:
# simulate some data from a familiar distribution
x_true = np.linspace(0,15,1000)
y_true = np.cos(x_true)

sigma_true = .3
x_train = np.random.choice(x_true, size=100)
y_train = np.random.laplace(np.cos(x_train), sigma_true)

In [5]:
plt.plot(x_true, y_true, '-', label='Truth')
plt.plot(x_train, y_train, 's', label='Train')
plt.legend()


Out[5]:
<matplotlib.legend.Legend at 0x7fcc0048ddd0>

In [6]:
# make a DecisionTreeRegressor
dt = sklearn.tree.DecisionTreeRegressor()

In [7]:
x_train[:5,None]


Out[7]:
array([[ 7.23723724],
       [ 7.28228228],
       [ 4.27927928],
       [ 1.93693694],
       [ 6.30630631]])

In [8]:
# fit it to the simulated training data
X_train = x_train[:,None]
dt.fit(X_train, y_train)


Out[8]:
DecisionTreeRegressor(compute_importances=None, criterion='mse',
           max_depth=None, max_features=None, max_leaf_nodes=None,
           min_density=None, min_samples_leaf=1, min_samples_split=2,
           random_state=None, splitter='best')

In [9]:
# predict for a range of x values
X_true = x_true[:,None]  # horrible, but remember it!
y_pred = dt.predict(X_true)

In [10]:
# have a look
plt.plot(x_true, y_true, '-', label='Truth')
plt.plot(x_train, y_train, 's', label='Train')
plt.plot(x_true, y_pred, '-', label='Predicted')
plt.legend()


Out[10]:
<matplotlib.legend.Legend at 0x7fcbfe70a0d0>

In [11]:
# today we are going to look in-depth at the decision tree itself

# can you find it?
dt.tree_


Out[11]:
<sklearn.tree._tree.Tree at 0x7fcc025ab6b0>

An aside on the strange use of underscores in Python, and just how strange it could get, if you let it:


In [12]:
_ = 10
__ = _ + 20

In [13]:
#dt.tree_ = 5

In [14]:
# what do the sklearn docs have to say about this?
help(dt.tree_)


Help on Tree object:

class Tree(__builtin__.object)
 |  Array-based representation of a binary decision tree.
 |  
 |  The binary tree is represented as a number of parallel arrays. The i-th
 |  element of each array holds information about the node `i`. Node 0 is the
 |  tree's root. You can find a detailed description of all arrays in
 |  `_tree.pxd`. NOTE: Some of the arrays only apply to either leaves or split
 |  nodes, resp. In this case the values of nodes of the other type are
 |  arbitrary!
 |  
 |  Attributes
 |  ----------
 |  node_count : int
 |      The number of nodes (internal nodes + leaves) in the tree.
 |  
 |  capacity : int
 |      The current capacity (i.e., size) of the arrays, which is at least as
 |      great as `node_count`.
 |  
 |  max_depth : int
 |      The maximal depth of the tree.
 |  
 |  children_left : array of int, shape [node_count]
 |      children_left[i] holds the node id of the left child of node i.
 |      For leaves, children_left[i] == TREE_LEAF. Otherwise,
 |      children_left[i] > i. This child handles the case where
 |      X[:, feature[i]] <= threshold[i].
 |  
 |  children_right : array of int, shape [node_count]
 |      children_right[i] holds the node id of the right child of node i.
 |      For leaves, children_right[i] == TREE_LEAF. Otherwise,
 |      children_right[i] > i. This child handles the case where
 |      X[:, feature[i]] > threshold[i].
 |  
 |  feature : array of int, shape [node_count]
 |      feature[i] holds the feature to split on, for the internal node i.
 |  
 |  threshold : array of double, shape [node_count]
 |      threshold[i] holds the threshold for the internal node i.
 |  
 |  value : array of double, shape [node_count, n_outputs, max_n_classes]
 |      Contains the constant prediction value of each node.
 |  
 |  impurity : array of double, shape [node_count]
 |      impurity[i] holds the impurity (i.e., the value of the splitting
 |      criterion) at node i.
 |  
 |  n_node_samples : array of int, shape [node_count]
 |      n_node_samples[i] holds the number of training samples reaching node i.
 |  
 |  weighted_n_node_samples : array of int, shape [node_count]
 |      weighted_n_node_samples[i] holds the weighted number of training samples
 |      reaching node i.
 |  
 |  Methods defined here:
 |  
 |  __getstate__(...)
 |      Getstate re-implementation, for pickling.
 |  
 |  __reduce__(...)
 |      Reduce re-implementation, for pickling.
 |  
 |  __setstate__(...)
 |      Setstate re-implementation, for unpickling.
 |  
 |  apply(...)
 |      Finds the terminal region (=leaf node) for each sample in X.
 |  
 |  compute_feature_importances(...)
 |      Computes the importance of each feature (aka variable).
 |  
 |  predict(...)
 |      Predict target for X.
 |  
 |  ----------------------------------------------------------------------
 |  Data descriptors defined here:
 |  
 |  capacity
 |  
 |  children_left
 |  
 |  children_right
 |  
 |  feature
 |  
 |  impurity
 |  
 |  max_depth
 |  
 |  max_n_classes
 |  
 |  n_classes
 |  
 |  n_features
 |  
 |  n_node_samples
 |  
 |  n_outputs
 |  
 |  node_count
 |  
 |  threshold
 |  
 |  value
 |  
 |  weighted_n_node_samples
 |  
 |  ----------------------------------------------------------------------
 |  Data and other attributes defined here:
 |  
 |  __new__ = <built-in method __new__ of type object>
 |      T.__new__(S, ...) -> a new object with type S, a subtype of T
 |  
 |  __pyx_vtable__ = <capsule object NULL>

Let's take a look at each of these things... guess what each will return before you execute the cell.


In [15]:
dt.tree_.node_count


Out[15]:
191

In [16]:
dt.tree_.capacity


Out[16]:
191

In [17]:
dt.tree_.max_depth


Out[17]:
13

In [18]:
dt.tree_.children_left


Out[18]:
array([  1,   2,   3,   4,  -1,   6,  -1,   8,   9,  -1,  -1,  -1,  13,
        -1,  -1,  16,  -1,  18,  -1,  -1,  21,  22,  23,  24,  -1,  -1,
        27,  28,  -1,  -1,  31,  -1,  -1,  34,  35,  -1,  37,  38,  39,
        -1,  -1,  -1,  43,  -1,  -1,  46,  47,  -1,  49,  50,  51,  52,
        -1,  -1,  -1,  56,  -1,  -1,  -1,  60,  61,  62,  63,  64,  65,
        -1,  -1,  68,  -1,  -1,  -1,  -1,  -1,  74,  -1,  76,  -1,  -1,
        79,  80,  81,  -1,  -1,  84,  85,  86,  87,  -1,  -1,  -1,  91,
        92,  93,  94,  -1,  -1,  -1,  -1,  99,  -1, 101, 102,  -1, 104,
       105,  -1,  -1,  -1, 109, 110,  -1,  -1,  -1, 114,  -1,  -1, 117,
       118, 119, 120, 121,  -1,  -1,  -1,  -1, 126, 127,  -1, 129,  -1,
        -1, 132, 133, 134, 135,  -1,  -1,  -1, 139, 140,  -1,  -1,  -1,
       144,  -1, 146,  -1, 148, 149,  -1,  -1,  -1, 153, 154, 155, 156,
       157, 158,  -1, 160,  -1,  -1, 163,  -1, 165, 166, 167,  -1,  -1,
       170,  -1,  -1, 173,  -1,  -1, 176,  -1,  -1, 179, 180, 181,  -1,
        -1, 184,  -1,  -1,  -1, 188,  -1,  -1,  -1])

In [19]:
dt.tree_.children_right


Out[19]:
array([ 20,  15,  12,   5,  -1,   7,  -1,  11,  10,  -1,  -1,  -1,  14,
        -1,  -1,  17,  -1,  19,  -1,  -1,  78,  33,  26,  25,  -1,  -1,
        30,  29,  -1,  -1,  32,  -1,  -1,  45,  36,  -1,  42,  41,  40,
        -1,  -1,  -1,  44,  -1,  -1,  59,  48,  -1,  58,  55,  54,  53,
        -1,  -1,  -1,  57,  -1,  -1,  -1,  73,  72,  71,  70,  67,  66,
        -1,  -1,  69,  -1,  -1,  -1,  -1,  -1,  75,  -1,  77,  -1,  -1,
       116,  83,  82,  -1,  -1, 113,  90,  89,  88,  -1,  -1,  -1,  98,
        97,  96,  95,  -1,  -1,  -1,  -1, 100,  -1, 108, 103,  -1, 107,
       106,  -1,  -1,  -1, 112, 111,  -1,  -1,  -1, 115,  -1,  -1, 152,
       125, 124, 123, 122,  -1,  -1,  -1,  -1, 131, 128,  -1, 130,  -1,
        -1, 143, 138, 137, 136,  -1,  -1,  -1, 142, 141,  -1,  -1,  -1,
       145,  -1, 147,  -1, 151, 150,  -1,  -1,  -1, 190, 187, 178, 175,
       162, 159,  -1, 161,  -1,  -1, 164,  -1, 172, 169, 168,  -1,  -1,
       171,  -1,  -1, 174,  -1,  -1, 177,  -1,  -1, 186, 183, 182,  -1,
        -1, 185,  -1,  -1,  -1, 189,  -1,  -1,  -1])

In [20]:
dt.tree_.feature


Out[20]:
array([ 0,  0,  0,  0, -2,  0, -2,  0,  0, -2, -2, -2,  0, -2, -2,  0, -2,
        0, -2, -2,  0,  0,  0,  0, -2, -2,  0,  0, -2, -2,  0, -2, -2,  0,
        0, -2,  0,  0,  0, -2, -2, -2,  0, -2, -2,  0,  0, -2,  0,  0,  0,
        0, -2, -2, -2,  0, -2, -2, -2,  0,  0,  0,  0,  0,  0, -2, -2,  0,
       -2, -2, -2, -2, -2,  0, -2,  0, -2, -2,  0,  0,  0, -2, -2,  0,  0,
        0,  0, -2, -2, -2,  0,  0,  0,  0, -2, -2, -2, -2,  0, -2,  0,  0,
       -2,  0,  0, -2, -2, -2,  0,  0, -2, -2, -2,  0, -2, -2,  0,  0,  0,
        0,  0, -2, -2, -2, -2,  0,  0, -2,  0, -2, -2,  0,  0,  0,  0, -2,
       -2, -2,  0,  0, -2, -2, -2,  0, -2,  0, -2,  0,  0, -2, -2, -2,  0,
        0,  0,  0,  0,  0, -2,  0, -2, -2,  0, -2,  0,  0,  0, -2, -2,  0,
       -2, -2,  0, -2, -2,  0, -2, -2,  0,  0,  0, -2, -2,  0, -2, -2, -2,
        0, -2, -2, -2])

In [21]:
dt.tree_.threshold


Out[21]:
array([  1.11111116,   0.87837839,   0.59309304,   0.2102102 ,
        -2.        ,   0.3903904 ,  -2.        ,   0.52552551,
         0.47297299,  -2.        ,  -2.        ,  -2.        ,
         0.75825822,  -2.        ,  -2.        ,   0.91591591,
        -2.        ,   1.01351345,  -2.        ,  -2.        ,
         4.8948946 ,   1.5915916 ,   1.18618619,   1.14864874,
        -2.        ,  -2.        ,   1.30630636,   1.23123121,
        -2.        ,  -2.        ,   1.47147155,  -2.        ,
        -2.        ,   2.16966963,   1.61411405,  -2.        ,
         1.8993994 ,   1.81681681,   1.69669676,  -2.        ,
        -2.        ,  -2.        ,   2.04204202,  -2.        ,
        -2.        ,   3.69369364,   2.25975966,  -2.        ,
         3.43093085,   2.95045042,   2.63513517,   2.40990973,
        -2.        ,  -2.        ,  -2.        ,   3.25825834,
        -2.        ,  -2.        ,  -2.        ,   4.43693686,
         4.36186218,   4.29429436,   4.23423386,   3.97897887,
         3.93393373,  -2.        ,  -2.        ,   4.09909916,
        -2.        ,  -2.        ,  -2.        ,  -2.        ,
        -2.        ,   4.54954958,  -2.        ,   4.68468475,
        -2.        ,  -2.        ,   7.48498535,   5.23273277,
         5.11261272,  -2.        ,  -2.        ,   6.92942953,
         5.55555534,   5.45045042,   5.3303299 ,  -2.        ,
        -2.        ,  -2.        ,   5.82582569,   5.77327347,
         5.67567587,   5.59309292,  -2.        ,  -2.        ,
        -2.        ,  -2.        ,   5.90840816,  -2.        ,
         6.34384394,   6.02852821,  -2.        ,   6.23123121,
         6.11861849,  -2.        ,  -2.        ,  -2.        ,
         6.60660648,   6.48648643,  -2.        ,  -2.        ,
        -2.        ,   7.2597599 ,  -2.        ,  -2.        ,
        10.77327347,   8.40840912,   8.16066074,   7.9954958 ,
         7.77777767,  -2.        ,  -2.        ,  -2.        ,
        -2.        ,   8.84384346,   8.63363361,  -2.        ,
         8.72372341,  -2.        ,  -2.        ,  10.03753757,
         9.57958031,   9.19669724,   8.90390396,  -2.        ,
        -2.        ,  -2.        ,   9.7897892 ,   9.72222233,
        -2.        ,  -2.        ,  -2.        ,  10.26276302,
        -2.        ,  10.31531525,  -2.        ,  10.61561584,
        10.44294262,  -2.        ,  -2.        ,  -2.        ,
        14.59459496,  13.52102089,  12.70270252,  12.04204178,
        11.30630684,  10.90090179,  -2.        ,  11.08858871,
        -2.        ,  -2.        ,  11.40390396,  -2.        ,
        11.59159184,  11.46396446,  11.43393326,  -2.        ,
        -2.        ,  11.50900936,  -2.        ,  -2.        ,
        11.74924946,  -2.        ,  -2.        ,  12.4399395 ,
        -2.        ,  -2.        ,  13.31081009,  12.86786842,
        12.79279327,  -2.        ,  -2.        ,  13.07056999,
        -2.        ,  -2.        ,  -2.        ,  13.97897911,
        -2.        ,  -2.        ,  -2.        ])

In [22]:
dt.tree_.value


Out[22]:
array([[[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 1.0632149 ]],

       [[ 0.        ]],

       [[ 0.73599075]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 1.08155707]],

       [[ 0.88214462]],

       [[ 1.10624047]],

       [[ 0.        ]],

       [[ 1.46292898]],

       [[ 1.15520336]],

       [[ 0.        ]],

       [[ 0.59581346]],

       [[ 0.        ]],

       [[ 0.49014347]],

       [[ 0.57512417]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.36092103]],

       [[ 0.4052274 ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.13316036]],

       [[-0.63906798]],

       [[ 0.        ]],

       [[ 0.25073346]],

       [[ 0.26851766]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[-0.5631036 ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[-0.25586127]],

       [[-0.15549664]],

       [[ 0.22989407]],

       [[ 0.        ]],

       [[-0.28386537]],

       [[-0.22764543]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[-0.43459162]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[-0.69274519]],

       [[-0.79953867]],

       [[-0.99091452]],

       [[ 0.        ]],

       [[-0.54986492]],

       [[-0.80389053]],

       [[-0.92641722]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[-0.22948057]],

       [[-0.15949298]],

       [[ 0.        ]],

       [[-0.47897252]],

       [[-0.24029049]],

       [[ 0.16239866]],

       [[-0.62971262]],

       [[ 0.28157206]],

       [[ 0.        ]],

       [[-1.09668109]],

       [[ 0.        ]],

       [[ 0.07008881]],

       [[-0.39653549]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.42861796]],

       [[ 0.44714611]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.66235772]],

       [[ 0.96197784]],

       [[ 0.44294229]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.86462875]],

       [[ 0.87896554]],

       [[ 0.76942032]],

       [[ 1.5807009 ]],

       [[ 0.        ]],

       [[ 0.52810617]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.99316089]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.64740948]],

       [[ 0.93326817]],

       [[ 0.63890765]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.98889319]],

       [[ 0.9520065 ]],

       [[ 0.99930373]],

       [[ 0.        ]],

       [[ 0.25072853]],

       [[ 0.7724775 ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[-0.14549445]],

       [[ 0.11832986]],

       [[-0.3452118 ]],

       [[ 0.18066013]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[-0.71426387]],

       [[ 0.        ]],

       [[-0.34758045]],

       [[-0.16647792]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[-0.72371159]],

       [[-0.69011216]],

       [[-0.52157867]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[-1.08561379]],

       [[-1.18311966]],

       [[-0.78188738]],

       [[ 0.        ]],

       [[-0.23972908]],

       [[ 0.        ]],

       [[-0.8266832 ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[-0.42265743]],

       [[-0.55689144]],

       [[-0.39144637]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.11016135]],

       [[ 0.        ]],

       [[ 0.36468479]],

       [[ 0.26236952]],

       [[ 0.        ]],

       [[ 1.26837309]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.39533984]],

       [[ 0.45276868]],

       [[ 0.        ]],

       [[-0.00349366]],

       [[ 0.38537714]],

       [[ 0.        ]],

       [[ 0.45732052]],

       [[ 0.57083166]],

       [[ 0.        ]],

       [[-0.10863384]],

       [[ 0.25782098]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 0.        ]],

       [[ 1.05886266]],

       [[ 1.19469224]],

       [[ 0.        ]],

       [[ 0.80496054]],

       [[ 0.77677155]],

       [[ 0.33869142]],

       [[ 0.        ]],

       [[ 0.17474494]],

       [[-0.24889702]],

       [[-1.26798997]]])

In [23]:
np.round(dt.tree_.impurity, 2)


Out[23]:
array([ 0.46,  0.08,  0.04,  0.02,  0.  ,  0.02,  0.  ,  0.01,  0.01,
        0.  ,  0.  ,  0.  ,  0.02,  0.  , -0.  ,  0.  ,  0.  ,  0.  ,
        0.  ,  0.  ,  0.43,  0.18,  0.13,  0.  ,  0.  ,  0.  ,  0.14,
        0.15,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.14,  0.05,  0.  ,
        0.04,  0.04,  0.  ,  0.  , -0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
        0.16,  0.03,  0.  ,  0.02,  0.02,  0.02,  0.  ,  0.  , -0.  ,
       -0.  ,  0.02,  0.  ,  0.  ,  0.  ,  0.15,  0.09,  0.06,  0.04,
        0.01,  0.  ,  0.  , -0.  ,  0.01,  0.  ,  0.  , -0.  , -0.  ,
       -0.  ,  0.23,  0.  ,  0.05,  0.  ,  0.  ,  0.45,  0.09,  0.  ,
        0.  ,  0.  ,  0.08,  0.07,  0.05,  0.02,  0.  ,  0.  ,  0.  ,
        0.07,  0.12,  0.04,  0.05,  0.07,  0.  , -0.  , -0.  ,  0.03,
        0.  ,  0.02,  0.03,  0.  ,  0.02,  0.02,  0.  ,  0.  ,  0.  ,
        0.  ,  0.  ,  0.  ,  0.  , -0.  ,  0.07,  0.  ,  0.  ,  0.41,
        0.18,  0.2 ,  0.23,  0.27,  0.  ,  0.39,  0.  ,  0.  ,  0.08,
        0.05,  0.  ,  0.01,  0.  ,  0.  ,  0.08,  0.05,  0.01,  0.  ,
        0.  , -0.  , -0.  ,  0.03,  0.  ,  0.  , -0.  ,  0.  ,  0.04,
        0.  ,  0.03,  0.  ,  0.01,  0.  ,  0.  ,  0.  ,  0.  ,  0.3 ,
        0.17,  0.15,  0.11,  0.1 ,  0.01,  0.  ,  0.  ,  0.  , -0.  ,
        0.13,  0.  ,  0.03,  0.03,  0.  ,  0.  ,  0.  ,  0.04,  0.  ,
        0.  ,  0.  ,  0.  ,  0.  ,  0.03,  0.  ,  0.  ,  0.09,  0.03,
        0.  ,  0.  , -0.  ,  0.  ,  0.  , -0.  , -0.  ,  0.04,  0.  ,
        0.  ,  0.  ])

In [24]:
dt.tree_.n_node_samples


Out[24]:
array([100,  11,   8,   6,   2,   4,   1,   3,   2,   1,   1,   1,   2,
         1,   1,   3,   1,   2,   1,   1,  89,  29,   6,   2,   1,   1,
         4,   2,   1,   1,   2,   1,   1,  23,   6,   1,   5,   3,   2,
         1,   1,   1,   2,   1,   1,  17,   7,   1,   6,   5,   3,   2,
         1,   1,   1,   2,   1,   1,   1,  10,   7,   6,   5,   4,   2,
         1,   1,   2,   1,   1,   1,   1,   1,   3,   1,   2,   1,   1,
        60,  21,   2,   1,   1,  19,  17,   3,   2,   1,   1,   1,  14,
         5,   4,   3,   2,   1,   1,   1,   9,   1,   8,   5,   2,   3,
         2,   1,   1,   1,   3,   2,   1,   1,   1,   2,   1,   1,  39,
        19,   5,   4,   3,   1,   2,   1,   1,  14,   3,   1,   2,   1,
         1,  11,   6,   3,   2,   1,   1,   1,   3,   2,   1,   1,   1,
         5,   1,   4,   1,   3,   2,   1,   1,   1,  20,  19,  17,  12,
        10,   3,   1,   2,   1,   1,   7,   1,   6,   4,   2,   1,   1,
         2,   1,   1,   2,   1,   1,   2,   1,   1,   5,   4,   2,   1,
         1,   2,   1,   1,   1,   2,   1,   1,   1])

In [25]:
dt.tree_.weighted_n_node_samples


Out[25]:
array([ 100.,   11.,    8.,    6.,    2.,    4.,    1.,    3.,    2.,
          1.,    1.,    1.,    2.,    1.,    1.,    3.,    1.,    2.,
          1.,    1.,   89.,   29.,    6.,    2.,    1.,    1.,    4.,
          2.,    1.,    1.,    2.,    1.,    1.,   23.,    6.,    1.,
          5.,    3.,    2.,    1.,    1.,    1.,    2.,    1.,    1.,
         17.,    7.,    1.,    6.,    5.,    3.,    2.,    1.,    1.,
          1.,    2.,    1.,    1.,    1.,   10.,    7.,    6.,    5.,
          4.,    2.,    1.,    1.,    2.,    1.,    1.,    1.,    1.,
          1.,    3.,    1.,    2.,    1.,    1.,   60.,   21.,    2.,
          1.,    1.,   19.,   17.,    3.,    2.,    1.,    1.,    1.,
         14.,    5.,    4.,    3.,    2.,    1.,    1.,    1.,    9.,
          1.,    8.,    5.,    2.,    3.,    2.,    1.,    1.,    1.,
          3.,    2.,    1.,    1.,    1.,    2.,    1.,    1.,   39.,
         19.,    5.,    4.,    3.,    1.,    2.,    1.,    1.,   14.,
          3.,    1.,    2.,    1.,    1.,   11.,    6.,    3.,    2.,
          1.,    1.,    1.,    3.,    2.,    1.,    1.,    1.,    5.,
          1.,    4.,    1.,    3.,    2.,    1.,    1.,    1.,   20.,
         19.,   17.,   12.,   10.,    3.,    1.,    2.,    1.,    1.,
          7.,    1.,    6.,    4.,    2.,    1.,    1.,    2.,    1.,
          1.,    2.,    1.,    1.,    2.,    1.,    1.,    5.,    4.,
          2.,    1.,    1.,    2.,    1.,    1.,    1.,    2.,    1.,
          1.,    1.])

In [26]:
# how can min samples help?
dt = sklearn.tree.DecisionTreeRegressor(max_depth=4, splitter='best')
dt.fit(X_train[::-1], -y_train[::-1])
y_pred = dt.predict(X_true)

In [27]:
plt.plot(x_true, y_true, '-', label='Truth')
plt.plot(x_train, y_train, 's', label='Train')
plt.plot(x_true, -y_pred, '-', label='Predicted')
plt.legend()


Out[27]:
<matplotlib.legend.Legend at 0x7fcbfe644110>

Time to refactor that?

Yes, if we are going to keep experimenting with it. But perhaps we will be moving on a little bit.


In [28]:
def experiment_w_dt_options(max_depth=None, min_samples_leaf=1):
    dt = sklearn.tree.DecisionTreeRegressor(max_depth=max_depth, min_samples_leaf=min_samples_leaf)
    dt.fit(X_train[::-1], -y_train[::-1])
    y_pred = dt.predict(X_true)
    
    plt.plot(x_true, y_true, '-', label='Truth')
    plt.plot(x_train, y_train, 's', label='Train', ms=4)
    plt.plot(x_true, -y_pred, '-', label='Predicted')

In [ ]:
i = 0
for max_depth in [8, 4, 2]:
    for min_samples_leaf in [1,10,25]:
        i += 1
        plt.subplot(3,3,i)
        experiment_w_dt_options(max_depth, min_samples_leaf)
        plt.text(0,2,'max_depth: %d, min_samples_left: %d'%(max_depth, min_samples_leaf), va='top')
plt.legend(loc=(1,.1))


Out[ ]:
<matplotlib.legend.Legend at 0x7fb5fd101450>

describing the contents of a tree, with recursion


In [ ]:
# the next line is a tricky little trick to get some tab-completion help
t = dt.tree_

In [ ]:
# here is a tricky python thing about strings:
'abie'*0, 'abie'*10


Out[ ]:
('', 'abieabieabieabieabieabieabieabieabieabie')

In [30]:
def print_tree(t, root=0, depth=0):
    """ print the contents of a decision tree
    as a python function
    
    parameters
    t : sklearn.tree.tree_.Tree
    root : int, optional - where to start tree
    depth : int, optional - how deep we are in orig tree
    """
    indent = '    '*depth
    
    left_child = t.children_left[root]
    right_child = t.children_right[root]
    
    if left_child == sklearn.tree._tree.TREE_LEAF:  # magic number is -1
        print indent + 'return %.2f # (node %d)' % (t.value[root], root)
    else:
        print indent + 'if X_i[%d] < %.2f: # (node %d)' % (t.feature[root], t.threshold[root], root)
        print_tree(t, root=left_child, depth=depth+1)
        
        print indent + 'else:'
        print_tree(t, root=right_child, depth=depth+1)
    
print_tree(dt.tree_)


if X_i[0] < 1.11: # (node 0)
    return 0.93 # (node 1)
else:
    if X_i[0] < 4.89: # (node 2)
        if X_i[0] < 1.90: # (node 3)
            return 0.00 # (node 4)
        else:
            return -0.44 # (node 5)
    else:
        if X_i[0] < 7.48: # (node 6)
            if X_i[0] < 5.91: # (node 7)
                return 0.77 # (node 8)
            else:
                return 0.82 # (node 9)
        else:
            if X_i[0] < 10.77: # (node 10)
                return -0.46 # (node 11)
            else:
                if X_i[0] < 12.04: # (node 12)
                    return 0.43 # (node 13)
                else:
                    return 0.30 # (node 14)

Can you draw a tree version of that?

No, it has too many nodes...

Is it right?

If I tried to do it, it probably would not be...

pruning


In [29]:
dt = sklearn.tree.DecisionTreeRegressor(min_samples_leaf=10)
dt.fit(X_train, y_train)

# make a copy of the decision tree regressor dt (called pt, "p" for pruned)
pt = sklearn.tree.DecisionTreeRegressor(min_samples_leaf=10)
pt.fit(X_train, y_train)


Out[29]:
DecisionTreeRegressor(compute_importances=None, criterion='mse',
           max_depth=None, max_features=None, max_leaf_nodes=None,
           min_density=None, min_samples_leaf=10, min_samples_split=2,
           random_state=None, splitter='best')

In [35]:
print_tree(pt.tree_)


if X_i[0] < 1.11: # (node 0)
    return 0.93 # (node 1)
else:
    if X_i[0] < 4.89: # (node 2)
        if X_i[0] < 1.90: # (node 3)
            return 0.00 # (node 4)
        else:
            return -0.44 # (node 5)
    else:
        if X_i[0] < 7.48: # (node 6)
            if X_i[0] < 5.91: # (node 7)
                return 0.77 # (node 8)
            else:
                return 0.82 # (node 9)
        else:
            if X_i[0] < 10.77: # (node 10)
                return -0.46 # (node 11)
            else:
                return 0.36 # (node 12)

In [32]:
# have a look at node 12
pt.tree_.value[12]


Out[32]:
array([[ 0.]])

In [33]:
# need to set value at soon-to-be leaf node
pt.tree_.value[12] = 1  # NOTE: this is not the right value

In [34]:
# find the left and right children of node 12
left_child = pt.tree_.children_left[12]
right_child = pt.tree_.children_right[12]

# find the weight of these nodes in the training dataset
wt_left = pt.tree_.weighted_n_node_samples[left_child]
wt_right = pt.tree_.weighted_n_node_samples[right_child]

# find the value of these nodes in the training dataset
val_left = pt.tree_.value[left_child]
val_right = pt.tree_.value[right_child]

# calculate the value of node 12 after pruning
pt.tree_.value[12] = (wt_left*val_left + wt_right*val_right) / (wt_left + wt_right)


pt.tree_.children_left[12] = sklearn.tree._tree.TREE_LEAF
pt.tree_.children_right[12] = sklearn.tree._tree.TREE_LEAF

In [36]:
# have a look at the original tree compared to the pruned version
y_pred = dt.predict(X_true)
plt.plot(x_true, y_pred, '-', label='Original Pred')

y_pred = pt.predict(X_true)
plt.plot(x_true, y_pred, '-', label='Pruned Pred')

#plt.plot(x_train, y_train, 's', label='Train')

plt.legend()


Out[36]:
<matplotlib.legend.Legend at 0x7fcbfcd71d10>

Another look at bounded-depth

We will use our skills from last week's class to sweep over a range of depth values, and see which is best


In [ ]:
d_vals = [1,2,4,8,16]

# initialize rmse dict
rmse = {}
for d in d_vals:
    rmse[d] = []

# 10 repetitions of 10-fold cross-validation
for rep in range(10):
    cv = sklearn.cross_validation.KFold(len(y_train), n_folds=10, shuffle=True)

    for train, validate in cv:
        for d in d_vals:
            dt = sklearn.tree.DecisionTreeRegressor(max_depth=d)
            dt.fit(X_train[train], y_train[train])

            y_pred = dt.predict(X_train[validate])

            rmse[d].append(np.sqrt(np.mean((y_pred - y_train[validate])**2)))

In [ ]:
pd.DataFrame(rmse)

In [ ]:
pd.DataFrame(rmse).mean().plot(marker='s')


Out[ ]:
<matplotlib.axes.AxesSubplot at 0x7fb5fc8bbf50>

In [ ]:
dt = sklearn.tree.DecisionTreeRegressor(max_depth=4)
dt.fit(X_train, y_train)
print_tree(dt.tree_)


if X_i[0] < 1.11: # (node 0)
    if X_i[0] < 0.88: # (node 1)
        if X_i[0] < 0.59: # (node 2)
            if X_i[0] < 0.21: # (node 3)
                return 1.06 # (node 4)
            else:
                return 0.95 # (node 5)
        else:
            if X_i[0] < 0.76: # (node 6)
                return 1.46 # (node 7)
            else:
                return 1.16 # (node 8)
    else:
        if X_i[0] < 0.92: # (node 9)
            return 0.60 # (node 10)
        else:
            if X_i[0] < 1.01: # (node 11)
                return 0.49 # (node 12)
            else:
                return 0.58 # (node 13)
else:
    if X_i[0] < 4.89: # (node 14)
        if X_i[0] < 1.59: # (node 15)
            if X_i[0] < 1.19: # (node 16)
                return 0.38 # (node 17)
            else:
                return 0.00 # (node 18)
        else:
            if X_i[0] < 2.17: # (node 19)
                return -0.21 # (node 20)
            else:
                return -0.47 # (node 21)
    else:
        if X_i[0] < 7.48: # (node 22)
            if X_i[0] < 5.23: # (node 23)
                return 0.44 # (node 24)
            else:
                return 0.83 # (node 25)
        else:
            if X_i[0] < 10.77: # (node 26)
                return -0.46 # (node 27)
            else:
                return 0.36 # (node 28)

In [ ]:
dt = sklearn.tree.DecisionTreeRegressor(max_depth=2)
dt.fit(X_train, y_train)
print_tree(dt.tree_)


if X_i[0] < 1.11: # (node 0)
    if X_i[0] < 0.88: # (node 1)
        return 1.07 # (node 2)
    else:
        return 0.55 # (node 3)
else:
    if X_i[0] < 4.89: # (node 4)
        return -0.29 # (node 5)
    else:
        return 0.25 # (node 6)

An aside: bootstrap + decision trees = good


In [ ]:
dt = sklearn.tree.DecisionTreeRegressor(max_depth=4)

y_pred = {}
for rep in range(500):
    train = np.random.choice(range(len(y_train)), size=len(y_train))
    
    dt.fit(X_train[train], y_train[train])
    y_pred[rep] = dt.predict(X_true)

In [ ]:
y_pred = pd.DataFrame(y_pred)
y_pred = y_pred.mean(axis=1)

In [ ]:
plt.plot(x_true, y_true, '-', label='Truth')
plt.plot(x_train, y_train, 's', label='Train')
plt.plot(x_true, y_pred, label='Mean of Bootstrapped Prediction')
plt.legend()


Out[ ]:
<matplotlib.legend.Legend at 0x7fb5fc784890>

One missing piece: measuring split quality


In [ ]:
# here are the criteria for regression tree quality that sklearn knows
sklearn.tree.tree.CRITERIA_REG


Out[ ]:
{'friedman_mse': sklearn.tree._tree.FriedmanMSE, 'mse': sklearn.tree._tree.MSE}

In [ ]:
# here is a super-tricky way to modify the print_tree function
# so that is includes the impurity

# can you understand how it works?

old_print_tree = print_tree

def print_tree(t, root=0, depth=0):
    indent = '    '*depth
    print indent + '# node %d: impurity = %.2f' % (root, t.impurity[root])

    old_print_tree(t, root, depth)
    
print_tree(dt.tree_)


# node 0: impurity = 0.39
if X_i[0] < 0.74: # (node 0)
    # node 1: impurity = 0.02
    if X_i[0] < 0.21: # (node 1)
        # node 2: impurity = 0.00
        return 1.06 # (node 2)
    else:
        # node 3: impurity = 0.02
        if X_i[0] < 0.39: # (node 3)
            # node 4: impurity = 0.00
            return 0.74 # (node 4)
        else:
            # node 5: impurity = 0.01
            if X_i[0] < 0.47: # (node 5)
                # node 6: impurity = 0.00
                return 1.08 # (node 6)
            else:
                # node 7: impurity = 0.01
                return 0.94 # (node 7)
else:
    # node 8: impurity = 0.36
    if X_i[0] < 10.83: # (node 8)
        # node 9: impurity = 0.36
        if X_i[0] < 7.78: # (node 9)
            # node 10: impurity = 0.28
            if X_i[0] < 4.89: # (node 10)
                # node 11: impurity = 0.14
                return -0.07 # (node 11)
            else:
                # node 12: impurity = 0.12
                return 0.74 # (node 12)
        else:
            # node 13: impurity = 0.08
            if X_i[0] < 10.01: # (node 13)
                # node 14: impurity = 0.07
                return -0.78 # (node 14)
            else:
                # node 15: impurity = 0.02
                return -0.44 # (node 15)
    else:
        # node 16: impurity = 0.14
        if X_i[0] < 13.31: # (node 16)
            # node 17: impurity = 0.15
            if X_i[0] < 12.30: # (node 17)
                # node 18: impurity = 0.09
                return 0.45 # (node 18)
            else:
                # node 19: impurity = 0.03
                return 1.08 # (node 19)
        else:
            # node 20: impurity = 0.00
            if X_i[0] < 13.52: # (node 20)
                # node 21: impurity = 0.00
                return 0.34 # (node 21)
            else:
                # node 22: impurity = -0.00
                return 0.17 # (node 22)

In [46]:
# here is a less-tricky way to do the same thing
# still tricky, since it uses recursion

def print_tree(t, root=0, depth=0):
    indent = '    '*depth
    print indent + '# node %s: impurity = %.2f' % (str(root), t.impurity[root])
    left_child = t.children_left[root]
    right_child = t.children_right[root]
    
    if left_child == sklearn.tree._tree.TREE_LEAF:
        print indent + 'return %s # (node %d)' % (str(t.value[root]), root)
    else:
        print indent + 'if X_i[%d] < %.2f: # (node %d)' % (t.feature[root], t.threshold[root], root)
        print_tree(t, root=left_child, depth=depth+1)
        
        print indent + 'else:'
        print_tree(t,root=right_child, depth=depth+1)
    
print_tree(dt.tree_)


# node 0: impurity = 0.48
if X_i[0] < 1.11: # (node 0)
    # node 1: impurity = 0.08
    if X_i[0] < 0.88: # (node 1)
        # node 2: impurity = 0.04
        if X_i[0] < 0.59: # (node 2)
            # node 3: impurity = 0.02
            if X_i[0] < 0.21: # (node 3)
                # node 4: impurity = 0.00
                return [[ 1.0632149]] # (node 4)
            else:
                # node 5: impurity = 0.02
                if X_i[0] < 0.39: # (node 5)
                    # node 6: impurity = 0.00
                    return [[ 0.73599075]] # (node 6)
                else:
                    # node 7: impurity = 0.01
                    if X_i[0] < 0.53: # (node 7)
                        # node 8: impurity = 0.01
                        if X_i[0] < 0.47: # (node 8)
                            # node 9: impurity = 0.00
                            return [[ 1.08155707]] # (node 9)
                        else:
                            # node 10: impurity = 0.00
                            return [[ 0.88214462]] # (node 10)
                    else:
                        # node 11: impurity = 0.00
                        return [[ 1.10624047]] # (node 11)
        else:
            # node 12: impurity = 0.02
            if X_i[0] < 0.76: # (node 12)
                # node 13: impurity = 0.00
                return [[ 1.46292898]] # (node 13)
            else:
                # node 14: impurity = -0.00
                return [[ 1.15520336]] # (node 14)
    else:
        # node 15: impurity = 0.00
        if X_i[0] < 0.92: # (node 15)
            # node 16: impurity = 0.00
            return [[ 0.59581346]] # (node 16)
        else:
            # node 17: impurity = 0.00
            if X_i[0] < 1.01: # (node 17)
                # node 18: impurity = 0.00
                return [[ 0.49014347]] # (node 18)
            else:
                # node 19: impurity = 0.00
                return [[ 0.57512417]] # (node 19)
else:
    # node 20: impurity = 0.43
    if X_i[0] < 4.89: # (node 20)
        # node 21: impurity = 0.19
        if X_i[0] < 1.59: # (node 21)
            # node 22: impurity = 0.13
            if X_i[0] < 1.19: # (node 22)
                # node 23: impurity = 0.00
                if X_i[0] < 1.15: # (node 23)
                    # node 24: impurity = 0.00
                    return [[ 0.36092103]] # (node 24)
                else:
                    # node 25: impurity = 0.00
                    return [[ 0.4052274]] # (node 25)
            else:
                # node 26: impurity = 0.14
                if X_i[0] < 1.31: # (node 26)
                    # node 27: impurity = 0.15
                    if X_i[0] < 1.23: # (node 27)
                        # node 28: impurity = 0.00
                        return [[ 0.13316036]] # (node 28)
                    else:
                        # node 29: impurity = 0.00
                        return [[-0.63906798]] # (node 29)
                else:
                    # node 30: impurity = 0.00
                    if X_i[0] < 1.47: # (node 30)
                        # node 31: impurity = 0.00
                        return [[ 0.25073346]] # (node 31)
                    else:
                        # node 32: impurity = 0.00
                        return [[ 0.26851766]] # (node 32)
        else:
            # node 33: impurity = 0.15
            if X_i[0] < 3.69: # (node 33)
                # node 34: impurity = 0.12
                if X_i[0] < 2.26: # (node 34)
                    # node 35: impurity = 0.06
                    if X_i[0] < 1.61: # (node 35)
                        # node 36: impurity = 0.00
                        return [[-0.5631036]] # (node 36)
                    else:
                        # node 37: impurity = 0.05
                        if X_i[0] < 1.90: # (node 37)
                            # node 38: impurity = 0.04
                            if X_i[0] < 1.82: # (node 38)
                                # node 39: impurity = 0.00
                                if X_i[0] < 1.70: # (node 39)
                                    # node 40: impurity = 0.00
                                    return [[-0.25586127]] # (node 40)
                                else:
                                    # node 41: impurity = -0.00
                                    return [[-0.15549664]] # (node 41)
                            else:
                                # node 42: impurity = 0.00
                                return [[ 0.22989407]] # (node 42)
                        else:
                            # node 43: impurity = 0.01
                            if X_i[0] < 2.06: # (node 43)
                                # node 44: impurity = 0.00
                                return [[-0.28386537]] # (node 44)
                            else:
                                # node 45: impurity = 0.00
                                return [[-0.43459162]] # (node 45)
                else:
                    # node 46: impurity = 0.02
                    if X_i[0] < 3.43: # (node 46)
                        # node 47: impurity = 0.02
                        if X_i[0] < 2.95: # (node 47)
                            # node 48: impurity = 0.02
                            if X_i[0] < 2.64: # (node 48)
                                # node 49: impurity = 0.00
                                if X_i[0] < 2.41: # (node 49)
                                    # node 50: impurity = 0.00
                                    return [[-0.69274519]] # (node 50)
                                else:
                                    # node 51: impurity = -0.00
                                    return [[-0.79953867]] # (node 51)
                            else:
                                # node 52: impurity = -0.00
                                return [[-0.99091452]] # (node 52)
                        else:
                            # node 53: impurity = 0.02
                            if X_i[0] < 3.26: # (node 53)
                                # node 54: impurity = 0.00
                                return [[-0.54986492]] # (node 54)
                            else:
                                # node 55: impurity = 0.00
                                return [[-0.80389053]] # (node 55)
                    else:
                        # node 56: impurity = 0.00
                        return [[-0.92641722]] # (node 56)
            else:
                # node 57: impurity = 0.15
                if X_i[0] < 4.44: # (node 57)
                    # node 58: impurity = 0.09
                    if X_i[0] < 4.36: # (node 58)
                        # node 59: impurity = 0.06
                        if X_i[0] < 4.29: # (node 59)
                            # node 60: impurity = 0.04
                            if X_i[0] < 4.23: # (node 60)
                                # node 61: impurity = 0.01
                                if X_i[0] < 3.98: # (node 61)
                                    # node 62: impurity = 0.00
                                    if X_i[0] < 3.93: # (node 62)
                                        # node 63: impurity = 0.00
                                        return [[-0.22948057]] # (node 63)
                                    else:
                                        # node 64: impurity = -0.00
                                        return [[-0.15949298]] # (node 64)
                                else:
                                    # node 65: impurity = 0.01
                                    if X_i[0] < 4.10: # (node 65)
                                        # node 66: impurity = 0.00
                                        return [[-0.47897252]] # (node 66)
                                    else:
                                        # node 67: impurity = 0.00
                                        return [[-0.24029049]] # (node 67)
                            else:
                                # node 68: impurity = -0.00
                                return [[ 0.16239866]] # (node 68)
                        else:
                            # node 69: impurity = -0.00
                            return [[-0.62971262]] # (node 69)
                    else:
                        # node 70: impurity = -0.00
                        return [[ 0.28157206]] # (node 70)
                else:
                    # node 71: impurity = 0.23
                    if X_i[0] < 4.55: # (node 71)
                        # node 72: impurity = 0.00
                        return [[-1.09668109]] # (node 72)
                    else:
                        # node 73: impurity = 0.05
                        if X_i[0] < 4.68: # (node 73)
                            # node 74: impurity = 0.00
                            return [[ 0.07008881]] # (node 74)
                        else:
                            # node 75: impurity = 0.00
                            return [[-0.39653549]] # (node 75)
    else:
        # node 76: impurity = 0.47
        if X_i[0] < 7.58: # (node 76)
            # node 77: impurity = 0.09
            if X_i[0] < 5.23: # (node 77)
                # node 78: impurity = 0.00
                if X_i[0] < 5.11: # (node 78)
                    # node 79: impurity = 0.00
                    return [[ 0.42861796]] # (node 79)
                else:
                    # node 80: impurity = 0.00
                    return [[ 0.44714611]] # (node 80)
            else:
                # node 81: impurity = 0.09
                if X_i[0] < 6.93: # (node 81)
                    # node 82: impurity = 0.08
                    if X_i[0] < 5.56: # (node 82)
                        # node 83: impurity = 0.05
                        if X_i[0] < 5.45: # (node 83)
                            # node 84: impurity = 0.02
                            if X_i[0] < 5.33: # (node 84)
                                # node 85: impurity = 0.00
                                return [[ 0.66235772]] # (node 85)
                            else:
                                # node 86: impurity = 0.00
                                return [[ 0.96197784]] # (node 86)
                        else:
                            # node 87: impurity = 0.00
                            return [[ 0.44294229]] # (node 87)
                    else:
                        # node 88: impurity = 0.07
                        if X_i[0] < 6.03: # (node 88)
                            # node 89: impurity = 0.08
                            if X_i[0] < 5.77: # (node 89)
                                # node 90: impurity = 0.04
                                if X_i[0] < 5.68: # (node 90)
                                    # node 91: impurity = 0.05
                                    if X_i[0] < 5.59: # (node 91)
                                        # node 92: impurity = 0.07
                                        return [[ 0.86462875]] # (node 92)
                                    else:
                                        # node 93: impurity = 0.00
                                        return [[ 0.87896554]] # (node 93)
                                else:
                                    # node 94: impurity = 0.00
                                    return [[ 0.76942032]] # (node 94)
                            else:
                                # node 95: impurity = 0.08
                                if X_i[0] < 5.89: # (node 95)
                                    # node 96: impurity = 0.00
                                    return [[ 1.5807009]] # (node 96)
                                else:
                                    # node 97: impurity = 0.00
                                    return [[ 0.99316089]] # (node 97)
                        else:
                            # node 98: impurity = 0.03
                            if X_i[0] < 6.46: # (node 98)
                                # node 99: impurity = 0.02
                                if X_i[0] < 6.23: # (node 99)
                                    # node 100: impurity = 0.02
                                    if X_i[0] < 6.12: # (node 100)
                                        # node 101: impurity = 0.00
                                        return [[ 0.64740948]] # (node 101)
                                    else:
                                        # node 102: impurity = 0.00
                                        return [[ 0.93326817]] # (node 102)
                                else:
                                    # node 103: impurity = 0.00
                                    return [[ 0.63890765]] # (node 103)
                            else:
                                # node 104: impurity = 0.00
                                return [[ 0.99930373]] # (node 104)
                else:
                    # node 105: impurity = 0.07
                    if X_i[0] < 7.26: # (node 105)
                        # node 106: impurity = 0.00
                        return [[ 0.25072853]] # (node 106)
                    else:
                        # node 107: impurity = 0.00
                        return [[ 0.7724775]] # (node 107)
        else:
            # node 108: impurity = 0.42
            if X_i[0] < 10.83: # (node 108)
                # node 109: impurity = 0.11
                if X_i[0] < 8.84: # (node 109)
                    # node 110: impurity = 0.08
                    if X_i[0] < 8.41: # (node 110)
                        # node 111: impurity = 0.09
                        if X_i[0] < 8.16: # (node 111)
                            # node 112: impurity = 0.01
                            if X_i[0] < 8.00: # (node 112)
                                # node 113: impurity = 0.00
                                return [[-0.50464141]] # (node 113)
                            else:
                                # node 114: impurity = -0.00
                                return [[-0.3452118]] # (node 114)
                        else:
                            # node 115: impurity = -0.00
                            return [[ 0.18066013]] # (node 115)
                    else:
                        # node 116: impurity = 0.05
                        if X_i[0] < 8.63: # (node 116)
                            # node 117: impurity = 0.00
                            return [[-0.71426387]] # (node 117)
                        else:
                            # node 118: impurity = 0.01
                            if X_i[0] < 8.72: # (node 118)
                                # node 119: impurity = 0.00
                                return [[-0.34758045]] # (node 119)
                            else:
                                # node 120: impurity = 0.00
                                return [[-0.16647792]] # (node 120)
                else:
                    # node 121: impurity = 0.08
                    if X_i[0] < 10.04: # (node 121)
                        # node 122: impurity = 0.05
                        if X_i[0] < 9.58: # (node 122)
                            # node 123: impurity = 0.01
                            if X_i[0] < 9.20: # (node 123)
                                # node 124: impurity = 0.00
                                if X_i[0] < 8.90: # (node 124)
                                    # node 125: impurity = 0.00
                                    return [[-0.72371159]] # (node 125)
                                else:
                                    # node 126: impurity = -0.00
                                    return [[-0.69011216]] # (node 126)
                            else:
                                # node 127: impurity = -0.00
                                return [[-0.52157867]] # (node 127)
                        else:
                            # node 128: impurity = 0.03
                            if X_i[0] < 9.79: # (node 128)
                                # node 129: impurity = 0.00
                                if X_i[0] < 9.72: # (node 129)
                                    # node 130: impurity = 0.00
                                    return [[-1.08561379]] # (node 130)
                                else:
                                    # node 131: impurity = -0.00
                                    return [[-1.18311966]] # (node 131)
                            else:
                                # node 132: impurity = 0.00
                                return [[-0.78188738]] # (node 132)
                    else:
                        # node 133: impurity = 0.01
                        if X_i[0] < 10.31: # (node 133)
                            # node 134: impurity = 0.00
                            return [[-0.23972908]] # (node 134)
                        else:
                            # node 135: impurity = 0.01
                            if X_i[0] < 10.62: # (node 135)
                                # node 136: impurity = 0.00
                                if X_i[0] < 10.44: # (node 136)
                                    # node 137: impurity = 0.00
                                    return [[-0.42265743]] # (node 137)
                                else:
                                    # node 138: impurity = 0.00
                                    return [[-0.55689144]] # (node 138)
                            else:
                                # node 139: impurity = 0.00
                                return [[-0.39144637]] # (node 139)
            else:
                # node 140: impurity = 0.33
                if X_i[0] < 14.59: # (node 140)
                    # node 141: impurity = 0.18
                    if X_i[0] < 13.52: # (node 141)
                        # node 142: impurity = 0.16
                        if X_i[0] < 12.70: # (node 142)
                            # node 143: impurity = 0.12
                            if X_i[0] < 11.46: # (node 143)
                                # node 144: impurity = 0.13
                                if X_i[0] < 11.31: # (node 144)
                                    # node 145: impurity = 0.00
                                    if X_i[0] < 11.09: # (node 145)
                                        # node 146: impurity = 0.00
                                        return [[ 0.36468479]] # (node 146)
                                    else:
                                        # node 147: impurity = -0.00
                                        return [[ 0.26236952]] # (node 147)
                                else:
                                    # node 148: impurity = 0.16
                                    if X_i[0] < 11.40: # (node 148)
                                        # node 149: impurity = 0.00
                                        return [[ 1.26837309]] # (node 149)
                                    else:
                                        # node 150: impurity = 0.00
                                        if X_i[0] < 11.43: # (node 150)
                                            # node 151: impurity = 0.00
                                            return [[ 0.39533984]] # (node 151)
                                        else:
                                            # node 152: impurity = 0.00
                                            return [[ 0.45276868]] # (node 152)
                            else:
                                # node 153: impurity = 0.06
                                if X_i[0] < 12.04: # (node 153)
                                    # node 154: impurity = 0.06
                                    if X_i[0] < 11.51: # (node 154)
                                        # node 155: impurity = 0.00
                                        return [[-0.00349366]] # (node 155)
                                    else:
                                        # node 156: impurity = 0.01
                                        if X_i[0] < 11.70: # (node 156)
                                            # node 157: impurity = 0.00
                                            return [[ 0.38537714]] # (node 157)
                                        else:
                                            # node 158: impurity = 0.00
                                            return [[ 0.57083166]] # (node 158)
                                else:
                                    # node 159: impurity = 0.03
                                    if X_i[0] < 12.44: # (node 159)
                                        # node 160: impurity = 0.00
                                        return [[-0.10863384]] # (node 160)
                                    else:
                                        # node 161: impurity = 0.00
                                        return [[ 0.25782098]] # (node 161)
                        else:
                            # node 162: impurity = 0.11
                            if X_i[0] < 13.14: # (node 162)
                                # node 163: impurity = 0.03
                                if X_i[0] < 12.87: # (node 163)
                                    # node 164: impurity = 0.00
                                    if X_i[0] < 12.79: # (node 164)
                                        # node 165: impurity = 0.00
                                        return [[ 1.05886266]] # (node 165)
                                    else:
                                        # node 166: impurity = -0.00
                                        return [[ 1.19469224]] # (node 166)
                                else:
                                    # node 167: impurity = -0.00
                                    return [[ 0.80496054]] # (node 167)
                            else:
                                # node 168: impurity = -0.00
                                return [[ 0.33869142]] # (node 168)
                    else:
                        # node 169: impurity = 0.04
                        if X_i[0] < 13.98: # (node 169)
                            # node 170: impurity = 0.00
                            return [[ 0.17474494]] # (node 170)
                        else:
                            # node 171: impurity = 0.00
                            return [[-0.24889702]] # (node 171)
                else:
                    # node 172: impurity = -0.00
                    return [[-1.26798997]] # (node 172)

The mean squared error (MSE) criteria is realitively straight-forward:


In [ ]:
dt.criterion


Out[ ]:
'mse'

In [ ]:
y_train.mean()


Out[ ]:
0.16938346878102944

In [ ]:
np.mean((y_train - y_train.mean())**2)


Out[ ]:
0.46270472251902189

In [ ]:
np.var(y_train)


Out[ ]:
0.46270472251902189

Alternatives: mean absolute deviation, median absolute deviation, (something specific to your regression situation?)

So far, we've looked entirely at DT regression, because it is prettier.

In Week 4 homework, we has a DT classification problem. How did it work?


In [ ]:
# here are the criteria that sklearn knows for classification problems
sklearn.tree.tree.CRITERIA_CLF


Out[ ]:
{'entropy': sklearn.tree._tree.Entropy, 'gini': sklearn.tree._tree.Gini}

Which did you use in Homework 4? Does using the other one change anything?

We used gini; let's check if entropy is any different!


In [41]:
# Load and prep data from Week 4:
df = pd.read_csv('../Week_4/IHME_PHMRC_VA_DATA_ADULT_Y2013M09D11_0.csv', low_memory=False)
X = np.array(df.filter(like='word_'))
df['Cause'] = df.gs_text34.map({'Stroke':'Stroke', 'Diabetes':'Diabetes'}).fillna('Other')
y = np.array(df.Cause)

# 10-rep 10-fold cross-validate decision tree with class weights
import sklearn.tree
weights = 1000. / df.Cause.value_counts()
sample_weight = np.array(weights[y])

In [ ]:
%%time
clf = sklearn.tree.DecisionTreeClassifier(max_depth=4, criterion='gini')

scores = []
for rep in range(10):
    print rep,
    cv = sklearn.cross_validation.StratifiedKFold(y, n_folds=10, shuffle=True, random_state=123456+rep)
    for train, test in cv:
        clf.fit(X[train], y[train], sample_weight=sample_weight[train])

        y_pred = clf.predict(X[test])
        scores.append(sklearn.metrics.accuracy_score(y[test], y_pred, sample_weight=sample_weight[test]))
                   
print
print np.mean(scores)


0 1 2 4 5 6 7 8 9
0.562316995709
CPU times: user 1min 6s, sys: 987 ms, total: 1min 7s
Wall time: 1min 8s

In [ ]:
%%time
clf = sklearn.tree.DecisionTreeClassifier(max_depth=4, criterion='entropy')

scores = []
for rep in range(10):
    print rep,
    cv = sklearn.cross_validation.StratifiedKFold(y, n_folds=10, shuffle=True, random_state=123456+rep)
    for train, test in cv:
        clf.fit(X[train], y[train], sample_weight=sample_weight[train])

        y_pred = clf.predict(X[test])
        scores.append(sklearn.metrics.accuracy_score(y[test], y_pred, sample_weight=sample_weight[test]))
                   
print
print np.mean(scores)


0 1 2 3 4 5 6 7 8 9
0.558928990728
CPU times: user 1min 6s, sys: 1.02 s, total: 1min 7s
Wall time: 1min 8s

Some people say that a strength of Decision Trees (and hence Random Forests) is that there is no need to transform your variables.

Show that this is correct for categorial labels (classification setting):

Show that transforming variables is relevant when labels are numeric (regression setting):


In [44]:
# we will test this out with simulated data
np.random.seed(123456)

# do you understand what this is?
X = np.random.normal(size=(100,10))
beta = np.random.normal(size=10)

# if so, how about this?
y_numeric = np.dot(X, beta)
y_categorical = (y_numeric > 0)

# if you really get it, what would be a better way to learn from (X, y_categorical)?

In [47]:
# case one: categorical labels

# with original data:
clf = sklearn.tree.DecisionTreeClassifier(max_depth=2)
clf.fit(X, y_categorical)
print_tree(clf.tree_)


# node 0: impurity = 0.48
if X_i[6] < -0.15: # (node 0)
    # node 1: impurity = 0.41
    if X_i[7] < -0.42: # (node 1)
        # node 2: impurity = 0.40
        return [[ 3.  8.]] # (node 2)
    else:
        # node 3: impurity = 0.20
        return [[ 24.   3.]] # (node 3)
else:
    # node 4: impurity = 0.33
    if X_i[9] < -0.24: # (node 4)
        # node 5: impurity = 0.06
        return [[  1.  31.]] # (node 5)
    else:
        # node 6: impurity = 0.48
        return [[ 12.  18.]] # (node 6)

In [48]:
# with transformed data:
clf = sklearn.tree.DecisionTreeClassifier(max_depth=2)
clf.fit(np.exp(X), y_categorical)
print_tree(clf.tree_)


# node 0: impurity = 0.48
if X_i[6] < 0.86: # (node 0)
    # node 1: impurity = 0.41
    if X_i[7] < 0.66: # (node 1)
        # node 2: impurity = 0.40
        return [[ 3.  8.]] # (node 2)
    else:
        # node 3: impurity = 0.20
        return [[ 24.   3.]] # (node 3)
else:
    # node 4: impurity = 0.33
    if X_i[9] < 0.79: # (node 4)
        # node 5: impurity = 0.06
        return [[  1.  31.]] # (node 5)
    else:
        # node 6: impurity = 0.48
        return [[ 12.  18.]] # (node 6)

In [49]:
# same tree, right?

In [50]:
# case two: with numeric labels
# with original data:
clf = sklearn.tree.DecisionTreeRegressor(max_depth=2)
clf.fit(X, y_numeric)
print_tree(clf.tree_)


# node 0: impurity = 5.94
if X_i[9] < -0.34: # (node 0)
    # node 1: impurity = 3.67
    if X_i[6] < -0.92: # (node 1)
        # node 2: impurity = 2.09
        return [[-0.61808797]] # (node 2)
    else:
        # node 3: impurity = 2.75
        return [[ 2.27810542]] # (node 3)
else:
    # node 4: impurity = 5.06
    if X_i[6] < -0.15: # (node 4)
        # node 5: impurity = 3.73
        return [[-2.05021916]] # (node 5)
    else:
        # node 6: impurity = 3.31
        return [[ 0.49294032]] # (node 6)

In [51]:
# with transformed feature vectors:
clf = sklearn.tree.DecisionTreeRegressor(max_depth=2)
clf.fit(np.exp(X), y_numeric)
print_tree(clf.tree_)


# node 0: impurity = 5.94
if X_i[9] < 0.71: # (node 0)
    # node 1: impurity = 3.67
    if X_i[6] < 0.40: # (node 1)
        # node 2: impurity = 2.09
        return [[-0.61808797]] # (node 2)
    else:
        # node 3: impurity = 2.75
        return [[ 2.27810542]] # (node 3)
else:
    # node 4: impurity = 5.06
    if X_i[6] < 0.86: # (node 4)
        # node 5: impurity = 3.73
        return [[-2.05021916]] # (node 5)
    else:
        # node 6: impurity = 3.31
        return [[ 0.49294032]] # (node 6)

In [52]:
# same again, except the thresholds are in the transformed space

In [53]:
# case three: transformed numeric labels
# with original data
clf = sklearn.tree.DecisionTreeRegressor(max_depth=2)
clf.fit(X, y_numeric)
print_tree(clf.tree_)


# node 0: impurity = 5.94
if X_i[9] < -0.34: # (node 0)
    # node 1: impurity = 3.67
    if X_i[6] < -0.92: # (node 1)
        # node 2: impurity = 2.09
        return [[-0.61808797]] # (node 2)
    else:
        # node 3: impurity = 2.75
        return [[ 2.27810542]] # (node 3)
else:
    # node 4: impurity = 5.06
    if X_i[6] < -0.15: # (node 4)
        # node 5: impurity = 3.73
        return [[-2.05021916]] # (node 5)
    else:
        # node 6: impurity = 3.31
        return [[ 0.49294032]] # (node 6)

In [54]:
clf = sklearn.tree.DecisionTreeRegressor(max_depth=2)
clf.fit(X, np.exp(y_numeric))
print_tree(clf.tree_)


# node 0: impurity = 2415.69
if X_i[3] < 1.99: # (node 0)
    # node 1: impurity = 993.82
    if X_i[6] < 1.18: # (node 1)
        # node 2: impurity = 248.07
        return [[ 7.75697332]] # (node 2)
    else:
        # node 3: impurity = 5526.67
        return [[ 54.44719139]] # (node 3)
else:
    # node 4: impurity = 34480.55
    if X_i[2] < -0.37: # (node 4)
        # node 5: impurity = 0.00
        return [[ 394.9192137]] # (node 5)
    else:
        # node 6: impurity = 0.64
        return [[ 1.01498272]] # (node 6)

In [55]:
# got that?

A systematic approach to pruning

One proposed approach to limiting the complexity of decision trees is through pruning to minimize the sum $$ \sum_{m=1}^{|T|} \sum_{i: x_i \in R_m} (y_i - \hat{y}_{R_m})^2 + \alpha |T|. $$

This can be accomplished recursively: for a tree with a root and two leaves, you must determine if the MSE for the root is less than the MSE for the leaves + $\alpha$. For a more complicated tree, i.e. a root with two non-leaf subtrees, do the pruning separately for each subtree, and then see if you end up in the root-and-two-leave case.

Super-hard extra bonus homework: implement this, and used cross-validation to see what it changes in the Exercise 4 example.


In [56]:
# refactor pruning code from above into a function
def prune(tree, node):
    """ prune decision tree so that specified node becomes a leaf
    Parameters
    ----------
    tree : sklearn.tree.tree._tree.Tree
    node : int
    
    Results
    -------
    changes internal arrays of `tree` so that node `node` is a leaf
    """
    
    # find the left and right children of node
    left_child = tree.children_left[node]
    right_child = tree.children_right[node]

    # find the weight of these nodes in the training dataset
    wt_left = tree.weighted_n_node_samples[left_child]
    wt_right = tree.weighted_n_node_samples[right_child]

    # find the value of these nodes in the training dataset
    val_left = tree.value[left_child]
    val_right = tree.value[right_child]

    # calculate the value of node 12 after pruning
    tree.value[node] = (wt_left*val_left + wt_right*val_right) / (wt_left + wt_right)

    # remove children of node
    tree.children_left[node] = sklearn.tree._tree.TREE_LEAF
    tree.children_right[node] = sklearn.tree._tree.TREE_LEAF

In [57]:
# test that:
clf.fit(X, np.exp(y_numeric))
print_tree(clf.tree_)


# node 0: impurity = 2415.69
if X_i[3] < 1.99: # (node 0)
    # node 1: impurity = 993.82
    if X_i[6] < 1.18: # (node 1)
        # node 2: impurity = 248.07
        return [[ 7.75697332]] # (node 2)
    else:
        # node 3: impurity = 5526.67
        return [[ 54.44719139]] # (node 3)
else:
    # node 4: impurity = 34480.55
    if X_i[4] < 0.10: # (node 4)
        # node 5: impurity = 0.00
        return [[ 394.9192137]] # (node 5)
    else:
        # node 6: impurity = 0.64
        return [[ 1.01498272]] # (node 6)

In [58]:
prune(clf.tree_, 4)
print_tree(clf.tree_)


# node 0: impurity = 2415.69
if X_i[3] < 1.99: # (node 0)
    # node 1: impurity = 993.82
    if X_i[6] < 1.18: # (node 1)
        # node 2: impurity = 248.07
        return [[ 7.75697332]] # (node 2)
    else:
        # node 3: impurity = 5526.67
        return [[ 54.44719139]] # (node 3)
else:
    # node 4: impurity = 34480.55
    return [[ 132.31639304]] # (node 4)

Now make a recursive pruning function that uses an $\alpha$ value to decide what to keep:


In [59]:
tree = clf.tree_

In [60]:
def recursive_prune(tree, alpha, node=0):
    """ Prune tree so to min tree.score[node] + \alpha |T|
    Parameters
    ----------
    tree : sklearn.tree.tree._tree.Tree
    alpha : float, pruning parameter
    node : int, optional; node to consider root of tree, for recursion
    
    Results
    -------
    Prunes tree to maximize objective, returns contribution of impurity to min sum 
    and number of leaves in tree that attains it
    """
    # find the left and right children of node
    left_child = tree.children_left[node]
    right_child = tree.children_right[node]
    wt = tree.weighted_n_node_samples[node]

    if left_child == sklearn.tree._tree.TREE_LEAF:
        assert right_child == sklearn.tree._tree.TREE_LEAF, 'Expected binary tree'
        
        return (tree.impurity[node],1)
    else:

        # find the weight of these nodes in the training dataset
        wt_left = tree.weighted_n_node_samples[left_child]
        wt_right = tree.weighted_n_node_samples[right_child]

        # calculate contribution of children to objective function
        left_impurity_sum, left_leaves = recursive_prune(tree, alpha, left_child)
        right_impurity_sum, right_leaves = recursive_prune(tree, alpha, right_child)
        
        current_obj_contrib = (wt_left*left_impurity_sum + wt_right*right_impurity_sum) \
                            / (wt_left + wt_right)
        
        # compare to contribution to objective function if tree is pruned so that current node
        # is a leaf
        pruned_obj_contrib = tree.impurity[node]
        if pruned_obj_contrib + alpha <= current_obj_contrib + (left_leaves + right_leaves)*alpha:
            prune(tree, node)
            return (tree.impurity[node],1)
        else:
            return current_obj_contrib, left_leaves + right_leaves

In [61]:
# test that:
clf.fit(X, np.exp(y_numeric))
print_tree(clf.tree_)


# node 0: impurity = 2415.69
if X_i[3] < 1.99: # (node 0)
    # node 1: impurity = 993.82
    if X_i[6] < 1.18: # (node 1)
        # node 2: impurity = 248.07
        return [[ 7.75697332]] # (node 2)
    else:
        # node 3: impurity = 5526.67
        return [[ 54.44719139]] # (node 3)
else:
    # node 4: impurity = 34480.55
    if X_i[9] < -0.26: # (node 4)
        # node 5: impurity = 0.00
        return [[ 394.9192137]] # (node 5)
    else:
        # node 6: impurity = 0.64
        return [[ 1.01498272]] # (node 6)

In [62]:
clf.fit(X, np.exp(y_numeric))
recursive_prune(clf.tree_, alpha=0)  # should not change tree
print_tree(clf.tree_)


# node 0: impurity = 2415.69
if X_i[3] < 1.99: # (node 0)
    # node 1: impurity = 993.82
    if X_i[6] < 1.18: # (node 1)
        # node 2: impurity = 248.07
        return [[ 7.75697332]] # (node 2)
    else:
        # node 3: impurity = 5526.67
        return [[ 54.44719139]] # (node 3)
else:
    # node 4: impurity = 34480.55
    if X_i[4] < 0.10: # (node 4)
        # node 5: impurity = 0.00
        return [[ 394.9192137]] # (node 5)
    else:
        # node 6: impurity = 0.64
        return [[ 1.01498272]] # (node 6)

In [63]:
clf.fit(X, np.exp(y_numeric))
recursive_prune(clf.tree_, alpha=np.inf)  # should prune to the root
print_tree(clf.tree_)


# node 0: impurity = 2415.69
return [[ 16.16277772]] # (node 0)

In [64]:
clf.fit(X, np.exp(y_numeric))
recursive_prune(clf.tree_, alpha=500)  # should prune node 1, not node 4
print_tree(clf.tree_)


# node 0: impurity = 2415.69
if X_i[3] < 1.99: # (node 0)
    # node 1: impurity = 993.82
    return [[ 12.57039787]] # (node 1)
else:
    # node 4: impurity = 34480.55
    if X_i[7] < -0.55: # (node 4)
        # node 5: impurity = 0.00
        return [[ 394.9192137]] # (node 5)
    else:
        # node 6: impurity = 0.64
        return [[ 1.01498272]] # (node 6)

In [ ]: