In [1]:
x = [ [0,0], [0,1], [0,1] ,
[1,0], [1,1], [1,2]
]
y = [1,0,1,1,0,1]
col_name = ["gender", "age"]
In [2]:
from collections import Counter
from math import log
class DT_hzj():
def __init__(self):
print "求豆麻袋"
def set_metric(self, metric):
if metric == "gini":
self.imp = DT_hzj.gini
elif metric == 'entropy':
self.imp = DT_hzj.entropy
self.metric = metric
@staticmethod
def gini(y):
'''
ttl: 样本总数
cnt: 样本的类分布
'''
ttl = float(len(y))
cnt = Counter(y)
result = 1
for key in cnt:
result -= (cnt[key]/ttl)**2
return result
@staticmethod
def entropy(y):
ttl = float(len(y))
cnt = Counter(y)
result = 0
for key in cnt:
pi = cnt[key]/ttl
result -= pi*log(pi)
return result
def DT(self,x,y, indices, modalities, metric="gini"):
''' 对(x,y)的子集根据metric进行一次分支
Args:
(x,y) 数据集
indices(iterables) :- 目前叶子上的数据点的指标
modalities(iterables) :- [[modality]],每种特征的取值空间
metric(str) :- 不纯度度量
'''
col_num = len(x[0])
num_samples = float(len(indices))
x_sample = [x[index] for index in indices]
y_sample = [y[index] for index in indices]
impurity = self.imp(y_sample)
impurity_min = 1
leaves = []
for col in range(col_num):
impurity_A = 0
leaves_candidate = []
for modality in modalities[col]:
sample_k = [ (i,value) for i,value in enumerate(y_sample) if x_sample[i][col]==modality ]
y_k = [ele[1] for ele in sample_k]
if len(y_k) > 0: # 避免生成空叶子
gini_k = self.imp(y_k)
impurity_A += len(y_k)/num_samples * gini_k
leaves_candidate.append( { "sample" : [ indices[ele[0]] for ele in sample_k], "modality" : modality, "impurity" : gini_k} )
if impurity_A < impurity_min:
impurity_min = impurity_A
A = col_name[col]
leaves = leaves_candidate
impurity_reduction = impurity- impurity_min
return impurity_reduction, A, leaves
@staticmethod
def stop_cond(leaves):
''' 终止条件:叶子结点都为单一类型
'''
for leaf in leaves:
if leaf["impurity"] > 0 :
return False
return True
@staticmethod
def majority(y, samples):
''' 得到y子列中的多数
'''
y_sample = [y[i] for i in samples]
return Counter(y_sample).most_common(1)[0][0]
def DT_total(self, x,y):
''' 从(x,y)训练集中得到叶子和非叶子结点的列表
叶子和非叶子结点都是字典
- 非叶子结点的字段包括['attr', 'sample', 'modality', 'impurity', 'children', 'name']
根节点不包括'modality'
- 叶子结点的字段包括['modality','name','sample', 'impurity']
'''
# 保证x和y数量一致
try :
assert len(x) == len(y)
except AssertionError:
raise Exception("The length of x and y must be identical")
self.col_num = len(x[0])
self.modalities =[ set() for _ in range(self.col_num)]
# 获取全部特征的取值空间以及确保x没有缺失值
try:
for x_i in x:
assert len(x_i) == self.col_num
for col in range(self.col_num):
self.modalities[col].add(x_i[col])
except AssertionError:
raise Exception("The length of x must be equal")
ttl = len(x) # 样本大小
leaves = [] # 叶子结点
non_leaves = [] # 非叶子结点,每个非叶子结点都有key来表示孩子结点
root = {
"impurity" : self.imp(y),
"sample" : [i for i in range(ttl)],
"name" : 'v0'
}
i = 0
leaves.append(root)
while not DT_hzj.stop_cond(leaves): # 不满足停止条件时
leaf2split = dict()
childrenGenerated = []
attrChosen = ''
impu_reduc_max = 0
for leaf in leaves:
if leaf["impurity"] > 0:
samples = leaf['sample']
impu_reduc, attr, children = self.DT(x, y, indices=samples, modalities=self.modalities)
impu_reduc_tmp = impu_reduc * len(samples)
if impu_reduc_tmp > impu_reduc_max:
leaf2split = leaf
childrenGenerated = filter(None,children)
attrChosen = attr
if not leaf2split: # 若没有可以降低不纯度的split方法,则停止split
break
leaves.remove(leaf2split)
leaf2split["children"] = []
for child in childrenGenerated:
i += 1
name = "v{}".format(i)
child["name"] = name
leaf2split["children"].append(name)
leaf2split["attr"] = attrChosen
leaves.extend(childrenGenerated)
non_leaves.append(leaf2split)
self.leaves = leaves
self.non_leaves = non_leaves
return leaves, non_leaves
@staticmethod
def get_subtree(leaves, non_leaves, root='v0'):
''' 从叶子结点和非叶子结点中提取出规则
规则用字典的形式表示
{
'attr1=modality1': y,
'attr1=modality2' : {
'attr2=modality1' : y,
'attr2=modality2' : {
}
}
}
'''
subtree = dict()
leaves_name = [node["name"] for node in leaves]
nonLeaves_name = [node["name"] for node in non_leaves]
index = nonLeaves_name.index(root)
node = non_leaves[index]
branches = node["children"]
attr = node["attr"]
for child_name in branches:
if child_name in nonLeaves_name:
child = non_leaves[ nonLeaves_name.index(child_name) ]
else:
child = leaves[ leaves_name.index(child_name)]
modality = child["modality"]
if child_name not in leaves_name: # 子结点还在分支
subtree["{}={}".format(attr, modality)] = DT_hzj.get_subtree(leaves, non_leaves, child_name)
else :
subtree["{}={}".format(attr, modality)] = DT_hzj.majority(y, child['sample'])
return subtree
def fit(self,x,y, col_name, metric = 'gini'):
self.set_metric(metric)
self.col_name = col_name
self.col_num = len(col_name)
leaves, non_leaves = self.DT_total(x,y)
self.tree = DT_hzj.get_subtree(leaves, non_leaves)
def pred(self, x_test):
try:
assert self.col_num==len(x_test)
except AssertionError:
raise Exception("x is not of the right dimensionality")
x_conv = ["{}={}".format(self.col_name[index], x_test[index]) for index in range(self.col_num)]
print "new data point is ",
print x_conv
i = 1
subtree = self.tree
print i,' step : ',
print subtree
i += 1
while isinstance(subtree, dict):
reach = False
for x_i in x_conv:
if x_i in subtree:
subtree = subtree[x_i]
print i,' step : ',
print subtree
i += 1
reach = True
break
if not reach:
subtree = "can't predict"
self.prediction = subtree
In [3]:
clf = DT_hzj()
In [4]:
clf.fit(x,y,col_name, metric='entropy')
In [5]:
clf.pred(x_test = [1,1])
In [6]:
clf.pred([0,0])
In [7]:
clf.tree
Out[7]:
In [8]:
clf.col_name
Out[8]:
In [9]:
clf.leaves
Out[9]:
In [10]:
clf.non_leaves
Out[10]: