day-8 python自带库实现ID3决策树算法

  前一天,我们基于sklearn科学库实现了ID3的决策树程序,本文将基于python自带库实现ID3决策树算法。

一、代码涉及基本知识

  1、 为了绘图方便,引入了一个第三方treePlotter模块进行图形绘制。该模块使用方法简单,调用模块createPlot接口,传入一个树型结构对象,即可绘制出相应图像。

  2、  在python中,如何定义一个树型结构对象

    可以使用了python自带的字典数据类型来定义一个树型对象。例如下面代码,我们定义一个根节点和两个左右子节点:

    rootNode = {‘rootNode‘: {}}
    leftNode = {‘leftNode‘: {‘yes‘:‘yes‘}}
    rightNode = {‘rightNode‘: {‘no‘:‘no‘}}
    rootNode[‘rootNode‘][‘left‘] = leftNode
    rootNode[‘rootNode‘][‘right‘] = rightNode
    treePlotter.createPlot(rootNode)

    通过调用treePlotter模块,绘制出如下树的图像

    

  2、  递归调用

    为了求每个节点的各个子节点,要用到递归的方法来实现,基本思想和二叉树的遍历方法一致,后面我们还会用Python实现一个二叉树源码,此处不再进行介绍。

  3、  此外,还需要对python常用的数据类型及其操作比较了解,例如字典、列表、集合等

二、程序主要流程

三、测试数据集


age


income


student


credit_rating


class_buys_computer


youth


high


no


fair


no


youth


high


no


excellent


no


middle_aged


high


no


fair


yes


senior


medium


no


fair


yes


senior


low


yes


fair


yes


senior


low


yes


excellent


no


middle_aged


low


yes


excellent


yes


youth


medium


no


fair


no


youth


low


yes


fair


yes


senior


medium


yes


fair


yes


youth


medium


yes


excellent


yes


middle_aged


medium


no


excellent


yes


middle_aged


high


yes


fair


yes


senior


medium


no


excellent


no

四、程序代码

1、计算测试集熵及信息增益

# 求最优的根节点
def chooseBestFeatureToSplit(dataset,headerList):
    # 定义一个初始值
    bestInfoGainRate = 0.0
    bestFeature = 0
    # 求特征列项的数量
    numFeatures = len(dataset[0]) -1
    # 获取整个测试数据集的熵
    baseShnnonEnt = calcShannonEnt(dataset)
    print("total‘s shannonEnt = %f" % (baseShnnonEnt))
    # 遍历每一个特征列,求取信息增益
    for i in range(numFeatures):
        # 获取某一列所有特征值
        featureList = [example[i] for example in dataset]
        uniqueVals = set(featureList)
        newEntropy = 0.0
        # 求得某一列某一个特征值的概率和熵
        newShannonEnt = 0.0
        for value in uniqueVals:
            # 计算熵
            subDataset = splitDataSet(dataset,i,value)
            newEntropy = calcShannonEnt(subDataset)
            # 计算某一列某一个特征值的概率
            newProbability = len(subDataset) / float(len(dataset))
            newShannonEnt += newProbability*calcShannonEnt(subDataset)
        infoGainRate = baseShnnonEnt - newShannonEnt
        print("%s‘s infoGainRate = %f - %f = %f"%(headerList[i],baseShnnonEnt,newShannonEnt,infoGainRate))
        if infoGainRate > bestInfoGainRate:
            bestInfoGainRate = infoGainRate
            bestFeature = i
    return bestFeature

  该结果和前一天计算结果一致,age特征对应信息增益最大,因此设为根节点:

2、程序源码

treePlotter.py        

import matplotlib.pyplot as plt

# 定义决策树决策结果属性
descisionNode = dict(boxstyle=‘sawtooth‘, fc=‘0.8‘)
leafNode = dict(boxstyle=‘round4‘, fc=‘0.8‘)
arrow_args = dict(arrowstyle=‘<-‘)

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    # nodeTxt为要显示的文本,centerNode为文本中心点, nodeType为箭头所在的点, parentPt为指向文本的点
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords=‘axes fraction‘,
                             xytext=centerPt, textcoords=‘axes fraction‘,
                              va=‘center‘, ha=‘center‘, bbox=nodeType, arrowprops=arrow_args)
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == ‘dict‘:
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]     # 这个是改的地方,原来myTree.keys()返回的是dict_keys类,不是列表,运行会报错。有好几个地方这样
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == ‘dict‘:
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

