Key Requirements for the iRF scikit-learn implementation

  • The following is a documentation of the main requirements for the iRF implementation

Typical Setup


In [35]:
%matplotlib inline
import matplotlib.pyplot as plt

import pydotplus
import numpy as np
import pprint
from sklearn import metrics
from sklearn import tree
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn import tree
from sklearn.tree import _tree
from IPython.display import display, Image
from sklearn.datasets import load_iris
from sklearn.datasets import load_breast_cancer

from functools import reduce

# Import our custom utilities
from imp import reload
from utils import utils
reload(utils)


Out[35]:
<module 'utils.utils' from '/Users/shamindras/PERSONAL/LEARNING/REPOS/scikit-learn-sandbox/jupyter/utils/utils.py'>

Step 1: Fit the Initial Random Forest

  • Just fit every feature with equal weights per the usual random forest code e.g. DecisionForestClassifier in scikit-learn

In [2]:
%timeit
X_train, X_test, y_train, y_test, rf = utils.generate_rf_example(sklearn_ds = load_breast_cancer()
                                                           , train_split_propn = 0.9
                                                           , n_estimators = 3
                                                           , random_state_split = 2017
                                                           , random_state_classifier = 2018)

Check out the data


In [31]:
print("Training feature dimensions", X_train.shape, sep = ":\n")
print("\n")
print("Training outcome dimensions", y_train.shape, sep = ":\n")
print("\n")
print("Test feature dimensions", X_test.shape, sep = ":\n")
print("\n")
print("Test outcome dimensions", y_test.shape, sep = ":\n")
print("\n")
print("first 5 rows of the training set features", X_train[:5], sep = ":\n")
print("\n")
print("first 5 rows of the training set outcomes", y_train[:5], sep = ":\n")


Training feature dimensions:
(512, 30)


Training outcome dimensions:
(512,)


Test feature dimensions:
(57, 30)


Test outcome dimensions:
(57,)


first 5 rows of the training set features:
[[  1.98900000e+01   2.02600000e+01   1.30500000e+02   1.21400000e+03
    1.03700000e-01   1.31000000e-01   1.41100000e-01   9.43100000e-02
    1.80200000e-01   6.18800000e-02   5.07900000e-01   8.73700000e-01
    3.65400000e+00   5.97000000e+01   5.08900000e-03   2.30300000e-02
    3.05200000e-02   1.17800000e-02   1.05700000e-02   3.39100000e-03
    2.37300000e+01   2.52300000e+01   1.60500000e+02   1.64600000e+03
    1.41700000e-01   3.30900000e-01   4.18500000e-01   1.61300000e-01
    2.54900000e-01   9.13600000e-02]
 [  2.01800000e+01   1.95400000e+01   1.33800000e+02   1.25000000e+03
    1.13300000e-01   1.48900000e-01   2.13300000e-01   1.25900000e-01
    1.72400000e-01   6.05300000e-02   4.33100000e-01   1.00100000e+00
    3.00800000e+00   5.24900000e+01   9.08700000e-03   2.71500000e-02
    5.54600000e-02   1.91000000e-02   2.45100000e-02   4.00500000e-03
    2.20300000e+01   2.50700000e+01   1.46000000e+02   1.47900000e+03
    1.66500000e-01   2.94200000e-01   5.30800000e-01   2.17300000e-01
    3.03200000e-01   8.07500000e-02]
 [  1.44000000e+01   2.69900000e+01   9.22500000e+01   6.46100000e+02
    6.99500000e-02   5.22300000e-02   3.47600000e-02   1.73700000e-02
    1.70700000e-01   5.43300000e-02   2.31500000e-01   9.11200000e-01
    1.72700000e+00   2.05200000e+01   5.35600000e-03   1.67900000e-02
    1.97100000e-02   6.37000000e-03   1.41400000e-02   1.89200000e-03
    1.54000000e+01   3.19800000e+01   1.00400000e+02   7.34600000e+02
    1.01700000e-01   1.46000000e-01   1.47200000e-01   5.56300000e-02
    2.34500000e-01   6.46400000e-02]
 [  1.44700000e+01   2.49900000e+01   9.58100000e+01   6.56400000e+02
    8.83700000e-02   1.23000000e-01   1.00900000e-01   3.89000000e-02
    1.87200000e-01   6.34100000e-02   2.54200000e-01   1.07900000e+00
    2.61500000e+00   2.31100000e+01   7.13800000e-03   4.65300000e-02
    3.82900000e-02   1.16200000e-02   2.06800000e-02   6.11100000e-03
    1.62200000e+01   3.17300000e+01   1.13500000e+02   8.08900000e+02
    1.34000000e-01   4.20200000e-01   4.04000000e-01   1.20500000e-01
    3.18700000e-01   1.02300000e-01]
 [  2.05500000e+01   2.08600000e+01   1.37800000e+02   1.30800000e+03
    1.04600000e-01   1.73900000e-01   2.08500000e-01   1.32200000e-01
    2.12700000e-01   6.25100000e-02   6.98600000e-01   9.90100000e-01
    4.70600000e+00   8.77800000e+01   4.57800000e-03   2.61600000e-02
    4.00500000e-02   1.42100000e-02   1.94800000e-02   2.68900000e-03
    2.43000000e+01   2.54800000e+01   1.60200000e+02   1.80900000e+03
    1.26800000e-01   3.13500000e-01   4.43300000e-01   2.14800000e-01
    3.07700000e-01   7.56900000e-02]]


first 5 rows of the training set outcomes:
[0 0 1 1 0]

In [30]:
X_train.shape[0]
breast_cancer = load_breast_cancer()
breast_cancer.data.shape[0]


Out[30]:
569

Step 2: For each Tree get core leaf node features

  • For each decision tree in the classifier, get:
    • The list of leaf nodes
    • Depth of the leaf node
    • Leaf node predicted class i.e. {0, 1}
    • Probability of predicting class in leaf node
    • Number of observations in the leaf node i.e. weight of node

