tree.py代码
1 #encoding:utf-8 2 from math import log 3 import operator 4 import treePlotter as tp 5 6 7 def createDataSet(): #简单测试数据创建 8 dataSet = [[1, 1, ‘yes‘], 9 [1, 1, ‘yes‘], 10 [1, 0, ‘no‘], 11 [0, 1, ‘no‘], 12 [0, 1, ‘no‘]] 13 labels = [‘no surfacing‘, ‘flippers‘] 14 # change to discrete values 15 return dataSet, labels 16 17 18 def calcShannonEnt(dataSet): #计算给定数据集的香农熵 19 numEntries = len(dataSet) 20 labelCounts = {} 21 for featVec in dataSet: 22 currentLabel = featVec[-1] 23 if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 24 labelCounts[currentLabel] += 1 25 shannonEnt = 0.0 26 for key in labelCounts: 27 prob = float(labelCounts[key]) / numEntries 28 shannonEnt -= prob * log(prob, 2) 29 return shannonEnt 30 31 #按照给定特征划分数据集 32 def splitDataSet(dataSet, axis, value): #dataSet:数据集 axis:下标(用于指定哪个特征) value:该特征的值 33 retDataSet = [] 34 for featVec in dataSet: 35 if featVec[axis] == value: 36 reducedFeatVec = featVec[:axis] 37 reducedFeatVec.extend(featVec[axis + 1:]) 38 retDataSet.append(reducedFeatVec) #reducedFeatVec中没有指定的那个特征值了,注意append和extend的区别 39 return retDataSet 40 41 42 def chooseBestFeatureToSplit(dataSet): 43 numFeatures = len(dataSet[0]) - 1 44 baseEntropy = calcShannonEnt(dataSet) 45 bestInfoGain = 0.0; 46 bestFeature = -1 47 for i in range(numFeatures): #第i列 48 featList = [example[i] for example in dataSet] 49 uniqueVals = set(featList) #创建唯一的分类标签列表 50 newEntropy = 0.0 51 for value in uniqueVals: #计算每种分类方式的信息熵,并加到总的熵,一个特征可能有多个值 52 subDataSet = splitDataSet(dataSet, i, value) 53 prob = len(subDataSet) / float(len(dataSet)) 54 newEntropy += prob * calcShannonEnt(subDataSet) #总的熵 55 infoGain = baseEntropy - newEntropy # 得到信息增益 56 if (infoGain > bestInfoGain): #如果更好,则更新 57 bestInfoGain = infoGain 58 bestFeature = i 59 return bestFeature #返回最好的第几列,整型 60 61 62 def majorityCnt(classList): 63 classCount = {} #类似于map 64 for vote in classList: #统计分类名称出现的次数 65 if vote not in classCount.keys(): classCount[vote] = 0 66 classCount[vote] += 1 67 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) #排序 68 return sortedClassCount[0][0] #返回出现次数最多的分类名称 69 70 71 def createTree(dataSet, labels): #构建决策树 72 classList = [example[-1] for example in dataSet] 73 if classList.count(classList[0]) == len(classList): #类别相同则停止划分 74 return classList[0] 75 if len(dataSet[0]) == 1: # 遍历完所有特征值时返回最多的 76 return majorityCnt(classList) 77 bestFeat = chooseBestFeatureToSplit(dataSet) #最佳划分 78 bestFeatLabel = labels[bestFeat] #最佳划分属性名 79 myTree = {bestFeatLabel: {}} 80 del (labels[bestFeat]) #删除该属性 81 featValues = [example[bestFeat] for example in dataSet] 82 uniqueVals = set(featValues) #得到列表包含的所有特征值 83 for value in uniqueVals: 84 subLabels = labels[:] 85 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) 86 return myTree #返回决策树 87 88 89 def classify(inputTree, featLabels, testVec): 90 firstStr = inputTree.keys()[0] #inputTree根字符串 91 secondDict = inputTree[firstStr] #形如 0:‘yes‘, 1:{下一级树} 92 featIndex = featLabels.index(firstStr) #将标签字符串转化为索引 93 key = testVec[featIndex] #testVec当前属性下的值 94 valueOfFeat = secondDict[key] #值为key的下一级树 95 if isinstance(valueOfFeat, dict): #valueOfFeat为字典数据类型时,递归 96 classLabel = classify(valueOfFeat, featLabels, testVec) 97 else: 98 classLabel = valueOfFeat #否则就是当前结果 99 return classLabel 100 101 102 def storeTree(inputTree, filename): #决策树的存储 103 import pickle 104 fw = open(filename, ‘w‘) 105 pickle.dump(inputTree, fw) 106 fw.close() 107 108 109 def grabTree(filename): #决策树的读取 110 import pickle 111 fr = open(filename) 112 return pickle.load(fr) 113 114 if __name__ == ‘__main__‘: 115 # dataSet, labels = createDataSet() 116 # print dataSet 117 # print labels 118 # shannonEnt = calcShannonEnt(dataSet) 119 # print "香农熵为 %f" % (shannonEnt) 120 # myMat = splitDataSet(dataSet,0,1) 121 # print myMat 122 # index = chooseBestFeatureToSplit(dataSet) 123 # print index 124 #mytree = createTree(dataSet, labels) 125 # print "决策树:" 126 # print mytree 127 # myTree = tp.retrieveTree(0) 128 # print myTree 129 # storeTree(myTree,‘myTree.txt‘) 130 # myTree = grabTree(‘myTree.txt‘) 131 # print myTree 132 # print classify(myTree,labels,[1,0]) 133 134 #决策树预测隐形眼镜类型 135 fr = open(‘lenses.txt‘) 136 lenses = [line.strip().split(‘\t‘) for line in fr.readlines()] 137 lensesLabels = [‘age‘,‘prescript‘,‘astigmatic‘,‘tearRate‘] 138 lensesTree = createTree(lenses,lensesLabels) 139 print lensesTree 140 tp.createPlot(lensesTree)
treePlotter.py代码
1 #encoding:utf-8 2 import matplotlib.pyplot as plt 3 4 decisionNode = dict(boxstyle="sawtooth", fc="0.8") 5 leafNode = dict(boxstyle="round4", fc="0.8") 6 arrow_args = dict(arrowstyle="<-") 7 8 9 def getNumLeafs(myTree): #得到树的叶子节点数 10 numLeafs = 0 11 firstStr = myTree.keys()[0] 12 secondDict = myTree[firstStr] 13 for key in secondDict.keys(): 14 if type(secondDict[ 15 key]).__name__ == ‘dict‘: # test to see if the nodes are dictonaires, if not they are leaf nodes 16 numLeafs += getNumLeafs(secondDict[key]) 17 else: 18 numLeafs += 1 19 return numLeafs 20 21 22 def getTreeDepth(myTree): #得到树的深度 23 maxDepth = 0 24 firstStr = myTree.keys()[0] 25 secondDict = myTree[firstStr] 26 for key in secondDict.keys(): 27 if type(secondDict[ 28 key]).__name__ == ‘dict‘: # test to see if the nodes are dictonaires, if not they are leaf nodes 29 thisDepth = 1 + getTreeDepth(secondDict[key]) 30 else: 31 thisDepth = 1 32 if thisDepth > maxDepth: maxDepth = thisDepth 33 return maxDepth 34 35 36 def plotNode(nodeTxt, centerPt, parentPt, nodeType): #绘制带箭头的注解 37 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords=‘axes fraction‘, 38 xytext=centerPt, textcoords=‘axes fraction‘, 39 va="center", ha="center", bbox=nodeType, arrowprops=arrow_args) 40 41 42 def plotMidText(cntrPt, parentPt, txtString): #在父子节点间填充文本信息 43 xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0] 44 yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1] 45 createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) 46 47 #绘制树 48 def plotTree(myTree, parentPt, nodeTxt): # if the first key tells you what feat was split on 49 numLeafs = getNumLeafs(myTree) # this determines the x width of this tree 50 depth = getTreeDepth(myTree) 51 firstStr = myTree.keys()[0] # the text label for this node should be this 52 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff) 53 plotMidText(cntrPt, parentPt, nodeTxt) 54 plotNode(firstStr, cntrPt, parentPt, decisionNode) 55 secondDict = myTree[firstStr] 56 plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD 57 for key in secondDict.keys(): 58 if type(secondDict[ 59 key]).__name__ == ‘dict‘: # test to see if the nodes are dictonaires, if not they are leaf nodes 60 plotTree(secondDict[key], cntrPt, str(key)) # recursion 61 else: # it‘s a leaf node print the leaf node 62 plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW 63 plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) 64 plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) 65 plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD 66 67 68 # if you do get a dictonary you know it‘s a tree, and the first element will be another dict 69 70 def createPlot(inTree): 71 fig = plt.figure(1, facecolor=‘white‘) 72 fig.clf() 73 axprops = dict(xticks=[], yticks=[]) 74 createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # no ticks 75 # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 76 plotTree.totalW = float(getNumLeafs(inTree)) 77 plotTree.totalD = float(getTreeDepth(inTree)) 78 plotTree.xOff = -0.5 / plotTree.totalW; 79 plotTree.yOff = 1.0; 80 plotTree(inTree, (0.5, 1.0), ‘‘) 81 plt.show() 82 83 84 # def createPlot(): 85 # fig = plt.figure(1, facecolor=‘white‘) 86 # fig.clf() 87 # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 88 # plotNode(‘a decision node‘, (0.5, 0.1), (0.1, 0.5), decisionNode) 89 # plotNode(‘a leaf node‘, (0.8, 0.1), (0.3, 0.8), leafNode) 90 # plt.show() 91 92 def retrieveTree(i): 93 listOfTrees = [{‘no surfacing‘: {0: ‘no‘, 1: {‘flippers‘: {0: ‘no‘, 1: ‘yes‘}}}}, 94 {‘no surfacing‘: {0: ‘no‘, 1: {‘flippers‘: {0: {‘head‘: {0: ‘no‘, 1: ‘yes‘}}, 1: ‘no‘}}}} 95 ] 96 return listOfTrees[i] 97 98 # createPlot(thisTree) 99 100 if __name__ == ‘__main__‘: 101 decisionNode = dict(boxstyle="sawtooth", fc="0.8") 102 leafNode = dict(boxstyle="round4", fc="0.8") 103 arrow_args = dict(arrowstyle="<-") 104 #createPlot() 105 myTree = retrieveTree(0) 106 createPlot(myTree) 107 # print myTree 108 # print getNumLeafs(myTree) 109 # print getTreeDepth(myTree)
时间: 2024-10-26 09:26:22