In [1]:
%pylab
%matplotlib inline


Using matplotlib backend: Qt4Agg
Populating the interactive namespace from numpy and matplotlib

In [2]:
class Sample:
    def __init__(self):
        self.data={}
        self.lable=None
    def __setitem__(self,key,value):
        self.data[key]=value
    def __getitem__(self,key):
        return self.data[key]
    def keys(self):
        return self.data.keys()
    def __str__(self):
        return str(self.data)+'--lable-'+str(self.lable)

In [3]:
def lable0(sample):
    return sample.lable==0
def lable1(sample):
    return sample.lable==1

In [104]:
import copy
class Dataset:
    def __init__(self):
        self.data=[]
        self.ks=[]
    def __getitem__(self,index):
        return self.data[index]
    def append(self,sample):
        self.data.append(sample)
        for k in sample.data.keys():
            if k not in self.ks:
                self.ks.append(k)
    def __str__(self):
        listOfSample=[]
        for sample in self.data:
            SampleOne = str(sample)+'\n'
            listOfSample.append(SampleOne)
        return '\n'.join(listOfSample)+'\n'
    def __len__(self):
        return len(self.data)
    def count(self,pred):
        cot=0
        for sample in self:
            if(pred(sample)):
                cot+=1
        return cot
    def split(self,key,val):
#         split dataset according to feature 
        dataset = Dataset() 
        for sample in self:
            if(sample[key]==val):
                dataset.append(sample)
        dataset.ks=copy.deepcopy(self.ks) 
#         print(dataset.ks)
        dataset.ks.remove(key)
        return dataset
    
    def canSplit(self):
        if(len(self.ks)==0):
            return False
        return True
    def valsIssame(self):
        first = self[0]
        for sample in self:
            for key in self.ks:
                if(sample[key]!=first[key]):
                    return False
        return True
#     def lableIssame(self):
#         first = self[0]
#         for sample in self:
#             if(first.lable!=sample.lable):
#                 return Fa
#     all the same so that no necessary to split
    def needSplit(self):
        if(self.valsIssame(self)):
            return False
        else:
            return True
    def majorityVote(self):
        n0 = self.count(label0)
        n1 = self.count(label1)
        if n0<n1 :
            return 1
        else:
            return 0

In [91]:
class Filler:
    def filler(self):
        pass
import random
class RandomFiller(Filler):
    def __init__(self):
        self.begin=0
        self.end=4
        self.nfeat=2
        self.nsample=20
    def fillOne(self,sample):
        for i in range(self.nfeat):
            sample[i]=random.randint(self.begin,self.end)
        sample.lable=random.randint(0,1)
    def fillDataset(self,dataset):
        for i in range(self.nsample):
            sample=Sample()
            self.fillOne(sample)
            dataset.append(sample)

In [92]:
sample = Sample()
f=RandomFiller()
f.fillOne(sample)
print(sample)
print(sample.data.keys())


{0: 1, 1: 0}--lable-1
[0, 1]

In [93]:
dataset = Dataset()
f.fillDataset(dataset)
print(dataset)


{0: 0, 1: 0}--lable-1

{0: 4, 1: 4}--lable-0

{0: 0, 1: 4}--lable-0

{0: 4, 1: 1}--lable-1

{0: 0, 1: 0}--lable-0

{0: 2, 1: 1}--lable-1

{0: 1, 1: 1}--lable-0

{0: 1, 1: 1}--lable-0

{0: 1, 1: 1}--lable-1

{0: 3, 1: 2}--lable-1

{0: 2, 1: 0}--lable-0

{0: 0, 1: 2}--lable-1

{0: 2, 1: 0}--lable-0

{0: 0, 1: 1}--lable-1

{0: 4, 1: 1}--lable-0

{0: 4, 1: 0}--lable-1

{0: 0, 1: 3}--lable-1

{0: 3, 1: 3}--lable-0

{0: 4, 1: 1}--lable-0

{0: 1, 1: 1}--lable-0



In [94]:
print(dataset.ks)


[0, 1]

In [105]:
import math
def calEntropy(dataset):
    nl0 = dataset.count(lable0)
    nl1 = dataset.count(lable1)
    if nl0==0 or nl1==0:
        return 0
    n = len(dataset)
    p0 = float(nl0)/n
    p1 = float(nl1)/n
    print('p0 : '+str(p0))
    print('p1 : '+str(p1))
    return -(p0 * math.log(p0)+p1 * math.log(p1))
def getVals(dataset,key):
    sets=[]
    for sample in dataset:
        if sample[key] not in sets:
            sets.append(sample[key])
    return sets
def getNumForVals(dataset,key,val):
    nums=0
    for sample in dataset:
        if(sample[key]==val):
            nums+=1
    return nums