Get the 2 Decision trees to use for testing


In [4]:
# Import our custom utilities
rf.n_estimators


Out[4]:
3

In [5]:
estimator0 = rf.estimators_[0] # First tree
estimator1 = rf.estimators_[1] # Second tree
estimator2 = rf.estimators_[2] # Second tree

Design the single function to get the key tree information

Get data from the first and second decision tree


In [6]:
tree_dat0 = utils.getTreeData(X_train = X_train, dtree = estimator0, root_node_id = 0)
tree_dat1 = utils.getTreeData(X_train = X_train, dtree = estimator1, root_node_id = 0)
tree_dat1 = utils.getTreeData(X_train = X_train, dtree = estimator2, root_node_id = 0)

Decision Tree 0 (First) - Get output

Check the output against the decision tree graph


In [21]:
# Now plot the trees individually
utils.draw_tree(decision_tree = estimator0)



In [22]:
utils.prettyPrintDict(inp_dict = tree_dat0)


{   'all_leaf_node_classes': [   1,
                                 0,
                                 1,
                                 0,
                                 1,
                                 0,
                                 1,
                                 1,
                                 0,
                                 1,
                                 1,
                                 0,
                                 0,
                                 1,
                                 0,
                                 1,
                                 1,
                                 1,
                                 0,
                                 1,
                                 0,
                                 0],
    'all_leaf_node_paths': [   array([0, 1, 2, 3, 4, 5, 6]),
                               array([0, 1, 2, 3, 4, 5, 7, 8]),
                               array([0, 1, 2, 3, 4, 5, 7, 9]),
                               array([ 0,  1,  2,  3,  4, 10]),
                               array([ 0,  1,  2,  3, 11, 12]),
                               array([ 0,  1,  2,  3, 11, 13, 14]),
                               array([ 0,  1,  2,  3, 11, 13, 15]),
                               array([ 0,  1,  2, 16, 17, 18, 19]),
                               array([ 0,  1,  2, 16, 17, 18, 20, 21, 22]),
                               array([ 0,  1,  2, 16, 17, 18, 20, 21, 23]),
                               array([ 0,  1,  2, 16, 17, 18, 20, 24]),
                               array([ 0,  1,  2, 16, 17, 25]),
                               array([ 0,  1,  2, 16, 26]),
                               array([ 0,  1, 27, 28, 29]),
                               array([ 0,  1, 27, 28, 30]),
                               array([ 0,  1, 27, 31, 32]),
                               array([ 0,  1, 27, 31, 33, 34]),
                               array([ 0,  1, 27, 31, 33, 35, 36]),
                               array([ 0,  1, 27, 31, 33, 35, 37]),
                               array([ 0, 38, 39, 40]),
                               array([ 0, 38, 39, 41]),
                               array([ 0, 38, 42])],
    'all_leaf_node_values': [   array([[  0, 239]]),
                                array([[1, 0]]),
                                array([[0, 8]]),
                                array([[2, 0]]),
                                array([[0, 8]]),
                                array([[7, 0]]),
                                array([[0, 2]]),
                                array([[ 0, 27]]),
                                array([[3, 0]]),
                                array([[0, 1]]),
                                array([[ 0, 10]]),
                                array([[2, 0]]),
                                array([[7, 0]]),
                                array([[0, 7]]),
                                array([[1, 0]]),
                                array([[0, 2]]),
                                array([[0, 2]]),
                                array([[0, 1]]),
                                array([[19,  0]]),
                                array([[0, 6]]),
                                array([[2, 0]]),
                                array([[155,   0]])],
    'all_leaf_nodes': [   6,
                          8,
                          9,
                          10,
                          12,
                          14,
                          15,
                          19,
                          22,
                          23,
                          24,
                          25,
                          26,
                          29,
                          30,
                          32,
                          34,
                          36,
                          37,
                          40,
                          41,
                          42],
    'all_leaf_paths_features': [   array([23, 26,  1,  6, 13,  5]),
                                   array([23, 26,  1,  6, 13,  5,  9]),
                                   array([23, 26,  1,  6, 13,  5,  9]),
                                   array([23, 26,  1,  6, 13]),
                                   array([23, 26,  1,  6, 22]),
                                   array([23, 26,  1,  6, 22, 13]),
                                   array([23, 26,  1,  6, 22, 13]),
                                   array([23, 26,  1,  3, 27, 27]),
                                   array([23, 26,  1,  3, 27, 27, 19, 17]),
                                   array([23, 26,  1,  3, 27, 27, 19, 17]),
                                   array([23, 26,  1,  3, 27, 27, 19]),
                                   array([23, 26,  1,  3, 27]),
                                   array([23, 26,  1,  3]),
                                   array([23, 26,  3, 18]),
                                   array([23, 26,  3, 18]),
                                   array([23, 26,  3,  8]),
                                   array([23, 26,  3,  8, 22]),
                                   array([23, 26,  3,  8, 22,  4]),
                                   array([23, 26,  3,  8, 22,  4]),
                                   array([23, 26, 22]),
                                   array([23, 26, 22]),
                                   array([23, 26])],
    'all_uniq_leaf_paths_features': [   array([ 1,  5,  6, 13, 23, 26]),
                                        array([ 1,  5,  6,  9, 13, 23, 26]),
                                        array([ 1,  5,  6,  9, 13, 23, 26]),
                                        array([ 1,  6, 13, 23, 26]),
                                        array([ 1,  6, 22, 23, 26]),
                                        array([ 1,  6, 13, 22, 23, 26]),
                                        array([ 1,  6, 13, 22, 23, 26]),
                                        array([ 1,  3, 23, 26, 27]),
                                        array([ 1,  3, 17, 19, 23, 26, 27]),
                                        array([ 1,  3, 17, 19, 23, 26, 27]),
                                        array([ 1,  3, 19, 23, 26, 27]),
                                        array([ 1,  3, 23, 26, 27]),
                                        array([ 1,  3, 23, 26]),
                                        array([ 3, 18, 23, 26]),
                                        array([ 3, 18, 23, 26]),
                                        array([ 3,  8, 23, 26]),
                                        array([ 3,  8, 22, 23, 26]),
                                        array([ 3,  4,  8, 22, 23, 26]),
                                        array([ 3,  4,  8, 22, 23, 26]),
                                        array([22, 23, 26]),
                                        array([22, 23, 26]),
                                        array([23, 26])],
    'leaf_nodes_depths': [   6,
                             7,
                             7,
                             5,
                             5,
                             6,
                             6,
                             6,
                             8,
                             8,
                             7,
                             5,
                             4,
                             4,
                             4,
                             4,
                             5,
                             6,
                             6,
                             3,
                             3,
                             2],
    'max_node_depth': 8,
    'n_nodes': 43,
    'node_features_idx': array([23, 26,  1,  6, 13,  5, 28,  9, 28, 28, 28, 22, 28, 13, 28, 28,  3,
       27, 27, 28, 19, 17, 28, 28, 28, 28, 28,  3, 18, 28, 28,  8, 28, 22,
       28,  4, 28, 28, 26, 22, 28, 28, 28]),
    'num_features_used': 16,
    'tot_leaf_node_values': [   239,
                                1,
                                8,
                                2,
                                8,
                                7,
                                2,
                                27,
                                3,
                                1,
                                10,
                                2,
                                7,
                                7,
                                1,
                                2,
                                2,
                                1,
                                19,
                                6,
                                2,
                                155]}

In [23]:
# Count the number of samples passing through the leaf nodes
sum(tree_dat0['tot_leaf_node_values'])


Out[23]:
512

Step 3: Get the Gini Importance of Weights for the Random Forest

  • For the first random forest we just need to get the Gini Importance of Weights

Step 3.1 Get them numerically - most important


In [24]:
feature_importances = rf.feature_importances_
std = np.std([dtree.feature_importances_ for dtree in rf.estimators_]
             , axis=0)
feature_importances_rank_idx = np.argsort(feature_importances)[::-1]

# Check that the feature importances are standardized to 1
print(sum(feature_importances))


1.0

Step 3.2 Display Feature Importances Graphically (just for interest)


In [11]:
# Print the feature ranking
print("Feature ranking:")

for f in range(X_train.shape[1]):
    print("%d. feature %d (%f)" % (f + 1
                                   , feature_importances_rank_idx[f]
                                   , feature_importances[feature_importances_rank_idx[f]]))
    
# Plot the feature importances of the forest
plt.figure()
plt.title("Feature importances")
plt.bar(range(X_train.shape[1])
        , feature_importances[feature_importances_rank_idx]
        , color="r"
        , yerr = std[feature_importances_rank_idx], align="center")
plt.xticks(range(X_train.shape[1]), feature_importances_rank_idx)
plt.xlim([-1, X_train.shape[1]])
plt.show()


Feature ranking:
1. feature 23 (0.465017)
2. feature 20 (0.232076)
3. feature 7 (0.039853)
4. feature 26 (0.036480)
5. feature 27 (0.033853)
6. feature 3 (0.024435)
7. feature 22 (0.022307)
8. feature 24 (0.021493)
9. feature 28 (0.021055)
10. feature 1 (0.017764)
11. feature 6 (0.015104)
12. feature 17 (0.011860)
13. feature 13 (0.010908)
14. feature 21 (0.009833)
15. feature 0 (0.008362)
16. feature 29 (0.005037)
17. feature 19 (0.004404)
18. feature 18 (0.003789)
19. feature 8 (0.003747)
20. feature 14 (0.002958)
21. feature 4 (0.002603)
22. feature 9 (0.002436)
23. feature 12 (0.002320)
24. feature 10 (0.002012)
25. feature 5 (0.000293)
26. feature 15 (0.000000)
27. feature 11 (0.000000)
28. feature 16 (0.000000)
29. feature 2 (0.000000)
30. feature 25 (0.000000)

Putting it all together

  • Create a dictionary object to include all of the random forest objects

In [12]:
# CHECK: If the random forest objects are going to be really large in size
#        we could just omit them and only return our custom summary outputs

rf_metrics = utils.getValidationMetrics(rf, y_true = y_test, X_test = X_test)
all_rf_outputs = {"rf_obj" : rf,
                 "feature_importances" : feature_importances,
                 "feature_importances_rank_idx" : feature_importances_rank_idx,
                 "rf_metrics" : rf_metrics}

In [13]:
# CHECK: The following should be paralellized!
# CHECK: Whether we can maintain X_train correctly as required
for idx, dtree in enumerate(rf.estimators_):
    dtree_out = utils.getTreeData(X_train = X_train, dtree = dtree, root_node_id = 0)
    # Append output to dictionary
    all_rf_outputs["dtree" + str(idx)] = dtree_out

Check the final dictionary of outputs


In [14]:
utils.prettyPrintDict(inp_dict = all_rf_outputs)