def createPlot(inTree):
    fig = plt.figure(1, facecolor=‘white‘)
    fig.clf()
    axprops = {‘xticks‘: None, ‘yticks‘: None}
    createPlot.ax1 = plt.subplot(111, frameon=False)
    plotTree.totalW = float(getNumLeafs(inTree))     # 全局变量宽度 = 叶子数目
    plotTree.totalD = float(getTreeDepth(inTree))     # 全局变量高度 = 深度
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), ‘‘)
    plt.show()

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    # cntrPt文本中心点, parentPt指向文本中心的点
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, descisionNode)
    seconDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in seconDict.keys():
        if type(seconDict[key]).__name__ == ‘dict‘:
            plotTree(seconDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(seconDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va=‘center‘, ha=‘center‘, rotation=30)

         decision_tree_ID3.py

# 导入库
import csv
import math
import operator
import treePlotter

# 导入数据集
def readDataset(file_path,file_mode):
    allElectronicsData = open(file_path, file_mode)
    reader = csv.reader(allElectronicsData)
    # 读取特征名称
    headers = next(reader)
    # 读取测试数据集
    dataset = []
    for row in reader:
        dataset.append(row)
    return headers,dataset

# 求某个数据集的熵
def calcShannonEnt(dataset):
    shannonEnt = 0.0
    labelList = {}
    for vec_now in dataset:
        labelValue = vec_now[-1]
        if vec_now[-1] not in labelList.keys():
            labelList[labelValue] = 0
        labelList[labelValue] += 1
    for labelKey in labelList:
        probability = float(labelList[labelKey] / len(dataset))
        shannonEnt -= probability*math.log(probability,2)
    return shannonEnt

# 根据给定的列特征值,分理出给定的特征量
def splitDataSet(dataset,feature_seq,value):
    new_dataset = []
    for vec_row in dataset:
        feature_Value = vec_row[feature_seq]
        if feature_Value == value:
            temp_vec = []
            temp_vec = vec_row[:feature_seq]
            temp_vec.extend(vec_row[feature_seq+1:])
            new_dataset.append(temp_vec)
    return new_dataset

# 求最优的根节点
def chooseBestFeatureToSplit(dataset,headerList):
    # 定义一个初始值
    bestInfoGainRate = 0.0
    bestFeature = 0
    # 求特征列项的数量
    numFeatures = len(dataset[0]) -1
    # 获取整个测试数据集的熵
    baseShnnonEnt = calcShannonEnt(dataset)
    #print("total‘s shannonEnt = %f" % (baseShnnonEnt))
    # 遍历每一个特征列,求取信息增益
    for i in range(numFeatures):
        # 获取某一列所有特征值
        featureList = [example[i] for example in dataset]
        uniqueVals = set(featureList)
        newEntropy = 0.0
        # 求得某一列某一个特征值的概率和熵
        newShannonEnt = 0.0
        for value in uniqueVals:
            # 计算熵
            subDataset = splitDataSet(dataset,i,value)
            newEntropy = calcShannonEnt(subDataset)
            # 计算某一列某一个特征值的概率
            newProbability = len(subDataset) / float(len(dataset))
            newShannonEnt += newProbability*calcShannonEnt(subDataset)
        infoGainRate = baseShnnonEnt - newShannonEnt
        #print("%s‘s infoGainRate = %f - %f = %f"%(headerList[i],baseShnnonEnt,newShannonEnt,infoGainRate))
        if infoGainRate > bestInfoGainRate:
            bestInfoGainRate = infoGainRate
            bestFeature = i
    return bestFeature

# 标签判定,通过少数服从多数原则
def majorityCnt(classList):
    classcount = {}
    for cl in classList:
        if cl not in classcount.keys():
            classcount[cl] = 0
        classcount[cl] += 1
    sortedClassCount = sorted(classcount.items(),key = operator.itemgetter(1),reverse= True)
    return sortedClassCount[0][0]

# 创建一个决策树
def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    # 1 所有特征值都是相同的时候直接返回
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    # 2 遍历完所有特征值,投票原则,返回出现次数最多的标签
    if len(dataSet[0])  == 1:
        return majorityCnt(classList)
    # 3 如果不满足上面两者,求最优特征
    bestFeature = chooseBestFeatureToSplit(dataSet,labels)
    bestFeatureLabel = labels[bestFeature]
    myTree = {bestFeatureLabel: {}}
    del (labels[bestFeature])
    featurValues = [example[bestFeature] for example in dataSet]
    uniqueVals = set(featurValues)
    # 使用递归的方法,获得整个树
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatureLabel][value] = createTree(splitDataSet(dataSet, bestFeature, value), subLabels)
    return myTree

