决策树ID3算法预测隐形眼睛类型--python实现

本节讲解如何预测患者需要佩戴的隐形眼镜类型。

1、使用决策树预测隐形眼镜类型的一般流程

(1)收集数据:提供的文本文件(数据来源于UCI数据库)

(2)准备数据:解析tab键分隔的数据行

(3)分析数据:快速检查数据,确保正确地解析数据内容,使用createPlot()函数绘制最终的树形图

(4)训练算法:createTree()函数

(5)测试算法:编写测试函数验证决策树可以正确分类给定的数据实例

(6)使用算法:存储数的数据结构,以使下次使用时无需重新构造树

trees.py如下:

#!/usr/bin/python
# -*- coding: utf-8 -*-
from math import log
#计算给定数据集的香农熵
def calcShannonEnt(dataSet):
  numEntries=len(dataSet)
  labelCounts={}
  for featVec in dataSet:
      currentLabel=featVec[-1]
      if currentLabel not in labelCounts.keys():
          labelCounts[currentLabel]=0
      labelCounts[currentLabel]+=1
  shannonEnt=0.0
  for key in labelCounts:
      prob=float(labelCounts[key])/numEntries
      shannonEnt -=prob*log(prob,2)
  return shannonEnt
#按照给定特征划分数据集
def splitDataSet(dataSet,axis,value):
    retDataSet=[]
    for featVec in dataSet:
        if featVec[axis]==value:
            reducedFeatVec=featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet
#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
  numFeatures=len(dataSet[0])-1
  baseEntropy=calcShannonEnt(dataSet)   #计算整个数据集的原始香农熵
  bestInfoGain=0.0;bestFeature=-1
  for i in range(numFeatures):    #循环遍历数据集中的所有特征
    featList=[example[i] for example in dataSet]
    uniqueVals=set(featList)
    newEntropy=0.0
    for value in uniqueVals:
      subDataSet=splitDataSet(dataSet,i,value)
      prob=len(subDataSet)/float(len(dataSet))
      newEntropy+=prob*calcShannonEnt(subDataSet)
    infoGain=baseEntropy-newEntropy
    if(infoGain>bestInfoGain):
      bestInfoGain=infoGain
      bestFeature=i
  return bestFeature
def majorityCnt(classList):
  classCount={}
  for vote in classList:
    if vote not in classCount.keys():classCount[vote]=0
    classCount[vote]+=1
  sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
  return sortedClassCount[0][0]
#创建树的函数代码
def createTree(dataSet,labels):
  classList=[example[-1] for example in dataSet]
  if classList.count(classList[0])==len(classList):  #类别完全相同规则停止继续划分
    return classList[0]
  if len(dataSet[0])==1: return majorityCnt(classList)
  bestFeat=chooseBestFeatureToSplit(dataSet)
  bestFeatLabel=labels[bestFeat]
  myTree={bestFeatLabel:{}}
  del(labels[bestFeat])  #得到列表包含的所有属性
  featValues=[example[bestFeat] for example in dataSet]
  uniqueVals=set(featValues)
  for value in uniqueVals:
    subLabels=labels[:]
    myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
  return myTree
#测试算法:使用决策树执行分类
def classify(inputTree,featLabels,testVec):
  firstStr=inputTree.keys()[0]
  secondDict=inputTree[firstStr]
  featIndex=featLabels.index(firstStr)
  for key in secondDict.keys():
    if testVec[featIndex]==key:
      if type(secondDict[key]).__name__==‘dict‘:
        classLabel=classify(secondDict[key],featLabels,testVec)
      else:classLabel=secondDict[key]
  return classLabel
#使用算法:决策树的存储
def storeTree(inputTree,filename):
  import pickle
  fw=open(filename,‘w‘)
  pickle.dump(inputTree,fw)
  fw.close()

def grabTree(filename):
  import pickle
  fr=open(filename)
  return pickle.load(fr)

treePlotter.py如下:

#!/usr/bin/python
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
from numpy import *
import operator
#定义文本框和箭头格式
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")
#绘制箭头的注解
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords=‘axes fraction‘,xytext=centerPt,textcoords=‘axes fraction‘,va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)
def createPlot():
    fig=plt.figure(1,facecolor=‘white‘)
    fig.clf()
    createPlot.ax1=plt.subplot(111,frameon=False)
    plotNode(U‘决策节点‘,(0.5,0.1),(0.1,0.5),decisionNode)
    plotNode(U‘叶节点‘,(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()
#获取叶节点的数目和树的层数
def getNumLeafs(myTree):
    numLeafs=0
    firstStr=myTree.keys()[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__==‘dict‘:
            numLeafs += getNumLeafs(secondDict[key])
        else: numLeafs +=1
    return numLeafs
def getTreeDepth(myTree):
    maxDepth=0
    firstStr=myTree.keys()[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__==‘dict‘:
            thisDepth=1+getTreeDepth(secondDict[key])
        else:thisDepth=1
        if thisDepth>maxDepth:maxDepth=thisDepth
    return maxDepth

def retrieveTree(i):
    listOfTrees=[{‘no surfacing‘:{0:‘no‘,1:{‘flippers‘:{0:‘no‘,1:‘yes‘}}}},                 {‘no surfacing‘:{0:‘no‘,1:{‘flippers‘:{0:{‘head‘:{0:‘no‘,1:‘yes‘}},1:‘no‘}}}}]
    return listOfTrees[i]
#在父节点间填充文本信息
def plotMidText(cntrPt,parentPt,txtString):
    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
    createPlot.ax1.text(xMid,yMid,txtString)
#计算宽和高
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstStr=myTree.keys()[0]
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)   #计算父节点和子节点的中间位置
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict=myTree[firstStr]
    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__==‘dict‘:
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
        plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
def createPlot(inTree):
    fig=plt.figure(1,facecolor=‘white‘)
    fig.clf()
    axprops=dict(xticks=[],yticks=[])
    createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW=float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;
    plotTree(inTree,(0.5,1.0),‘‘)
    plt.show()

lenses.txt如下:

运行如下:

1 >>> import trees
2 >>> import treePlotter
3 >>> fr=open(‘lenses.txt‘)
4 >>> lenses=[inst.strip().split(‘\t‘) for inst in fr.readlines()]
5 >>> lensesLabels=[‘age‘,‘prescript‘,‘astigmatic‘,‘tearRate‘]
6 >>> lensesTree=trees.createTree(lenses,lensesLabels)
7 >>> lensesTree
8 {‘tearRate‘: {‘reduced‘: ‘no lenses‘, ‘normal‘: {‘astigmatic‘: {‘yes‘: {‘prescript‘: {‘hyper‘: {‘age‘: {‘pre‘: ‘no lenses‘, ‘presbyopic‘: ‘no lenses‘, ‘young‘: ‘hard‘}}, ‘myope‘: ‘hard‘}}, ‘no‘: {‘age‘: {‘pre‘: ‘soft‘, ‘presbyopic‘: {‘prescript‘: {‘hyper‘: ‘soft‘, ‘myope‘: ‘no lenses‘}}, ‘young‘: ‘soft‘}}}}}}
9 >>> treePlotter.createPlot(lensesTree)

由图看出决策树非常好地匹配了实验数据,然而这些匹配选项可能太多。我们将这种问题称之为过度匹配(overfitting)。为了减少过度匹配问题,我们可以裁剪决策树,去掉一些不必要的叶子节点。如果叶子节点只能增加少许信息,则可以删除该节点,将它并入到其他叶子节点中。

时间: 2024-10-25 19:06:36

决策树ID3算法预测隐形眼睛类型--python实现的相关文章

决策树---ID3算法(介绍及Python实现)

决策树---ID3算法   决策树: 以天气数据库的训练数据为例. Outlook Temperature Humidity Windy PlayGolf? sunny 85 85 FALSE no sunny 80 90 TRUE no overcast 83 86 FALSE yes rainy 70 96 FALSE yes rainy 68 80 FALSE yes rainy 65 70 TRUE no overcast 64 65 TRUE yes sunny 72 95 FALSE

决策树ID3算法[分类算法]

ID3分类算法的编码实现 1 <?php 2 /* 3 *决策树ID3算法(分类算法的实现) 4 */ 5 6 /* 7 *把.txt中的内容读到数组中保存 8 *$filename:文件名称 9 */ 10 11 //-------------------------------------------------------------------- 12 function gerFileContent($filename) 13 { 14 $array = array(NULL); 15

数据挖掘之决策树ID3算法(C#实现)

决策树是一种非常经典的分类器,它的作用原理有点类似于我们玩的猜谜游戏.比如猜一个动物: 问:这个动物是陆生动物吗? 答:是的. 问:这个动物有鳃吗? 答:没有. 这样的两个问题顺序就有些颠倒,因为一般来说陆生动物是没有鳃的(记得应该是这样的,如有错误欢迎指正).所以玩这种游戏,提问的顺序很重要,争取每次都能够获得尽可能多的信息量. AllElectronics顾客数据库标记类的训练元组 RID age income student credit_rating Class: buys_comput

决策树ID3算法的java实现

决策树的分类过程和人的决策过程比较相似,就是先挑“权重”最大的那个考虑,然后再往下细分.比如你去看医生,症状是流鼻涕,咳嗽等,那么医生就会根据你的流鼻涕这个权重最大的症状先认为你是感冒,接着再根据你咳嗽等症状细分你是否为病毒性感冒等等.决策树的过程其实也是基于极大似然估计.那么我们用一个什么标准来衡量某个特征是权重最大的呢,这里有信息增益和基尼系数两个.ID3算法采用的是信息增益这个量. 根据<统计学习方法>中的描述,G(D,A)表示数据集D在特征A的划分下的信息增益.具体公式: G(D,A)

决策树ID3算法的java实现(基本试用所有的ID3)

已知:流感训练数据集,预定义两个类别: 求:用ID3算法建立流感的属性描述决策树 流感训练数据集 No. 头痛 肌肉痛 体温 患流感 1 是(1) 是(1) 正常(0) 否(0) 2 是(1) 是(1) 高(1) 是(1) 3 是(1) 是(1) 很高(2) 是(1) 4 否(0) 是(1) 正常(0) 否(0) 5 否(0) 否(0) 高(1) 否(0) 6 否(0) 是(1) 很高(2) 是(1) 7 是(1) 否(0) 高(1) 是(1) 原理分析: 在决策树的每一个非叶子结点划分之前,先

决策树 -- ID3算法小结

ID3算法(Iterative Dichotomiser 3 迭代二叉树3代),是一个由Ross Quinlan发明的用于决策树的算法:简单理论是越是小型的决策树越优于大的决策树. 算法归纳: 1.使用所有没有使用的属性并计算与之相关的样本熵值: 2.选取其中熵值最小的属性 3.生成包含该属性的节点 4.使用新的分支表继续前面步骤 ID3算法以信息论为基础,以信息熵和信息增益为衡量标准,从而实现对数据的归纳分类:所以归根结底,是为了从一堆数据中生成决策树而采取的一种归纳方式: 具体介绍: 1.信

决策树ID3算法实现,详细分步实现,参考机器学习实战

一.编写计算历史数据的经验熵函数 In [1]: from math import log def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCounts = {} for elem in dataSet: #遍历数据集中每条样本的类别标签,统计每类标签的数量 currentLabel = elem[-1] if currentLabel not in labelCounts.keys(): #如果当前标签不在字典的key值中

Python实现决策树ID3算法

主要思想: 0.训练集格式:特征1,特征2,...特征n,类别 1.采用Python自带的数据结构字典递归的表示数据 2.ID3计算的信息增益是指类别的信息增益,因此每次都是计算类别的熵 3.ID3每次选择最优特征进行数据划分后都会消耗特征 4.当特征消耗到一定程度,可能会出现数据实例一样,但是类别不一样的情况,这个时候选不出最优特征而返回-1:   因此外面要捕获-1,要不然Python会以为最优特征是最后一列(类别) #coding=utf-8 import operator from ma

【机器学习】决策树-ID3算法的Python实现

''' Created on Jan 30, 2015 @author: 史帅 ''' from math import log import operator import re def fileToDataSet(fileName): ''' 此方法功能是:从文件中读取样本集数据,样本数据的格式为:数据以空白字符分割,最后一列为类标签 参数: fileName:存放样本集数据的文件路径 返回值: dataSet:样本集数据组成的二维数组 ''' file=open(fileName, mod