{   'dtree0': {   'all_leaf_node_classes': [   1,
                                               0,
                                               1,
                                               0,
                                               1,
                                               0,
                                               1,
                                               1,
                                               0,
                                               1,
                                               1,
                                               0,
                                               0,
                                               1,
                                               0,
                                               1,
                                               1,
                                               1,
                                               0,
                                               1,
                                               0,
                                               0],
                  'all_leaf_node_paths': [   array([0, 1, 2, 3, 4, 5, 6]),
                                             array([0, 1, 2, 3, 4, 5, 7, 8]),
                                             array([0, 1, 2, 3, 4, 5, 7, 9]),
                                             array([ 0,  1,  2,  3,  4, 10]),
                                             array([ 0,  1,  2,  3, 11, 12]),
                                             array([ 0,  1,  2,  3, 11, 13, 14]),
                                             array([ 0,  1,  2,  3, 11, 13, 15]),
                                             array([ 0,  1,  2, 16, 17, 18, 19]),
                                             array([ 0,  1,  2, 16, 17, 18, 20, 21, 22]),
                                             array([ 0,  1,  2, 16, 17, 18, 20, 21, 23]),
                                             array([ 0,  1,  2, 16, 17, 18, 20, 24]),
                                             array([ 0,  1,  2, 16, 17, 25]),
                                             array([ 0,  1,  2, 16, 26]),
                                             array([ 0,  1, 27, 28, 29]),
                                             array([ 0,  1, 27, 28, 30]),
                                             array([ 0,  1, 27, 31, 32]),
                                             array([ 0,  1, 27, 31, 33, 34]),
                                             array([ 0,  1, 27, 31, 33, 35, 36]),
                                             array([ 0,  1, 27, 31, 33, 35, 37]),
                                             array([ 0, 38, 39, 40]),
                                             array([ 0, 38, 39, 41]),
                                             array([ 0, 38, 42])],
                  'all_leaf_node_values': [   array([[  0, 239]]),
                                              array([[1, 0]]),
                                              array([[0, 8]]),
                                              array([[2, 0]]),
                                              array([[0, 8]]),
                                              array([[7, 0]]),
                                              array([[0, 2]]),
                                              array([[ 0, 27]]),
                                              array([[3, 0]]),
                                              array([[0, 1]]),
                                              array([[ 0, 10]]),
                                              array([[2, 0]]),
                                              array([[7, 0]]),
                                              array([[0, 7]]),
                                              array([[1, 0]]),
                                              array([[0, 2]]),
                                              array([[0, 2]]),
                                              array([[0, 1]]),
                                              array([[19,  0]]),
                                              array([[0, 6]]),
                                              array([[2, 0]]),
                                              array([[155,   0]])],
                  'all_leaf_nodes': [   6,
                                        8,
                                        9,
                                        10,
                                        12,
                                        14,
                                        15,
                                        19,
                                        22,
                                        23,
                                        24,
                                        25,
                                        26,
                                        29,
                                        30,
                                        32,
                                        34,
                                        36,
                                        37,
                                        40,
                                        41,
                                        42],
                  'all_leaf_paths_features': [   array([23, 26,  1,  6, 13,  5]),
                                                 array([23, 26,  1,  6, 13,  5,  9]),
                                                 array([23, 26,  1,  6, 13,  5,  9]),
                                                 array([23, 26,  1,  6, 13]),
                                                 array([23, 26,  1,  6, 22]),
                                                 array([23, 26,  1,  6, 22, 13]),
                                                 array([23, 26,  1,  6, 22, 13]),
                                                 array([23, 26,  1,  3, 27, 27]),
                                                 array([23, 26,  1,  3, 27, 27, 19, 17]),
                                                 array([23, 26,  1,  3, 27, 27, 19, 17]),
                                                 array([23, 26,  1,  3, 27, 27, 19]),
                                                 array([23, 26,  1,  3, 27]),
                                                 array([23, 26,  1,  3]),
                                                 array([23, 26,  3, 18]),
                                                 array([23, 26,  3, 18]),
                                                 array([23, 26,  3,  8]),
                                                 array([23, 26,  3,  8, 22]),
                                                 array([23, 26,  3,  8, 22,  4]),
                                                 array([23, 26,  3,  8, 22,  4]),
                                                 array([23, 26, 22]),
                                                 array([23, 26, 22]),
                                                 array([23, 26])],
                  'all_uniq_leaf_paths_features': [   array([ 1,  5,  6, 13, 23, 26]),
                                                      array([ 1,  5,  6,  9, 13, 23, 26]),
                                                      array([ 1,  5,  6,  9, 13, 23, 26]),
                                                      array([ 1,  6, 13, 23, 26]),
                                                      array([ 1,  6, 22, 23, 26]),
                                                      array([ 1,  6, 13, 22, 23, 26]),
                                                      array([ 1,  6, 13, 22, 23, 26]),
                                                      array([ 1,  3, 23, 26, 27]),
                                                      array([ 1,  3, 17, 19, 23, 26, 27]),
                                                      array([ 1,  3, 17, 19, 23, 26, 27]),
                                                      array([ 1,  3, 19, 23, 26, 27]),
                                                      array([ 1,  3, 23, 26, 27]),
                                                      array([ 1,  3, 23, 26]),
                                                      array([ 3, 18, 23, 26]),
                                                      array([ 3, 18, 23, 26]),
                                                      array([ 3,  8, 23, 26]),
                                                      array([ 3,  8, 22, 23, 26]),
                                                      array([ 3,  4,  8, 22, 23, 26]),
                                                      array([ 3,  4,  8, 22, 23, 26]),
                                                      array([22, 23, 26]),
                                                      array([22, 23, 26]),
                                                      array([23, 26])],
                  'leaf_nodes_depths': [   6,
                                           7,
                                           7,
                                           5,
                                           5,
                                           6,
                                           6,
                                           6,
                                           8,
                                           8,
                                           7,
                                           5,
                                           4,
                                           4,
                                           4,
                                           4,
                                           5,
                                           6,
                                           6,
                                           3,
                                           3,
                                           2],
                  'max_node_depth': 8,
                  'n_nodes': 43,
                  'node_features_idx': array([23, 26,  1,  6, 13,  5, 28,  9, 28, 28, 28, 22, 28, 13, 28, 28,  3,
       27, 27, 28, 19, 17, 28, 28, 28, 28, 28,  3, 18, 28, 28,  8, 28, 22,
       28,  4, 28, 28, 26, 22, 28, 28, 28]),
                  'num_features_used': 16,
                  'tot_leaf_node_values': [   239,
                                              1,
                                              8,
                                              2,
                                              8,
                                              7,
                                              2,
                                              27,
                                              3,
                                              1,
                                              10,
                                              2,
                                              7,
                                              7,
                                              1,
                                              2,
                                              2,
                                              1,
                                              19,
                                              6,
                                              2,
                                              155]},
    'dtree1': {   'all_leaf_node_classes': [   1,
                                               0,
                                               1,
                                               0,
                                               1,
                                               0,
                                               1,
                                               0,
                                               1,
                                               1,
                                               0,
                                               1,
                                               1,
                                               0,
                                               1,
                                               0,
                                               1,
                                               0,
                                               1,
                                               0,
                                               0],
                  'all_leaf_node_paths': [   array([0, 1, 2, 3, 4, 5]),
                                             array([0, 1, 2, 3, 4, 6, 7, 8]),
                                             array([ 0,  1,  2,  3,  4,  6,  7,  9, 10, 11]),
                                             array([ 0,  1,  2,  3,  4,  6,  7,  9, 10, 12]),
                                             array([ 0,  1,  2,  3,  4,  6,  7,  9, 13]),
                                             array([ 0,  1,  2,  3,  4,  6, 14, 15]),
                                             array([ 0,  1,  2,  3,  4,  6, 14, 16]),
                                             array([ 0,  1,  2,  3, 17, 18]),
                                             array([ 0,  1,  2,  3, 17, 19]),
                                             array([ 0,  1,  2, 20, 21, 22, 23]),
                                             array([ 0,  1,  2, 20, 21, 22, 24, 25]),
                                             array([ 0,  1,  2, 20, 21, 22, 24, 26]),
                                             array([ 0,  1,  2, 20, 21, 27]),
                                             array([ 0,  1,  2, 20, 28]),
                                             array([ 0,  1, 29, 30]),
                                             array([ 0,  1, 29, 31]),
                                             array([ 0, 32, 33, 34, 35]),
                                             array([ 0, 32, 33, 34, 36]),
                                             array([ 0, 32, 33, 37, 38]),
                                             array([ 0, 32, 33, 37, 39]),
                                             array([ 0, 32, 40])],
                  'all_leaf_node_values': [   array([[  0, 189]]),
                                              array([[3, 0]]),
                                              array([[0, 5]]),
                                              array([[1, 0]]),
                                              array([[  0, 101]]),
                                              array([[1, 0]]),
                                              array([[0, 1]]),
                                              array([[2, 0]]),
                                              array([[0, 3]]),
                                              array([[0, 2]]),
                                              array([[5, 0]]),
                                              array([[0, 1]]),
                                              array([[0, 7]]),
                                              array([[10,  0]]),
                                              array([[0, 3]]),
                                              array([[12,  0]]),
                                              array([[0, 2]]),
                                              array([[19,  0]]),
                                              array([[0, 7]]),
                                              array([[1, 0]]),
                                              array([[137,   0]])],
                  'all_leaf_nodes': [   5,
                                        8,
                                        11,
                                        12,
                                        13,
                                        15,
                                        16,
                                        18,
                                        19,
                                        23,
                                        25,
                                        26,
                                        27,
                                        28,
                                        30,
                                        31,
                                        35,
                                        36,
                                        38,
                                        39,
                                        40],
                  'all_leaf_paths_features': [   array([20, 24, 27, 10,  0]),
                                                 array([20, 24, 27, 10,  0,  6,  0]),
                                                 array([20, 24, 27, 10,  0,  6,  0, 14, 20]),
                                                 array([20, 24, 27, 10,  0,  6,  0, 14, 20]),
                                                 array([20, 24, 27, 10,  0,  6,  0, 14]),
                                                 array([20, 24, 27, 10,  0,  6, 18]),
                                                 array([20, 24, 27, 10,  0,  6, 18]),
                                                 array([20, 24, 27, 10, 28]),
                                                 array([20, 24, 27, 10, 28]),
                                                 array([20, 24, 27, 21,  6,  6]),
                                                 array([20, 24, 27, 21,  6,  6, 12]),
                                                 array([20, 24, 27, 21,  6,  6, 12]),
                                                 array([20, 24, 27, 21,  6]),
                                                 array([20, 24, 27, 21]),
                                                 array([20, 24, 22]),
                                                 array([20, 24, 22]),
                                                 array([20,  7, 17, 29]),
                                                 array([20,  7, 17, 29]),
                                                 array([20,  7, 17, 28]),
                                                 array([20,  7, 17, 28]),
                                                 array([20,  7])],
                  'all_uniq_leaf_paths_features': [   array([ 0, 10, 20, 24, 27]),
                                                      array([ 0,  6, 10, 20, 24, 27]),
                                                      array([ 0,  6, 10, 14, 20, 24, 27]),
                                                      array([ 0,  6, 10, 14, 20, 24, 27]),
                                                      array([ 0,  6, 10, 14, 20, 24, 27]),
                                                      array([ 0,  6, 10, 18, 20, 24, 27]),
                                                      array([ 0,  6, 10, 18, 20, 24, 27]),
                                                      array([10, 20, 24, 27, 28]),
                                                      array([10, 20, 24, 27, 28]),
                                                      array([ 6, 20, 21, 24, 27]),
                                                      array([ 6, 12, 20, 21, 24, 27]),
                                                      array([ 6, 12, 20, 21, 24, 27]),
                                                      array([ 6, 20, 21, 24, 27]),
                                                      array([20, 21, 24, 27]),
                                                      array([20, 22, 24]),
                                                      array([20, 22, 24]),
                                                      array([ 7, 17, 20, 29]),
                                                      array([ 7, 17, 20, 29]),
                                                      array([ 7, 17, 20, 28]),
                                                      array([ 7, 17, 20, 28]),
                                                      array([ 7, 20])],
                  'leaf_nodes_depths': [   5,
                                           7,
                                           9,
                                           9,
                                           8,
                                           7,
                                           7,
                                           5,
                                           5,
                                           6,
                                           7,
                                           7,
                                           5,
                                           4,
                                           3,
                                           3,
                                           4,
                                           4,
                                           4,
                                           4,
                                           2],
                  'max_node_depth': 9,
                  'n_nodes': 41,
                  'node_features_idx': array([20, 24, 27, 10,  0, 28,  6,  0, 28, 14, 20, 28, 28, 28, 18, 28, 28,
       28, 28, 28, 21,  6,  6, 28, 12, 28, 28, 28, 28, 22, 28, 28,  7, 17,
       29, 28, 28, 28, 28, 28, 28]),
                  'num_features_used': 15,
                  'tot_leaf_node_values': [   189,
                                              3,
                                              5,
                                              1,
                                              101,
                                              1,
                                              1,
                                              2,
                                              3,
                                              2,
                                              5,
                                              1,
                                              7,
                                              10,
                                              3,
                                              12,
                                              2,
                                              19,
                                              7,
                                              1,
                                              137]},
    'dtree2': {   'all_leaf_node_classes': [   1,
                                               0,
                                               1,
                                               1,
                                               0,
                                               1,
                                               1,
                                               0,
                                               1,
                                               0,
                                               0,
                                               1,
                                               0,
                                               1,
                                               0,
                                               1,
                                               0],
                  'all_leaf_node_paths': [   array([0, 1, 2, 3, 4]),
                                             array([0, 1, 2, 3, 5]),
                                             array([0, 1, 2, 6, 7, 8]),
                                             array([ 0,  1,  2,  6,  7,  9, 10]),
                                             array([ 0,  1,  2,  6,  7,  9, 11, 12]),
                                             array([ 0,  1,  2,  6,  7,  9, 11, 13]),
                                             array([ 0,  1,  2,  6, 14, 15]),
                                             array([ 0,  1,  2,  6, 14, 16]),
                                             array([ 0,  1, 17, 18, 19]),
                                             array([ 0,  1, 17, 18, 20]),
                                             array([ 0,  1, 17, 21]),
                                             array([ 0, 22, 23, 24]),
                                             array([ 0, 22, 23, 25]),
                                             array([ 0, 22, 26, 27]),
                                             array([ 0, 22, 26, 28, 29, 30]),
                                             array([ 0, 22, 26, 28, 29, 31]),
                                             array([ 0, 22, 26, 28, 32])],
                  'all_leaf_node_values': [   array([[0, 5]]),
                                              array([[4, 0]]),
                                              array([[  0, 230]]),
                                              array([[ 0, 43]]),
                                              array([[2, 0]]),
                                              array([[0, 7]]),
                                              array([[0, 6]]),
                                              array([[2, 0]]),
                                              array([[ 0, 15]]),
                                              array([[7, 0]]),
                                              array([[14,  0]]),
                                              array([[0, 3]]),
                                              array([[6, 0]]),
                                              array([[0, 1]]),
                                              array([[12,  0]]),
                                              array([[0, 1]]),
                                              array([[154,   0]])],
                  'all_leaf_nodes': [   4,
                                        5,
                                        8,
                                        10,
                                        12,
                                        13,
                                        15,
                                        16,
                                        19,
                                        20,
                                        21,
                                        24,
                                        25,
                                        27,
                                        30,
                                        31,
                                        32],
                  'all_leaf_paths_features': [   array([23,  7, 28,  7]),
                                                 array([23,  7, 28,  7]),
                                                 array([23,  7, 28, 13, 26]),
                                                 array([23,  7, 28, 13, 26, 23]),
                                                 array([23,  7, 28, 13, 26, 23,  3]),
                                                 array([23,  7, 28, 13, 26, 23,  3]),
                                                 array([23,  7, 28, 13, 27]),
                                                 array([23,  7, 28, 13, 27]),
                                                 array([23,  7, 28,  1]),
                                                 array([23,  7, 28,  1]),
                                                 array([23,  7, 28]),
                                                 array([23, 21, 20]),
                                                 array([23, 21, 20]),
                                                 array([23, 21, 27]),
                                                 array([23, 21, 27,  7, 14]),
                                                 array([23, 21, 27,  7, 14]),
                                                 array([23, 21, 27,  7])],
                  'all_uniq_leaf_paths_features': [   array([ 7, 23, 28]),
                                                      array([ 7, 23, 28]),
                                                      array([ 7, 13, 23, 26, 28]),
                                                      array([ 7, 13, 23, 26, 28]),
                                                      array([ 3,  7, 13, 23, 26, 28]),
                                                      array([ 3,  7, 13, 23, 26, 28]),
                                                      array([ 7, 13, 23, 27, 28]),
                                                      array([ 7, 13, 23, 27, 28]),
                                                      array([ 1,  7, 23, 28]),
                                                      array([ 1,  7, 23, 28]),
                                                      array([ 7, 23, 28]),
                                                      array([20, 21, 23]),
                                                      array([20, 21, 23]),
                                                      array([21, 23, 27]),
                                                      array([ 7, 14, 21, 23, 27]),
                                                      array([ 7, 14, 21, 23, 27]),
                                                      array([ 7, 21, 23, 27])],
                  'leaf_nodes_depths': [   4,
                                           4,
                                           5,
                                           6,
                                           7,
                                           7,
                                           5,
                                           5,
                                           4,
                                           4,
                                           3,
                                           3,
                                           3,
                                           3,
                                           5,
                                           5,
                                           4],
                  'max_node_depth': 7,
                  'n_nodes': 33,
                  'node_features_idx': array([23,  7, 28,  7, 28, 28, 13, 26, 28, 23, 28,  3, 28, 28, 27, 28, 28,
       28,  1, 28, 28, 28, 21, 20, 28, 28, 27, 28,  7, 14, 28, 28, 28]),
                  'num_features_used': 11,
                  'tot_leaf_node_values': [   5,
                                              4,
                                              230,
                                              43,
                                              2,
                                              7,
                                              6,
                                              2,
                                              15,
                                              7,
                                              14,
                                              3,
                                              6,
                                              1,
                                              12,
                                              1,
                                              154]},
    'feature_importances': array([  8.36213794e-03,   1.77643891e-02,   0.00000000e+00,
         2.44354801e-02,   2.60300437e-03,   2.93396550e-04,
         1.51044947e-02,   3.98525961e-02,   3.74674872e-03,
         2.43555965e-03,   2.01235226e-03,   0.00000000e+00,
         2.31968525e-03,   1.09078350e-02,   2.95809372e-03,
         0.00000000e+00,   0.00000000e+00,   1.18599588e-02,
         3.78931518e-03,   4.40357883e-03,   2.32076100e-01,
         9.83256387e-03,   2.23069414e-02,   4.65017474e-01,
         2.14928911e-02,   0.00000000e+00,   3.64804969e-02,
         3.38532565e-02,   2.10546188e-02,   5.03703082e-03]),
    'feature_importances_rank_idx': array([23, 20,  7, 26, 27,  3, 22, 24, 28,  1,  6, 17, 13, 21,  0, 29, 19,
       18,  8, 14,  4,  9, 12, 10,  5, 15, 11, 16,  2, 25]),
    'rf_metrics': {   'accuracy_score': 0.94736842105263153,
                      'classification_report': '             precision    '
                                               'recall  f1-score   support\n'
                                               '\n'
                                               '          0       0.92      '
                                               '0.86      0.89        14\n'
                                               '          1       0.95      '
                                               '0.98      0.97        43\n'
                                               '\n'
                                               'avg / total       0.95      '
                                               '0.95      0.95        57\n',
                      'confusion_matrix': array([[12,  2],
       [ 1, 42]]),
                      'f1_score': 0.96551724137931039,
                      'hamming_loss': 0.052631578947368418,
                      'log_loss': 1.8178583926244296,
                      'precision_score': 0.95454545454545459,
                      'recall_score': 0.97674418604651159,
                      'zero_one_loss': 0.052631578947368474},
    'rf_obj': RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_impurity_split=1e-07, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            n_estimators=3, n_jobs=1, oob_score=False, random_state=2018,
            verbose=0, warm_start=False)}

Now we can start setting up the RIT class

Overview

At it's core, the RIT is comprised of 3 main modules

  • FILTERING: Subsetting to either the 1's or the 0's
  • RANDOM SAMPLING: The path-nodes in a weighted manner, with/ without replacement, within tree/ outside tree
  • INTERSECTION: Intersecting the selected node paths in a systematic manner

For now we will just work with a single decision tree outputs


In [15]:
utils.prettyPrintDict(inp_dict = all_rf_outputs['rf_metrics'])


{   'accuracy_score': 0.94736842105263153,
    'classification_report': '             precision    recall  f1-score   '
                             'support\n'
                             '\n'
                             '          0       0.92      0.86      '
                             '0.89        14\n'
                             '          1       0.95      0.98      '
                             '0.97        43\n'
                             '\n'
                             'avg / total       0.95      0.95      '
                             '0.95        57\n',
    'confusion_matrix': array([[12,  2],
       [ 1, 42]]),
    'f1_score': 0.96551724137931039,
    'hamming_loss': 0.052631578947368418,
    'log_loss': 1.8178583926244296,
    'precision_score': 0.95454545454545459,
    'recall_score': 0.97674418604651159,
    'zero_one_loss': 0.052631578947368474}

In [16]:
all_rf_outputs['dtree0']


Out[16]:
{'all_leaf_node_classes': [1,
  0,
  1,
  0,
  1,
  0,
  1,
  1,
  0,
  1,
  1,
  0,
  0,
  1,
  0,
  1,
  1,
  1,
  0,
  1,
  0,
  0],
 'all_leaf_node_paths': [array([0, 1, 2, 3, 4, 5, 6]),
  array([0, 1, 2, 3, 4, 5, 7, 8]),
  array([0, 1, 2, 3, 4, 5, 7, 9]),
  array([ 0,  1,  2,  3,  4, 10]),
  array([ 0,  1,  2,  3, 11, 12]),
  array([ 0,  1,  2,  3, 11, 13, 14]),
  array([ 0,  1,  2,  3, 11, 13, 15]),
  array([ 0,  1,  2, 16, 17, 18, 19]),
  array([ 0,  1,  2, 16, 17, 18, 20, 21, 22]),
  array([ 0,  1,  2, 16, 17, 18, 20, 21, 23]),
  array([ 0,  1,  2, 16, 17, 18, 20, 24]),
  array([ 0,  1,  2, 16, 17, 25]),
  array([ 0,  1,  2, 16, 26]),
  array([ 0,  1, 27, 28, 29]),
  array([ 0,  1, 27, 28, 30]),
  array([ 0,  1, 27, 31, 32]),
  array([ 0,  1, 27, 31, 33, 34]),
  array([ 0,  1, 27, 31, 33, 35, 36]),
  array([ 0,  1, 27, 31, 33, 35, 37]),
  array([ 0, 38, 39, 40]),
  array([ 0, 38, 39, 41]),
  array([ 0, 38, 42])],
 'all_leaf_node_values': [array([[  0, 239]]),
  array([[1, 0]]),
  array([[0, 8]]),
  array([[2, 0]]),
  array([[0, 8]]),
  array([[7, 0]]),
  array([[0, 2]]),
  array([[ 0, 27]]),
  array([[3, 0]]),
  array([[0, 1]]),
  array([[ 0, 10]]),
  array([[2, 0]]),
  array([[7, 0]]),
  array([[0, 7]]),
  array([[1, 0]]),
  array([[0, 2]]),
  array([[0, 2]]),
  array([[0, 1]]),
  array([[19,  0]]),
  array([[0, 6]]),
  array([[2, 0]]),
  array([[155,   0]])],
 'all_leaf_nodes': [6,
  8,
  9,
  10,
  12,
  14,
  15,
  19,
  22,
  23,
  24,
  25,
  26,
  29,
  30,
  32,
  34,
  36,
  37,
  40,
  41,
  42],
 'all_leaf_paths_features': [array([23, 26,  1,  6, 13,  5]),
  array([23, 26,  1,  6, 13,  5,  9]),
  array([23, 26,  1,  6, 13,  5,  9]),
  array([23, 26,  1,  6, 13]),
  array([23, 26,  1,  6, 22]),
  array([23, 26,  1,  6, 22, 13]),
  array([23, 26,  1,  6, 22, 13]),
  array([23, 26,  1,  3, 27, 27]),
  array([23, 26,  1,  3, 27, 27, 19, 17]),
  array([23, 26,  1,  3, 27, 27, 19, 17]),
  array([23, 26,  1,  3, 27, 27, 19]),
  array([23, 26,  1,  3, 27]),
  array([23, 26,  1,  3]),
  array([23, 26,  3, 18]),
  array([23, 26,  3, 18]),
  array([23, 26,  3,  8]),
  array([23, 26,  3,  8, 22]),
  array([23, 26,  3,  8, 22,  4]),
  array([23, 26,  3,  8, 22,  4]),
  array([23, 26, 22]),
  array([23, 26, 22]),
  array([23, 26])],
 'all_uniq_leaf_paths_features': [array([ 1,  5,  6, 13, 23, 26]),
  array([ 1,  5,  6,  9, 13, 23, 26]),
  array([ 1,  5,  6,  9, 13, 23, 26]),
  array([ 1,  6, 13, 23, 26]),
  array([ 1,  6, 22, 23, 26]),
  array([ 1,  6, 13, 22, 23, 26]),
  array([ 1,  6, 13, 22, 23, 26]),
  array([ 1,  3, 23, 26, 27]),
  array([ 1,  3, 17, 19, 23, 26, 27]),
  array([ 1,  3, 17, 19, 23, 26, 27]),
  array([ 1,  3, 19, 23, 26, 27]),
  array([ 1,  3, 23, 26, 27]),
  array([ 1,  3, 23, 26]),
  array([ 3, 18, 23, 26]),
  array([ 3, 18, 23, 26]),
  array([ 3,  8, 23, 26]),
  array([ 3,  8, 22, 23, 26]),
  array([ 3,  4,  8, 22, 23, 26]),
  array([ 3,  4,  8, 22, 23, 26]),
  array([22, 23, 26]),
  array([22, 23, 26]),
  array([23, 26])],
 'leaf_nodes_depths': [6,
  7,
  7,
  5,
  5,
  6,
  6,
  6,
  8,
  8,
  7,
  5,
  4,
  4,
  4,
  4,
  5,
  6,
  6,
  3,
  3,
  2],
 'max_node_depth': 8,
 'n_nodes': 43,
 'node_features_idx': array([23, 26,  1,  6, 13,  5, 28,  9, 28, 28, 28, 22, 28, 13, 28, 28,  3,
        27, 27, 28, 19, 17, 28, 28, 28, 28, 28,  3, 18, 28, 28,  8, 28, 22,
        28,  4, 28, 28, 26, 22, 28, 28, 28]),
 'num_features_used': 16,
 'tot_leaf_node_values': [239,
  1,
  8,
  2,
  8,
  7,
  2,
  27,
  3,
  1,
  10,
  2,
  7,
  7,
  1,
  2,
  2,
  1,
  19,
  6,
  2,
  155]}

