ID3算法:不使用sklearn中的决策树方法,根据数据集自己利用python编写决策树构建程序。
文章目录1.熵的计算2.最佳属性划分的选择信息熵的计算3.决策树的构建4.采用python matplotlib模块画决策树,使其决策树可视化:5.全部代码:ID3算法:不使用sklearn中的决策树方法,根据数据集自己进行编写决策树构建程序。在代码中用到的data数据,以及属性值。也可以根据自己的实际情况进行修改。data = [[1, 0, 1, ‘no’],[0, 1, 1, ‘no’],[
ID3算法:不使用sklearn中的决策树方法,根据数据集自己进行编写决策树构建程序。
在代码中用到的data数据,以及属性值。也可以根据自己的实际情况进行修改。
data = [
[1, 0, 1, ‘no’],
[0, 1, 1, ‘no’],
[0, 0, 0, ‘no’],
[1, 1, 1, ‘no’],
[0, 2, 1, ‘yes’],
[0, 1, 0, ‘no’],
[1, 2, 1, ‘no’],
[0, 0, 1, ‘yes’],
[0, 1, 0, ‘no’],
[0, 0, 1, ‘yes’], ]
label = [‘Refund’, ‘Marital Status’, ‘Taxable Income’]
ID3算法用信息增益方式进行划分特征选择,以此建立决策树。
因此,在代码中需要计算熵和信息增益。
1.熵的计算
具体代码实现:
#返回信息熵函数
def infoEntropy(data):
#curlist用来存储数据集中不同的标签值 分别的个数。如{'no': 7, 'yes': 3}
curlist={}
for key in data:
if key[-1] not in curlist:
curlist[key[-1]]=1
else:
curlist[key[-1]]+=1
dataEntropy = 0 #熵初始化为0
#循环遍历curlist,计算信息熵
for key in curlist:
dataEntropy-=(curlist[key]/len(data)) * math.log2((curlist[key]/len(data)))
return dataEntropy
2.最佳属性划分的选择
在选择最佳划分属性时,即我们要通过计算信息增益,选择出划分后信息增益最大的属性。
在选择某个属性划分后,会产生多个数据集合,对于每个数据集合我们都要计算它的熵,因此我们需要用到spliteData函数来返回某个数据集合,以便计算熵。
同时构建树时选择好最佳划分属性后,依然要调用该函数来返回根据此最佳划分属性划分后的数据集合,以进行递归建树。要注意的是:某个属性在之前已经成为过最佳划分属性后,之后不能再被使用,因此该函数在返回数据集合时应该删除选中的属性那一列数据。
'''
该函数主要是用某个离散属性a对数据集进行划分时,因为a有很多个取值,当取value时,产生的样本个数
i表示第i个属性
'''
def spliteData(data,value,i):
newdata=[]
for row in data:
if row[i]==value:
newVec = row[:i]
newVec.extend(row[i + 1:])
newdata.append(newVec)
return newdata
信息熵的计算
具体代码实现:
#选择一个属性作为最优划分的过程,在此函数中要计算各个特征划分后产生的信息增益,并选择信息增益最大的作为最优划分选择
def optimalPartition(data,label):
#计算某个结点不进行划分时,目前的系统熵为多少
originalEntroy=infoEntropy(data)
bestGainEnt=0
bestopt=-1
labelLen=len(label)
for i in range(labelLen):
#获取data第i列的各个值 如i=0时,allValue=[1, 0, 0, 1, 0, 0, 1, 0, 0, 0]
allValue = [example[i] for example in data]
#获取allValue中不同的类别,set具有去重的作用 ,如i=1时,上面的allValue现在等于{0, 1}
allValue=set(allValue)
newEntropy=0
for key in allValue:
newdata=spliteData(data,key,i)
newEntropy=newEntropy+len(newdata)/len(data) *infoEntropy(newdata)
gainEntropy=originalEntroy-newEntropy #计算信息增益
if gainEntropy>bestGainEnt: #因为需要选择信息增益最大的,因此每次计算出从该属性进行划分时,信息增益是否为当前最大,如果最大,则更新
bestGainEnt=gainEntropy
bestopt=i
return bestopt
3.决策树的构建
下图为周志华《机器学习》一书中描述道德决策树学习算法。下面将根据该伪代码进行建立决策树。
从该决策树的构建过程中,可以知道决策树停止生长有两个条件:
- 如果当前节点所包含的数据集合的属性都属于同一个类的时候(此时数据集合熵为0),则不需要再进一步进行划分。
- 所有属性(特征)都已经被用来划分过了,即没有更多的属性可以进行分割时,即便数据集合仍然不纯,也停止生长。
对于第二个条件,此时应该将node标记为叶节点,此时数据中包含多个不同类别,我们应该选择样本中类别最多的作为最终该节点的类别。因此需要一个函数来计算此时数据集合中哪个类别最多。此处majorityCnt函数就是用来实现此目的
def majorityCnt(datalabel):
labelCount={}
for key in datalabel:
if labelCount[key] not in labelCount:
labelCount[key]=1
else:
labelCount[key]+=1
labelCount = sorted(label.items(), key=operator.itemgetter(1), reverse=True)
return labelCount[0][0]
def createTree(data,label):
label2=label[:]
labellist=[example[-1] for example in data]
#如果样本中全属于同一类别,则返回,把其作为叶子结点
if labellist.count(labellist[0])==len(labellist):
return labellist[0]
#如果没有更多的属性可以进行分割。即只剩下标签值那一列时,停止分割,返回
if len(data[0])==1:
return majorityCnt(labellist)
#如果不是上述两种情况,则进行选择最佳属性,进行分割
bestopt=optimalPartition(data,label)
bestlabel=label[bestopt]
myTree={bestlabel:{}}
allValue = [example[bestopt] for example in data]
allValue = set(allValue)
#因为在该节点label2[bestopt]属性值已经被利用划分了,因此在后续进行分割时,不再进行考虑,所以要把该属性从label2中删除掉
del (label2[bestopt])
for key in allValue:
myTree[bestlabel][key]=createTree(spliteData(data,key,bestopt),label2)
print(myTree)
return myTree
决策树构建的结果:
4.采用python matplotlib模块画决策树,使其决策树可视化:
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def getTreeDepth(myTree): # 获取树的深度
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[
key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodes
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
# 绘制带箭头的注释
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
# 计算父节点和子节点的中间位置,在父节点间填充文本的信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
# 画决策树的准备方法
def plotTree(myTree, parentPt, nodeTxt): # if the first key tells you what feat was split on
numLeafs = getNumLeafs(myTree) # 计算树的宽度
depth = getTreeDepth(myTree) # 计算树的深度
firstStr = list(myTree.keys())[0] # the text label for this node should be this
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[
key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodes
plotTree(secondDict[key], cntrPt, str(key)) # recursion
else: # it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # no ticks
# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW;
plotTree.yOff = 1.0;
plotTree(inTree, (0.5, 1.0), '')
plt.show()
截图:
5.全部代码:
import math
import operator
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc="0.8") # 定义文本框与箭头的格式
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
#返回信息熵函数
def infoEntropy(data):
#curlist用来存储数据集中不同的标签值 分别的个数。如{'no': 7, 'yes': 3}
curlist={}
for key in data:
if key[-1] not in curlist:
curlist[key[-1]]=1
else:
curlist[key[-1]]+=1
dataEntropy = 0 #熵初始化为0
#循环遍历curlist,计算信息熵
for key in curlist:
dataEntropy-=(curlist[key]/len(data)) * math.log2((curlist[key]/len(data)))
return dataEntropy
'''
该函数主要是用某个离散属性a对数据集进行划分时,因为a有很多个取值,当取value时,产生的样本个数
i表示第i个属性
'''
def spliteData(data,value,i):
newdata=[]
for row in data:
if row[i]==value:
newVec = row[:i]
newVec.extend(row[i + 1:])
newdata.append(newVec)
return newdata
def majorityCnt(datalabel):
labelCount={}
for key in datalabel:
if labelCount[key] not in labelCount:
labelCount[key]=1
else:
labelCount[key]+=1
labelCount = sorted(label.items(), key=operator.itemgetter(1), reverse=True)
return labelCount[0][0]
#选择一个属性作为最优划分的过程,在此函数中要计算各个特征划分后产生的信息增益,并选择信息增益最大的作为最优划分选择
def optimalPartition(data,label):
#计算某个结点不进行划分时,目前的系统熵为多少
originalEntroy=infoEntropy(data)
bestGainEnt=0
bestopt=-1
labelLen=len(label)
for i in range(labelLen):
#获取data第i列的各个值 如i=0时,allValue=[1, 0, 0, 1, 0, 0, 1, 0, 0, 0]
allValue = [example[i] for example in data]
#获取allValue中不同的类别,set具有去重的作用 ,如i=1时,上面的allValue现在等于{0, 1}
allValue=set(allValue)
newEntropy=0
for key in allValue:
newdata=spliteData(data,key,i)
newEntropy=newEntropy+len(newdata)/len(data) *infoEntropy(newdata)
gainEntropy=originalEntroy-newEntropy #计算信息增益
if gainEntropy>bestGainEnt: #因为需要选择信息增益最大的,因此每次计算出从该属性进行划分时,信息增益是否为当前最大,如果最大,则更新
bestGainEnt=gainEntropy
bestopt=i
return bestopt
def createTree(data,label):
label2=label[:]
labellist=[example[-1] for example in data]
#如果样本中全属于同一类别,则返回,把其作为叶子结点
if labellist.count(labellist[0])==len(labellist):
return labellist[0]
#如果没有更多的属性可以进行分割。即只剩下标签值那一列时,停止分割,返回
if len(data[0])==1:
return majorityCnt(labellist)
#如果不是上述两种情况,则进行选择最佳属性,进行分割
bestopt=optimalPartition(data,label)
bestlabel=label[bestopt]
myTree={bestlabel:{}}
allValue = [example[bestopt] for example in data]
allValue = set(allValue)
#因为在该节点label2[bestopt]属性值已经被利用划分了,因此在后续进行分割时,不再进行考虑,所以要把该属性从label2中删除掉
del (label2[bestopt])
for key in allValue:
myTree[bestlabel][key]=createTree(spliteData(data,key,bestopt),label2)
print(myTree)
return myTree
def getNumLeafs(myTree): # 获取树叶节点的数目
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict': # 测试节点的数据类型是不是字典,如果是则就需要递归的调用getNumLeafs()函数
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree): # 获取树的深度
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[
key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodes
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
# 绘制带箭头的注释
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
# 计算父节点和子节点的中间位置,在父节点间填充文本的信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
# 画决策树的准备方法
def plotTree(myTree, parentPt, nodeTxt): # if the first key tells you what feat was split on
numLeafs = getNumLeafs(myTree) # 计算树的宽度
depth = getTreeDepth(myTree) # 计算树的深度
firstStr = list(myTree.keys())[0] # the text label for this node should be this
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[
key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodes
plotTree(secondDict[key], cntrPt, str(key)) # recursion
else: # it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # no ticks
# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW;
plotTree.yOff = 1.0;
plotTree(inTree, (0.5, 1.0), '')
plt.show()
if __name__ == '__main__':
data = [
[1, 0, 1, 'no'],
[0, 1, 1, 'no'],
[0, 0, 0, 'no'],
[1, 1, 1, 'no'],
[0, 2, 1, 'yes'],
[0, 1, 0, 'no'],
[1, 2, 1, 'no'],
[0, 0, 1, 'yes'],
[0, 1, 0, 'no'],
[0, 0, 1, 'yes'], ]
label = ['Refund', 'Marital Status', 'Taxable Income']
myTree=createTree(data,label)
createPlot(myTree)
参考书籍: 周志华《机器学习》
《数据挖掘导论》(完整版)Pang-Ning Tan , Michael Steinbach , Vipin Kumar (作者) 范明 , 范宏建 (译者)
更多推荐
所有评论(0)