def classify(inputTree, featLabels, testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == ‘dict‘:
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel

def classifyAll(inputTree, featLabels, testDataSet):
    classLabelAll = []
    for testVec in testDataSet:
        classLabelAll.append(classify(inputTree, featLabels, testVec))
    return classLabelAll

def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, ‘wb‘)
    pickle.dump(inputTree, fw)
    fw.close()

def grabTree(filename):
    import pickle
    fr = open(filename, ‘rb‘)
    return pickle.load(fr)

def main():
    # 读取数据集
    labels, dataSet = readDataset(file_path=r‘D:\test.csv‘, file_mode=‘r‘)
    labels_tmp = labels[:] # 拷贝,createTree会改变labels
    desicionTree = createTree(dataSet, labels_tmp)
    storeTree(desicionTree, ‘classifierStorage.txt‘)
    desicionTree = grabTree(‘classifierStorage.txt‘)
    treePlotter.createPlot(desicionTree)
    testSet = [[‘youth‘, ‘high‘, ‘no‘, ‘fair‘, ‘no‘]]
    print(‘classifyResult:\n‘, classifyAll(desicionTree, labels, testSet))

if __name__ == ‘__main__‘:
    main()

五、测试结果及结论

  我们从上面求解信息增益的公式中,其实可以看出,信息增益准则其实是对可取值数目较多的属性有所偏好!
  现在假如我们把数据集中的“编号”也作为一个候选划分属性。我们可以算出“编号”的信息增益是0.998
  因为每一个样本的编号都是不同的(由于编号独特唯一,条件熵为0了,每一个结点中只有一类,纯度非常高啊),也就是说,来了一个预测样本,你只要告诉我编号,其它特征就没有用了,这样生成的决策树显然不具有泛化能力。

  参考链接:

  http://www.cnblogs.com/wsine/p/5180310.html

  https://zhuanlan.zhihu.com/p/26760551

原文地址:https://www.cnblogs.com/python-frog/p/8721065.html

时间: 2024-11-07 04:23:18

day-8 python自带库实现ID3决策树算法的相关文章

python自带库及第三方库api察看

今天发现一个很有意思的功能,python自带了所有库的文档查看器,配置如下: 配置pydoc服务,cmd中输入如下代码: python –m pydoc –p 1234 回车后 ,使用过程中,该窗口不要关闭. 然后,在浏览器中输入网址:http://localhost:1234/ ,打开python自带的及集成的第三方库的api. 界面类似如下: 此工具对于察看第三方库的帮助文档非常方便

ID3决策树算法原理及C++实现(其中代码转自别人的博客)

分类是数据挖掘中十分重要的组成部分. 分类作为一种无监督学习方式被广泛的使用. 之前关于"数据挖掘中十大经典算法"中,基于ID3核心思想的分类算法 C4.5榜上有名.所以不难看出ID3在数据分类中是多么的重要了. ID3又称为决策树算法,虽然现在广义的决策树算法不止ID3一种,但是由 于ID3的重要性,习惯是还是把ID3和决策树算法等价起来. 另外无监督学习方式我还要多说两句.无监督学习方式包括决策树算法, 基于规则的分类,神经网络等.这些分类方式是初始分类已知,将样本分为 训练样本和

