In [1]:
%pylab
%matplotlib inline
In [2]:
class Sample:
def __init__(self):
self.data={}
self.lable=None
def __setitem__(self,key,value):
self.data[key]=value
def __getitem__(self,key):
return self.data[key]
def keys(self):
return self.data.keys()
def __str__(self):
return str(self.data)+'--lable-'+str(self.lable)
In [3]:
def lable0(sample):
return sample.lable==0
def lable1(sample):
return sample.lable==1
In [104]:
import copy
class Dataset:
def __init__(self):
self.data=[]
self.ks=[]
def __getitem__(self,index):
return self.data[index]
def append(self,sample):
self.data.append(sample)
for k in sample.data.keys():
if k not in self.ks:
self.ks.append(k)
def __str__(self):
listOfSample=[]
for sample in self.data:
SampleOne = str(sample)+'\n'
listOfSample.append(SampleOne)
return '\n'.join(listOfSample)+'\n'
def __len__(self):
return len(self.data)
def count(self,pred):
cot=0
for sample in self:
if(pred(sample)):
cot+=1
return cot
def split(self,key,val):
# split dataset according to feature
dataset = Dataset()
for sample in self:
if(sample[key]==val):
dataset.append(sample)
dataset.ks=copy.deepcopy(self.ks)
# print(dataset.ks)
dataset.ks.remove(key)
return dataset
def canSplit(self):
if(len(self.ks)==0):
return False
return True
def valsIssame(self):
first = self[0]
for sample in self:
for key in self.ks:
if(sample[key]!=first[key]):
return False
return True
# def lableIssame(self):
# first = self[0]
# for sample in self:
# if(first.lable!=sample.lable):
# return Fa
# all the same so that no necessary to split
def needSplit(self):
if(self.valsIssame(self)):
return False
else:
return True
def majorityVote(self):
n0 = self.count(label0)
n1 = self.count(label1)
if n0<n1 :
return 1
else:
return 0
In [91]:
class Filler:
def filler(self):
pass
import random
class RandomFiller(Filler):
def __init__(self):
self.begin=0
self.end=4
self.nfeat=2
self.nsample=20
def fillOne(self,sample):
for i in range(self.nfeat):
sample[i]=random.randint(self.begin,self.end)
sample.lable=random.randint(0,1)
def fillDataset(self,dataset):
for i in range(self.nsample):
sample=Sample()
self.fillOne(sample)
dataset.append(sample)
In [92]:
sample = Sample()
f=RandomFiller()
f.fillOne(sample)
print(sample)
print(sample.data.keys())
In [93]:
dataset = Dataset()
f.fillDataset(dataset)
print(dataset)
In [94]:
print(dataset.ks)
In [105]:
import math
def calEntropy(dataset):
nl0 = dataset.count(lable0)
nl1 = dataset.count(lable1)
if nl0==0 or nl1==0:
return 0
n = len(dataset)
p0 = float(nl0)/n
p1 = float(nl1)/n
print('p0 : '+str(p0))
print('p1 : '+str(p1))
return -(p0 * math.log(p0)+p1 * math.log(p1))
def getVals(dataset,key):
sets=[]
for sample in dataset:
if sample[key] not in sets:
sets.append(sample[key])
return sets
def getNumForVals(dataset,key,val):
nums=0
for sample in dataset:
if(sample[key]==val):
nums+=1
return nums
def calGain(dataset,key):
newEntropy=0
sets = getVals(dataset,key)
oldEntropy = calEntropy(dataset)
n = len(dataset)
for val in sets:
newDataset = dataset.split(key,val)
nums = getNumForVals(dataset,key,val)
p = float(nums)/n
newEntropy += p * calEntropy(newDataset)
return oldEntropy - newEntropy
In [117]:
#test data.split()
print('---key0----')
sets = getVals(dataset,0)
for val in sets:
_data = dataset.split(0,val)
print(_data)
# print('---key1-----')
# sets = getVals(dataset,1)
# for val in sets:
# _data = dataset.split(1,val)
# print(_data)
# # test math function
# entropy = calEntropy(dataset)
# print(entropy) #pass
# sets = getVals(dataset,0)
# print(sets) #pass
# print(calEntropy(_data))
print(calGain(dataset,0))
print(calGain(dataset,1))
In [ ]:
In [119]:
class DTNode:
def __init__(self):
self.key=None
self.val=None
self.data=Dataset()
# self.parent=None
self.children=[]
In [120]:
class Model:
def build(self):
pass
class DTModel(Model):
def __init__(self):
DTRoot = DTNode()
# set feature of root to -1
DTRoot.key=-1
DTRoot.val=-1
self.DTRoot=DTRoot
def build(self,parentNode):
data = parentNode.data
if(not data.needSplit()):
return
maxGain=0
for key in data.ks:
nowGain=calGain(data,key)
if(nowGain > maxGain):
maxGain = nowGain
maxKey = key
sets = getVals(data,maxKey)
for val in sets:
node = DTNode()
node.key = maxKey
node.val = val
childrenData = []
for sample in data:
if(sample[key]==val):
childrenData.append(sample)
node.data = childrenData
parentNode.children.append(node)
build(node)
def trainModel(self,dataset):
print('---train decisionTree Model---\n')
self.DTRoot.data = dataset
build(self.DTRoot)
In [118]:
decisionTree = DTModel()
decisionTree.trainModel(dataset)
In [ ]: