In [1]:
# %load /Users/facai/Study/book_notes/preconfig.py
%matplotlib inline

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(color_codes=True)
sns.set(font='SimHei')
plt.rcParams['axes.grid'] = False

from IPython.display import SVG

def show_image(filename, figsize=None):
    if figsize:
        plt.figure(figsize=figsize)

    plt.imshow(plt.imread(filename))

树构建模块 _tree.* 详解

0. 大纲

此模块包含两种类,一种是二叉树的实现类 Tree,另一种是构建出整颗树的方法类 TreeBuilder。我们着重介绍构建方法类 TreeBuilder,最后会简要提及 Tree 的几个函数。


In [3]:
SVG("./res/uml/Model___tree_3.svg")


Out[3]:
«dataType»Node+left_child+right_child+feature+threshold+impurity+n_node_samples+weighted_n_node_samplesTree+_add_note()+_resize()+predict()+apply()+decision_path()+compute_feature_importances()TreeBuilder+splitter+min_samples_split+min_samples_leaf+min_weight_leaf+max_depth+build()+_check_input()+1..1+1..*DepthFirstTreeBuilder+build()_utils.Stack+push()+pop()+1..1BestFirstTreeBuilder+build()_utils.PriorityHeap+push()+pop()+1..1

1. 构建类

1.0 TreeBuilder

TreeBuilder 提供了接口方法 build 和一个具体的参数检查方法 _check_input,没太多说的。

1.1 DepthFirstTreeBuilder

DepthFirstTreeBuilder 是用类似先序遍历的方式生成整颗决策树,借助的数据结构是栈。

主要流程是:

  1. 从栈中抽出一个节点
  2. 计算分割点
    • 若达到叶子条件,不再处理;
    • 若不是叶子,则先压右叶子入栈,再压左叶子入栈。
  3. 若栈空,树构建完成。

主体代码如下:

195         with nogil:
 196             # push root node onto stack
 197             rc = stack.push(0, n_node_samples, 0, _TREE_UNDEFINED, 0, INFINITY, 0)
 198 #+--  4 lines: if rc == -1:-----------------------------------------------------------------
 202
 203             while not stack.is_empty():
 204                 stack.pop(&stack_record)
 205 #+--  9 lines: start = stack_record.start---------------------------------------------------
 214                 n_node_samples = end - start
 215                 splitter.node_reset(start, end, &weighted_n_node_samples)
 216
 217                 is_leaf = ((depth >= max_depth) or
 218                            (n_node_samples < min_samples_split) or
 219                            (n_node_samples < 2 * min_samples_leaf) or
 220                            (weighted_n_node_samples < min_weight_leaf))
 221
 222                 if first:
 223                     impurity = splitter.node_impurity()
 224                     first = 0
 225
 226                 is_leaf = is_leaf or (impurity <= MIN_IMPURITY_SPLIT)
 227
 228                 if not is_leaf:
 229                     splitter.node_split(impurity, &split, &n_constant_features)
 230                     is_leaf = is_leaf or (split.pos >= end)
 231
 232                 node_id = tree._add_node(parent, is_left, is_leaf, split.feature,
 233                                          split.threshold, impurity, n_node_samples,
 234                                          weighted_n_node_samples)
 235
 236 #+--  4 lines: if node_id == <SIZE_t>(-1):--------------------------------------------------
 240                 # Store value for all nodes, to facilitate tree/model
 241                 # inspection and interpretation
 242                 splitter.node_value(tree.value + node_id * tree.value_stride)
 243
 244                 if not is_leaf:
 245                     # Push right child on stack
 246                     rc = stack.push(split.pos, end, depth + 1, node_id, 0,
 247                                     split.impurity_right, n_constant_features)
 248 #+--  3 lines: if rc == -1:-----------------------------------------------------------------
 251                     # Push left child on stack
 252                     rc = stack.push(start, split.pos, depth + 1, node_id, 1,
 253                                     split.impurity_left, n_constant_features)
 254 #+--  2 lines: if rc == -1:-----------------------------------------------------------------
 256

1.2 BestFirstTreeBuilder

BestFristTreeBuilder 总是优先分割最混杂(不纯度最大)的节点,借助了最大堆的数据结构。处理流程和 DepthFirstTreeBuilder 大同小异,不赘述。

2. 实现类

2.0 Tree

sklearn 用数组实现了二叉树,我比较感兴趣的函数是计算特征重要性的 compute_feature_importances

这个函数的想法其实也很简单,就是遍历决策的中间节点,汇总各个特征对纯净度的贡献量。代码很短,很好理解。

1033     cpdef compute_feature_importances(self, normalize=True):
1034         """Computes the importance of each feature (aka variable)."""
1035 #+--  3 lines: cdef Node* left--------------------------------------------------------------
1038         cdef Node* node = nodes
1039         cdef Node* end_node = node + self.node_count
1040 #+--  3 lines: cdef double normalizer = 0.--------------------------------------------------
1043         cdef np.ndarray[np.float64_t, ndim=1] importances
1044 #+--  2 lines: importances = np.zeros((self.n_features,))-----------------------------------
1046
1047         with nogil:
1048             while node != end_node:
1049                 if node.left_child != _TREE_LEAF:
1050                     # ... and node.right_child != _TREE_LEAF:
1051                     left = &nodes[node.left_child]
1052                     right = &nodes[node.right_child]
1053
1054                     importance_data[node.feature] += (
1055                         node.weighted_n_node_samples * node.impurity -
1056                         left.weighted_n_node_samples * left.impurity -
1057                         right.weighted_n_node_samples * right.impurity)
1058                 node += 1
1059
1060         importances /= nodes[0].weighted_n_node_samples
1061
1062         if normalize:
1063             normalizer = np.sum(importances)
1064 #+--  3 lines: if normalizer > 0.0:---------------------------------------------------------
1067                 importances /= normalizer
1068
1069         return importances

另外,在 tree.py 模块里决策树分类结果也可以计算出概率值,这个概率其实是预测类的样本在此叶子的占比。

整个计算路径是:

  1. _tree.py:Tree.predict 通过 _tree.py:Tree.apply 找到叶子节点,结合 _tree.py:Tree._get_value_ndarray 得到所在叶子节点的各个类统计数。

  2. tree.py:DecisionTreeClassifier.predict_proba 计算占比。

总结

本文介绍了两种决策树的构建方法,和计算特征重要性与结果预测概率的方法。


In [ ]: