In [3]:

# %load ../../../preconfig.py
%matplotlib inline

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(color_codes=True)
#sns.set(font='SimHei')
#plt.rcParams['axes.grid'] = False

import numpy as np

import pandas as pd
pd.options.display.max_rows = 20

#import sklearn

#import itertools

import logging
logger = logging.getLogger()



# 决策树简介和 Python 实现

#### 0. 基本介绍

1. 如何分割样本？

2. 如何评价子集的纯净度？

3. 如何找到单个最佳的分割点，其子集最为纯净？

4. 如何找到最佳的分割点序列，其最终分割子集总体最为纯净？

#### 加载数据



In [4]:

from sklearn.datasets import load_iris
data = load_iris()




In [5]:

# 准备特征数据
X = pd.DataFrame(data.data,
columns=["sepal_length", "sepal_width", "petal_length", "petal_width"])
X.head(2)




Out[5]:

sepal_length
sepal_width
petal_length
petal_width

0
5.1
3.5
1.4
0.2

1
4.9
3.0
1.4
0.2




In [6]:

# 准备标签数据
y = pd.DataFrame(data.target, columns=['target'])
y.replace(to_replace=range(3), value=data.target_names, inplace=True)
y.head(3)




Out[6]:

target

0
setosa

1
setosa

2
setosa




In [7]:

# 组建样本 [特征，标签]
samples = pd.concat([X, y], axis=1) #, keys=["x", "y"])
samples.head(3)




Out[7]:

sepal_length
sepal_width
petal_length
petal_width
target

0
5.1
3.5
1.4
0.2
setosa

1
4.9
3.0
1.4
0.2
setosa

2
4.7
3.2
1.3
0.2
setosa



#### 1.0 如何分割样本

\begin{align} X = \begin{cases} X_l, \ \text{if } X[f] < t \\ X_r, \ \text{if } X[f] \geq t \end{cases} \end{align}


In [8]:

def splitter(samples, feature, threshold):
# 按特征 f 和阈值 t 分割样本

left_nodes = samples.query("{f} < {t}".format(f=feature, t=threshold))
right_nodes = samples.query("{f} >= {t}".format(f=feature, t=threshold))

return {"left_nodes": left_nodes, "right_nodes": right_nodes}




In [9]:

split = splitter(samples, "sepal_length", 5)

# 左子集
x_l = split["left_nodes"].loc[:, "target"].value_counts()
x_l




Out[9]:

setosa        20
versicolor     1
virginica      1
Name: target, dtype: int64




In [10]:

# 右子集
x_r = split["right_nodes"].loc[:, "target"].value_counts()
x_r




Out[10]:

virginica     49
versicolor    49
setosa        30
Name: target, dtype: int64



#### 2. 如何评价子集的纯净度？



In [11]:

def calc_class_proportion(node):
# 计算各标签在集合中的占比

y = node["target"]
return y.value_counts() / y.count()




In [12]:

calc_class_proportion(split["left_nodes"])




Out[12]:

setosa        0.909091
versicolor    0.045455
virginica     0.045455
Name: target, dtype: float64




In [13]:

calc_class_proportion(split["right_nodes"])




Out[13]:

virginica     0.382812
versicolor    0.382812
setosa        0.234375
Name: target, dtype: float64



$$\hat{p}_{m k} = \frac{1}{N_m} \displaystyle \sum_{x_i \in R_m} I(y_i = k)$$

##### 1. Misclassification error



In [14]:

def misclassification_error(node):
p_mk = calc_class_proportion(node)

return 1 - p_mk.max()




In [15]:

misclassification_error(split["left_nodes"])




Out[15]:

0.090909090909090939




In [16]:

misclassification_error(split["right_nodes"])




Out[16]:

0.6171875





In [17]:

binary_class = pd.Series(np.arange(0, 1.01, 0.01)).to_frame(name="p")
binary_class["1-p"] = 1 - binary_class["p"]
binary_class.head(3)




Out[17]:

p
1-p

0
0.00
1.00

1
0.01
0.99

2
0.02
0.98





In [18]:

binary_class["misclass"] = binary_class.apply(lambda x: 1 - x.max(), axis=1)
binary_class.plot(x="p", y="misclass")




Out[18]:

<matplotlib.axes._subplots.AxesSubplot at 0x114376080>



##### 2. Gini index

$$G(m) = \displaystyle \sum_{k \neq k'} p_{k m} p_{k' m} \, \overset{乘法分配律}{=} \sum_{k = 1}^{K} p_{k m} (1 - p_{k m})$$


In [19]:

def gini_index(node):
p_mk = calc_class_proportion(node)

return (p_mk * (1 - p_mk)).sum()




In [20]:

gini_index(split["left_nodes"])




Out[20]:

0.1694214876033058




In [21]:

gini_index(split["right_nodes"])




Out[21]:

0.6519775390625





In [22]:

binary_class["gini"] = (binary_class["p"] * binary_class["1-p"] * 2)
binary_class.plot(x="p", y="gini")




Out[22]:

<matplotlib.axes._subplots.AxesSubplot at 0x1143a2630>


##### 3. Cross-entropy

$$C(m) = \displaystyle \sum_{k=1}^K p_{m k} \log (1 / p_{m k}) \, = - \sum_{k=1}^K p_{m k} \log p_{m k}$$


In [23]:

def cross_entropy(node):
p_mk = calc_class_proportion(node)

return - (p_mk * p_mk.apply(np.log)).sum()




In [24]:

cross_entropy(split["left_nodes"])




Out[24]:

0.36764947740014225




In [25]:

cross_entropy(split["right_nodes"])