ID3决策树算法实现(Python版)

1 # -*- coding:utf-8 -*- 2 3 from numpy import * 4 import numpy as np 5 import pandas as pd 6 from math import log 7 import operator 8 9 #计算数据集的香农熵 10 def calcShannonEnt(dataSet): 11 numEntries=len(dataSet) 12 labelCounts={} 13 #给所有可能分类创建字典 14 for fe

数据挖掘之决策树算法ID3算法的相关原理

ID3决策树:针对属性选择问题,是决策树算法中最为典型和最具影响力的决策树算法. ID3决策树算法使用信息增益度作为选择测试属性. 其中p(ai) 表示ai 发生的概率. 假设有n个互不相容的事件a1,a2,a3,-.,an,它们中有且仅有一个 发生,则其平均的信息量可如下度量: 对数底数可以为任何数,不同的取值对应了熵的不同单位. 通常取2,并规定当p(ai)=0时 =0 Entropy(S,A)=∑(|Sv|/|S|)* Entropy(Sv)公式2 以去不去打羽毛球为例子 A:属性:out

python自带的用于解析HTML的库HtmlParser

转自博客https://www.cnblogs.com/masako/p/5868367.html HtmlParser,顾名思义,是解析Html的一个工具.python自带的. 一.常用属性和方法介绍 HtmlParser是一个类,在使用时一般继承它然后重载它的方法,来达到解析出需要的数据的目的. 1.常用属性: lasttag,保存上一个解析的标签名,是字符串. 2.常用方法: handle_starttag(tag, attrs) ,处理开始标签,比如<div>:这里的attrs获取到的

Python的常用库

读者您好.今天我将介绍20个属于我常用工具的Python库,我相信你看完之后也会觉得离不开它们.他们是: Requests.Kenneth Reitz写的最富盛名的http库.每个Python程序员都应该有它. Scrapy.如果你从事爬虫相关的工作,那么这个库也是必不可少的.用过它之后你就不会再想用别的同类库了. wxPython.Python的一个GUI(图形用户界面)工具.我主要用它替代tkinter.你一定会爱上它的. Pillow.它是PIL(Python图形库)的一个友好分支.对于用

Python常用的库简单介绍一下

Python常用的库简单介绍一下fuzzywuzzy ,字符串模糊匹配. esmre ,正则表达式的加速器. colorama 主要用来给文本添加各种颜色,并且非常简单易用. Prettytable 主要用于在终端或浏览器端构建格式化的输出. difflib ,[Python]标准库,计算文本差异 . Levenshtein ,快速计算字符串相似度. Chardet 字符编码探测器,可以自动检测文本.网页.xml的编码. shortuuid ,一组简洁URL/UUID函数库. ftfy ,Uni

python/c++接口库比较(SWIG,boost.python, pycxx, py++, sip, Weave, Pyrex )

<python/c++接口库比较(SWIG,boost.python, pycxx, py++, sip, Weave, Pyrex )>http://blog.csdn.net/lainegates/article/details/19565823 目前有很多开源的Python/c++ 绑定工具,搜了好多岾子,这里稍微总结下. SWIG 支持 Python 2 and 3 配置正确的话,可以全自动完成封装(*.i文件需要自己写) 当不是全自动的时候,它大多会重复你的.h文件并给出提示 除了P

[转载]python实现带验证码网站的自动登陆

    原文地址:python实现带验证码网站的自动登陆作者:TERRY-V 早听说用python做网络爬虫非常方便,正好这几天单位也有这样的需求,需要登陆XX网站下载部分文档,于是自己亲身试验了一番,效果还不错. 本例所登录的某网站需要提供用户名,密码和验证码,在此使用了python的urllib2直接登录网站并处理网站的Cookie. Cookie的工作原理: Cookie由服务端生成,然后发送给浏览器,浏览器会将Cookie保存在某个目录下的文本文件中.在下次请求同一网站时,会发送该Coo