def calGain(dataset,key):
    newEntropy=0
    sets = getVals(dataset,key)
    oldEntropy = calEntropy(dataset)
    n = len(dataset)
    for val in sets:
        newDataset = dataset.split(key,val)
        nums = getNumForVals(dataset,key,val)
        p = float(nums)/n
        newEntropy += p * calEntropy(newDataset)
    return oldEntropy - newEntropy

In [117]:
#test data.split()
print('---key0----')
sets = getVals(dataset,0)
for val in sets:
    _data = dataset.split(0,val)
    print(_data)
# print('---key1-----')
# sets = getVals(dataset,1)
# for val in sets:
#     _data = dataset.split(1,val)
#     print(_data)
# # test math function
# entropy = calEntropy(dataset)
# print(entropy) #pass
# sets = getVals(dataset,0) 
# print(sets) #pass
# print(calEntropy(_data))
print(calGain(dataset,0))
print(calGain(dataset,1))


---key0----
{0: 0, 1: 0}--lable-1

{0: 0, 1: 4}--lable-0

{0: 0, 1: 0}--lable-0

{0: 0, 1: 2}--lable-1

{0: 0, 1: 1}--lable-1

{0: 0, 1: 3}--lable-1


{0: 4, 1: 4}--lable-0

{0: 4, 1: 1}--lable-1

{0: 4, 1: 1}--lable-0

{0: 4, 1: 0}--lable-1

{0: 4, 1: 1}--lable-0


{0: 2, 1: 1}--lable-1

{0: 2, 1: 0}--lable-0

{0: 2, 1: 0}--lable-0


{0: 1, 1: 1}--lable-0

{0: 1, 1: 1}--lable-0

{0: 1, 1: 1}--lable-1

{0: 1, 1: 1}--lable-0


{0: 3, 1: 2}--lable-1

{0: 3, 1: 3}--lable-0


p0 : 0.55
p1 : 0.45
p0 : 0.333333333333
p1 : 0.666666666667
p0 : 0.6
p1 : 0.4
p0 : 0.666666666667
p1 : 0.333333333333
p0 : 0.75
p1 : 0.25
p0 : 0.5
p1 : 0.5
0.0516727742489
p0 : 0.55
p1 : 0.45
p0 : 0.6
p1 : 0.4
p0 : 0.555555555556
p1 : 0.444444444444
p0 : 0.5
p1 : 0.5
0.141438469436

In [ ]:


In [119]:
class DTNode:
    def __init__(self):
        self.key=None
        self.val=None
        self.data=Dataset()
#         self.parent=None
        self.children=[]

In [120]:
class Model:
    def build(self):
        pass
    
class DTModel(Model):
    def __init__(self):
        DTRoot = DTNode()
#         set feature of root to -1 
        DTRoot.key=-1
        DTRoot.val=-1
        self.DTRoot=DTRoot
    def build(self,parentNode):
        data = parentNode.data
        if(not data.needSplit()):
            return
        maxGain=0
        for key in data.ks:
            nowGain=calGain(data,key)
            if(nowGain > maxGain):
                maxGain = nowGain
                maxKey = key
        sets = getVals(data,maxKey)
        for val in sets:
            node = DTNode()
            node.key = maxKey
            node.val = val
            childrenData = []
            for sample in data:
                if(sample[key]==val):
                    childrenData.append(sample)
            node.data = childrenData
            parentNode.children.append(node)
            build(node)
    def trainModel(self,dataset):
        print('---train decisionTree Model---\n')
        self.DTRoot.data = dataset
        build(self.DTRoot)

In [118]:
decisionTree = DTModel()
decisionTree.trainModel(dataset)


---train decisionTree Model---

p0 : 0.55
p1 : 0.45
p0 : 0.333333333333
p1 : 0.666666666667
p0 : 0.6
p1 : 0.4
p0 : 0.666666666667
p1 : 0.333333333333
p0 : 0.75
p1 : 0.25
p0 : 0.5
p1 : 0.5
p0 : 0.55
p1 : 0.45
p0 : 0.6
p1 : 0.4
p0 : 0.555555555556
p1 : 0.444444444444
p0 : 0.5
p1 : 0.5
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-118-d499a2d9b8b6> in <module>()
      1 decisionTree = DTModel()
----> 2 decisionTree.trainModel(dataset)

<ipython-input-111-4657fc24f948> in trainModel(self, dataset)
     35         print('---train decisionTree Model---\n')
     36         self.DTRoot.data = dataset
---> 37         self.build(self.DTRoot)

<ipython-input-111-4657fc24f948> in build(self, parentNode)
     31             node.data = childrenData
     32             parentNode.children.append(node)
---> 33             self.build(node)
     34     def trainModel(self,dataset):
     35         print('---train decisionTree Model---\n')

<ipython-input-111-4657fc24f948> in build(self, parentNode)
     12     def build(self,parentNode):
     13         data = parentNode.data
---> 14         if(not data.needSplit()):
     15             return
     16         maxGain=0

AttributeError: 'list' object has no attribute 'needSplit'

In [ ]: