In [1]:
x = [   [0,0], [0,1], [0,1] ,
     [1,0], [1,1], [1,2]
]

y = [1,0,1,1,0,1]

col_name = ["gender", "age"]

In [2]:
from collections import Counter
from math import log

class DT_hzj():
    
    def __init__(self):
        print "求豆麻袋"
    
    def set_metric(self, metric):
        if metric == "gini":
            self.imp = DT_hzj.gini
        elif metric == 'entropy':
            self.imp = DT_hzj.entropy
        self.metric = metric
        
    @staticmethod
    def gini(y):
        '''
        ttl: 样本总数
        cnt: 样本的类分布
        '''
        ttl = float(len(y)) 
        cnt = Counter(y)
        result = 1
        for key in cnt:
            result -=  (cnt[key]/ttl)**2
        return result
    
    @staticmethod
    def entropy(y):
        ttl = float(len(y))
        cnt = Counter(y)
        result = 0
        for key in cnt:
            pi = cnt[key]/ttl
            result -= pi*log(pi)
        return result
    
    def DT(self,x,y, indices, modalities, metric="gini"):
        ''' 对(x,y)的子集根据metric进行一次分支

        Args:
            (x,y) 数据集
            indices(iterables) :- 目前叶子上的数据点的指标
            modalities(iterables) :- [[modality]],每种特征的取值空间
            metric(str) :- 不纯度度量
        '''

        col_num = len(x[0])
        num_samples = float(len(indices))

        x_sample = [x[index] for index in indices]
        y_sample = [y[index] for index in indices]

            
        impurity = self.imp(y_sample)

        impurity_min = 1

        leaves = []

        for col in range(col_num):
            impurity_A = 0
            leaves_candidate = []

            for modality in modalities[col]:

                sample_k = [ (i,value) for i,value in enumerate(y_sample) if x_sample[i][col]==modality ]
                y_k = [ele[1] for ele in sample_k]
                if len(y_k) > 0: # 避免生成空叶子
                    gini_k = self.imp(y_k)
                    impurity_A += len(y_k)/num_samples * gini_k

                    leaves_candidate.append( { "sample" : [  indices[ele[0]] for ele in sample_k], "modality" : modality, "impurity" : gini_k}  )

            if impurity_A < impurity_min:
                impurity_min =  impurity_A
                A = col_name[col]
                leaves = leaves_candidate

        impurity_reduction = impurity- impurity_min
        return impurity_reduction, A, leaves
    
    @staticmethod    
    def stop_cond(leaves):
        ''' 终止条件:叶子结点都为单一类型
        '''
        for leaf in leaves:
            if leaf["impurity"] > 0 :
                return False

        return True

    @staticmethod 
    def majority(y, samples):
        ''' 得到y子列中的多数
        '''
        y_sample = [y[i] for i in samples]
        return Counter(y_sample).most_common(1)[0][0]
    
    def DT_total(self, x,y):
        ''' 从(x,y)训练集中得到叶子和非叶子结点的列表
        
        叶子和非叶子结点都是字典
        
        - 非叶子结点的字段包括['attr', 'sample', 'modality', 'impurity', 'children', 'name']
          根节点不包括'modality'
        
        - 叶子结点的字段包括['modality','name','sample', 'impurity']
        '''
        # 保证x和y数量一致
        try :
            assert len(x) == len(y)
        except AssertionError:
            raise Exception("The length of x and y must be identical")


        self.col_num = len(x[0])

        self.modalities =[ set() for _ in range(self.col_num)]

        # 获取全部特征的取值空间以及确保x没有缺失值
        try:
            for x_i in x:
                assert len(x_i) == self.col_num
                for col in range(self.col_num):
                    self.modalities[col].add(x_i[col])
        except AssertionError:
            raise Exception("The length of x must be equal")

        ttl = len(x) # 样本大小
        leaves = [] # 叶子结点
        non_leaves = [] # 非叶子结点,每个非叶子结点都有key来表示孩子结点

        root = {
            "impurity" : self.imp(y),
            "sample" : [i for i in range(ttl)],
            "name" : 'v0'
        }
        i = 0

        leaves.append(root)

        while not DT_hzj.stop_cond(leaves): # 不满足停止条件时
            leaf2split = dict()
            childrenGenerated = []
            attrChosen = ''

            impu_reduc_max = 0
            for leaf in leaves:
                if leaf["impurity"] > 0:
                    samples = leaf['sample']

                    impu_reduc, attr, children = self.DT(x, y, indices=samples, modalities=self.modalities)
                    impu_reduc_tmp = impu_reduc * len(samples)
                    if impu_reduc_tmp > impu_reduc_max:
                        leaf2split = leaf
                        childrenGenerated = filter(None,children)
                        attrChosen = attr

            if not leaf2split: # 若没有可以降低不纯度的split方法,则停止split
                break
            leaves.remove(leaf2split)


            leaf2split["children"] = []
            for child in childrenGenerated:
                i += 1
                name = "v{}".format(i)
                child["name"] = name
                leaf2split["children"].append(name)
                leaf2split["attr"] = attrChosen

            leaves.extend(childrenGenerated)
            non_leaves.append(leaf2split)
        self.leaves = leaves
        self.non_leaves = non_leaves

        return leaves, non_leaves
    
    @staticmethod 
    def get_subtree(leaves, non_leaves, root='v0'):
        ''' 从叶子结点和非叶子结点中提取出规则
        
        规则用字典的形式表示 
        {
            'attr1=modality1': y, 
            'attr1=modality2' : {
                'attr2=modality1' : y,
                'attr2=modality2' : {
                
                }
            }
        }
        
        '''
        subtree = dict()

        leaves_name = [node["name"] for node in leaves]
        nonLeaves_name = [node["name"] for node in non_leaves]

        index = nonLeaves_name.index(root)
        node = non_leaves[index]
        branches = node["children"]
        attr = node["attr"]
        for child_name in branches:
            if child_name in nonLeaves_name:
                child = non_leaves[ nonLeaves_name.index(child_name) ]
            else:
                child = leaves[ leaves_name.index(child_name)]
            modality = child["modality"]

            if child_name not in leaves_name: # 子结点还在分支
                subtree["{}={}".format(attr, modality)] = DT_hzj.get_subtree(leaves, non_leaves, child_name)

            else :
                subtree["{}={}".format(attr, modality)] = DT_hzj.majority(y, child['sample'])

        return subtree

    def fit(self,x,y, col_name, metric = 'gini'):
        self.set_metric(metric)
        self.col_name = col_name
        self.col_num = len(col_name)
        leaves, non_leaves = self.DT_total(x,y)
        self.tree = DT_hzj.get_subtree(leaves, non_leaves)    

    def pred(self, x_test):

        try:
            assert self.col_num==len(x_test)
        except AssertionError:
            raise Exception("x is not of the right dimensionality")

        x_conv = ["{}={}".format(self.col_name[index], x_test[index])  for index in range(self.col_num)]
        print "new data point is ",
        print x_conv
        
        i = 1
        subtree = self.tree
        
        print i,' step : ',
        print subtree
        i += 1
        while isinstance(subtree, dict):
            reach = False
            for x_i in x_conv:
                if x_i in subtree:
                    subtree = subtree[x_i]
                    print i,' step : ',
                    print subtree
                    i += 1
                    reach = True
                    break
            if not reach:
                subtree = "can't predict"

        self.prediction = subtree