Get the leaf node 1's paths

Get the unique feature paths where the leaf node predicted class is just 1


In [17]:
uniq_feature_paths = all_rf_outputs['dtree0']['all_uniq_leaf_paths_features']
leaf_node_classes  = all_rf_outputs['dtree0']['all_leaf_node_classes']
ones_only = [i for i, j in zip(uniq_feature_paths, leaf_node_classes) 
               if j == 1]
ones_only


Out[17]:
[array([ 1,  5,  6, 13, 23, 26]),
 array([ 1,  5,  6,  9, 13, 23, 26]),
 array([ 1,  6, 22, 23, 26]),
 array([ 1,  6, 13, 22, 23, 26]),
 array([ 1,  3, 23, 26, 27]),
 array([ 1,  3, 17, 19, 23, 26, 27]),
 array([ 1,  3, 19, 23, 26, 27]),
 array([ 3, 18, 23, 26]),
 array([ 3,  8, 23, 26]),
 array([ 3,  8, 22, 23, 26]),
 array([ 3,  4,  8, 22, 23, 26]),
 array([22, 23, 26])]

In [18]:
print("Number of leaf nodes", len(all_rf_outputs['dtree0']['all_uniq_leaf_paths_features']), sep = ":\n")
print("Number of leaf nodes with 1 class", len(ones_only), sep = ":\n")


Number of leaf nodes:
22
Number of leaf nodes with 1 class:
12

In [19]:
# Just pick the last seven cases, we are going to manually construct
# binary RIT of depth 3 i.e. max 2**3 -1 = 7 intersecting nodes
ones_only_seven = ones_only[-7:]
ones_only_seven


Out[19]:
[array([ 1,  3, 17, 19, 23, 26, 27]),
 array([ 1,  3, 19, 23, 26, 27]),
 array([ 3, 18, 23, 26]),
 array([ 3,  8, 23, 26]),
 array([ 3,  8, 22, 23, 26]),
 array([ 3,  4,  8, 22, 23, 26]),
 array([22, 23, 26])]

In [34]:
# Construct a binary version of the RIT manually!
# This should come in useful for unit tests!
node0 = ones_only_seven[-1]
node1 = np.intersect1d(node0, ones_only_seven[-2])
node2 = np.intersect1d(node1, ones_only_seven[-3])
node3 = np.intersect1d(node1, ones_only_seven[-4])
node4 = np.intersect1d(node0, ones_only_seven[-5])
node5 = np.intersect1d(node4, ones_only_seven[-6])
node6 = np.intersect1d(node4, ones_only_seven[-7])

intersected_nodes_seven = [node0, node1, node2, node3, node4, node5, node6]

for idx, node in enumerate(intersected_nodes_seven):
    print("node" + str(idx), node)


node0 [22 23 26]
node1 [22 23 26]
node2 [22 23 26]
node3 [23 26]
node4 [23 26]
node5 [23 26]
node6 [23 26]

In [37]:
rit_output = reduce(np.union1d, (node2, node3, node5, node6))
rit_output


Out[37]:
array([22, 23, 26])