Out[25]:

1.075199711851601





In [26]:

x = binary_class[["p", "1-p"]]
binary_class["cross_entropy"] = -(x * np.log(x)).sum(axis=1)
binary_class.plot(x="p", y="cross_entropy")




Out[26]:

<matplotlib.axes._subplots.AxesSubplot at 0x116c16668>





In [27]:

binary_class.plot(x="p", y=["misclass", "gini", "cross_entropy"])




Out[27]:

<matplotlib.axes._subplots.AxesSubplot at 0x116dacb00>





In [28]:

binary_class["cross_entropy_scaled"] = binary_class["cross_entropy"] / binary_class["cross_entropy"].max() * 0.5
binary_class.plot(x="p", y=["misclass", "gini", "cross_entropy_scaled"], ylim=[0,0.55])




Out[28]:

<matplotlib.axes._subplots.AxesSubplot at 0x116d3d4a8>



#### 3. 如何找到单个最佳的分割点，其子集最为纯净？

1. 对于单次分割，分割前和分割后，集合的纯净度提升了多少？

2. 给定一个特征，纯净度提升最大的阈值是多少？

3. 对于多个特征，哪一个特征的最佳阈值对纯净度提升最大？

##### 3.1 对于单次分割，分割前和分割后，集合的纯净度提升了多少？

$$G(m) - G(m_l) - G(m_r)$$


In [29]:

def calc_impurity_measure(node, feathure, threshold, measure, min_nodes=5):
child = splitter(node, feathure, threshold)
left = child["left_nodes"]
right = child["right_nodes"]

if left.shape[0] <= min_nodes or right.shape[0] <= min_nodes:
return 0

impurity = pd.DataFrame([],
columns=["score", "rate"],
index=[])

impurity.loc["all"] = [measure(node), node.shape[0]]
impurity.loc["left"] = [-measure(left), left.shape[0]]
impurity.loc["right"] = [-measure(right), right.shape[0]]

impurity["rate"] /= impurity.at["all", "rate"]

logger.info(impurity)

return (impurity["score"] * impurity["rate"]).sum()




In [30]:

calc_impurity_measure(samples, "sepal_length", 5, gini_index)




Out[30]:

0.08546401515151514




In [31]:

calc_impurity_measure(samples, "sepal_length", 1, gini_index)




Out[31]:

0


##### 3.2. 给定一个特征，纯净度提升最大的阈值是多少？



In [32]:

def find_best_threshold(node, feature, measure):
threshold_candidates = node[feature].quantile(np.arange(0, 1, 0.2))

res = pd.Series([], name=feature)
for t in threshold_candidates:
res[t] = calc_impurity_measure(node, feature, t, measure)

logger.info(res)

if res.max() == 0:
return None
else:
return res.argmax()




In [33]:

find_best_threshold(samples, "sepal_width", gini_index)




Out[33]:

3.3999999999999999




In [34]:

find_best_threshold(samples, "sepal_length", gini_index)




Out[34]:

5.5999999999999996


##### 3.3. 对于多个特征，哪一个特征的最佳阈值对纯净度提升最大？



In [35]:

def find_best_split(node, measure):
if node["target"].unique().shape[0] <= 1:
return None

purity_gain = pd.Series([], name="feature")

for f in node.drop("target", axis=1).columns:
purity_gain[f] = find_best_threshold(node, f, measure)

if pd.isnull(purity_gain.max()):
return None
else:
best_split = {"feature": purity_gain.argmax(), "threshold": purity_gain.max()}
best_split["child"] = splitter(node, **best_split)

return best_split




In [36]:

best_split = find_best_split(samples, gini_index)
[best_split[x] for x in ["feature", "threshold"]]




Out[36]:

['sepal_length', 5.5999999999999996]



#### 4. 如何找到最佳的分割点序列，其最终分割子集总体最为纯净？



In [37]:

class BinaryNode:
def __init__(self, samples, max_depth, measure=gini_index):
self.samples = samples
self.max_depth = max_depth
self.measure = measure

self.is_leaf = False
self.class_ = None

self.left = None
self.right = None

self.best_split = None

def split(self, depth):
if depth > self.max_depth:
self.is_leaf = True
self.class_ = self.samples["target"].value_counts().argmax()
return

best_split = find_best_split(self.samples, self.measure)
if pd.isnull(best_split):
self.is_leaf = True
self.class_ = self.samples["target"].value_counts().argmax()
return

self.best_split = best_split
left = self.best_split["child"]["left_nodes"]
self.left = BinaryNode(left.drop(best_split["feature"], axis=1), self.max_depth)

right = self.best_split["child"]["right_nodes"]
self.right = BinaryNode(right.drop(best_split["feature"], axis=1), self.max_depth)

# 先序深度优先
self.left.split(depth+1)
self.right.split(depth+1)




In [38]:

binaryNode = BinaryNode(samples, 3)
binaryNode.split(0)




In [39]:

def show(node, depth):
if node.left:
show(node.left, depth+1)

if node.is_leaf:
print("{}{}".format("\t"*(depth+2), node.class_))
return
else:
print("{}{}: {}".format("\t"*depth,
node.best_split["feature"],
node.best_split["threshold"]))
if node.right:
show(node.right, depth+1)




In [40]:

show(binaryNode, 0)




versicolor
sepal_width: 2.8200000000000003
setosa
petal_length: 1.6
setosa
petal_width: 0.4
setosa
sepal_length: 5.6
versicolor
sepal_width: 3.1
versicolor
petal_length: 4.8
versicolor
petal_width: 1.8
virginica
sepal_width: 2.9
virginica
petal_width: 2.0
virginica