In [3]:
clf = DT_hzj()


求豆麻袋

In [4]:
clf.fit(x,y,col_name, metric='entropy')

In [5]:
clf.pred(x_test = [1,1])


new data point is  ['gender=1', 'age=1']
1  step :  {'age=2': 1, 'age=1': {'gender=0': 0, 'gender=1': 0}, 'age=0': 1}
2  step :  {'gender=0': 0, 'gender=1': 0}
3  step :  0

In [6]:
clf.pred([0,0])


new data point is  ['gender=0', 'age=0']
1  step :  {'age=2': 1, 'age=1': {'gender=0': 0, 'gender=1': 0}, 'age=0': 1}
2  step :  1

In [7]:
clf.tree


Out[7]:
{'age=0': 1, 'age=1': {'gender=0': 0, 'gender=1': 0}, 'age=2': 1}

In [8]:
clf.col_name


Out[8]:
['gender', 'age']

In [9]:
clf.leaves


Out[9]:
[{'impurity': 0.0, 'modality': 0, 'name': 'v1', 'sample': [0, 3]},
 {'impurity': 0.0, 'modality': 2, 'name': 'v3', 'sample': [5]},
 {'impurity': 0.6931471805599453,
  'modality': 0,
  'name': 'v4',
  'sample': [1, 2]},
 {'impurity': 0.0, 'modality': 1, 'name': 'v5', 'sample': [4]}]

In [10]:
clf.non_leaves


Out[10]:
[{'attr': 'age',
  'children': ['v1', 'v2', 'v3'],
  'impurity': 0.6365141682948128,
  'name': 'v0',
  'sample': [0, 1, 2, 3, 4, 5]},
 {'attr': 'gender',
  'children': ['v4', 'v5'],
  'impurity': 0.6365141682948128,
  'modality': 1,
  'name': 'v2',
  'sample': [1, 2, 4]}]