《机器学习实战》学习笔记:绘制树形图&使用决策树预测隐形眼镜类型

上一节实现了决策树,但只是使用包含树结构信息的嵌套字典来实现,其表示形式较难理解,显然,绘制直观的二叉树图是十分必要的。Python没有提供自带的绘制树工具,需要自己编写函数,结合Matplotlib库创建自己的树形图。这一部分的代码多而复杂,涉及二维坐标运算;书里的代码虽然可用,但函数和各种变量非常多,感觉非常凌乱,同时大量使用递归,因此只能反复研究,反反复复用了一天多时间,才差不多搞懂,因此需要备注一下。

一.绘制属性图

这里使用Matplotlib的注解工具annotations实现决策树绘制的各种细节,包括生成节点处的文本框、添加文本注释、提供对文字着色等等。在画一整颗树之前,最好先掌握单个树节点的绘制。一个简单实例如下:

# -*- coding: utf-8 -*-
"""
Created on Fri Sep 04 01:15:01 2015

@author: Herbert
"""

import matplotlib.pyplot as plt

nonLeafNodes = dict(boxstyle = "sawtooth", fc = "0.8")
leafNodes = dict(boxstyle = "round4", fc = "0.8")
line = dict(arrowstyle = "<-")

def plotNode(nodeName, targetPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeName, xy = parentPt, xycoords =                             ‘axes fraction‘, xytext = targetPt,                             textcoords = ‘axes fraction‘, va =                             "center", ha = "center", bbox = nodeType,                             arrowprops = line)

def createPlot():
    fig = plt.figure(1, facecolor = ‘white‘)
    fig.clf()
    createPlot.ax1 = plt.subplot(111, frameon = False)
    plotNode(‘nonLeafNode‘, (0.2, 0.1), (0.4, 0.8), nonLeafNodes)
    plotNode(‘LeafNode‘, (0.8, 0.1), (0.6, 0.8), leafNodes)
    plt.show()

createPlot()

输出结果:

该实例中,plotNode()函数用于绘制箭头和节点,该函数每调用一次,将绘制一个箭头和一个节点。后面对于该函数有比较详细的解释。createPlot()函数创建了输出图像的对话框并对齐进行一些简单的设置,同时调用了两次plotNode(),生成一对节点和指向节点的箭头。

绘制整颗树

这部分的函数和变量较多,为方便日后扩展功能,需要给出必要的标注:

# -*- coding: utf-8 -*-
"""
Created on Fri Sep 04 01:15:01 2015

@author: Herbert
"""

import matplotlib.pyplot as plt

# 部分代码是对绘制图形的一些定义,主要定义了文本框和剪头的格式
nonLeafNodes = dict(boxstyle = "sawtooth", fc = "0.8")
leafNodes = dict(boxstyle = "round4", fc = "0.8")
line = dict(arrowstyle = "<-")

# 使用递归计算树的叶子节点数目
def getLeafNum(tree):
    num = 0
    firstKey = tree.keys()[0]
    secondDict = tree[firstKey]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == ‘dict‘:
            num += getLeafNum(secondDict[key])
        else:
            num += 1
    return num

# 同叶子节点计算函数,使用递归计算决策树的深度
def getTreeDepth(tree):
    maxDepth = 0
    firstKey = tree.keys()[0]
    secondDict = tree[firstKey]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == ‘dict‘:
            depth = getTreeDepth(secondDict[key]) + 1
        else:
            depth = 1
        if depth > maxDepth:
            maxDepth = depth
    return maxDepth

# 在前面例子已实现的函数,用于注释形式绘制节点和箭头
def plotNode(nodeName, targetPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeName, xy = parentPt, xycoords =                             ‘axes fraction‘, xytext = targetPt,                             textcoords = ‘axes fraction‘, va =                             "center", ha = "center", bbox = nodeType,                             arrowprops = line)

# 用于绘制剪头线上的标注,涉及坐标计算,其实就是两个点坐标的中心处添加标注
def insertText(targetPt, parentPt, info):
    xCoord = (parentPt[0] - targetPt[0]) / 2.0 + targetPt[0]
    yCoord = (parentPt[1] - targetPt[1]) / 2.0 + targetPt[1]
    createPlot.ax1.text(xCoord, yCoord, info)

# 实现整个树的绘制逻辑和坐标运算,使用的递归,重要的函数
# 其中两个全局变量plotTree.xOff和plotTree.yOff
# 用于追踪已绘制的节点位置,并放置下个节点的恰当位置
def plotTree(tree, parentPt, info):
    # 分别调用两个函数算出树的叶子节点数目和树的深度
    leafNum = getLeafNum(tree)
    treeDepth = getTreeDepth(tree)
    firstKey = tree.keys()[0] # the text label for this node
    firstPt = (plotTree.xOff + (1.0 + float(leafNum)) / 2.0/plotTree.totalW,                plotTree.yOff)
    insertText(firstPt, parentPt, info)
    plotNode(firstKey, firstPt, parentPt, nonLeafNodes)
    secondDict = tree[firstKey]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == ‘dict‘:
            plotTree(secondDict[key], firstPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),                     firstPt, leafNodes)
            insertText((plotTree.xOff, plotTree.yOff), firstPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD

# 以下函数执行真正的绘图操作,plotTree()函数只是树的一些逻辑和坐标运算
def createPlot(inTree):
    fig = plt.figure(1, facecolor = ‘white‘)
    fig.clf()
    createPlot.ax1 = plt.subplot(111, frameon = False) #, **axprops)
    # 全局变量plotTree.totalW和plotTree.totalD
    # 用于存储树的宽度和树的深度
    plotTree.totalW = float(getLeafNum(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), ‘ ‘)
    plt.show()

# 一个小的测试集
def retrieveTree(i):
    listOfTrees = [{‘no surfacing‘:{0: ‘no‘, 1:{‘flippers‘:{0:‘no‘, 1:‘yes‘}}}},                    {‘no surfacing‘:{0: ‘no‘, 1:{‘flippers‘:{0:{‘head‘:{0:‘no‘,                     1:‘yes‘}}, 1:‘no‘}}}}]
    return listOfTrees[i]

createPlot(retrieveTree(1)) # 调用测试集中一棵树进行绘制

retrieveTree()函数中包含两颗独立的树,分别输入参数即可返回树的参数tree,最后执行createPlot(tree)即得到画图的结果,如下所示:

书中关于递归计算树的叶子节点和深度这部分十分简单,在编写绘制属性图的函数时,难度在于这本书中一些绘图坐标的取值以及在计算节点坐标所作的处理,书中对于这部分的解释比较散乱。博客:http://www.cnblogs.com/fantasy01/p/4595902.html 给出了十分详尽的解释,包括坐标的求解和公式的分析,以下只摘取一部分作为了解:

这里说一下具体绘制的时候是利用自定义,如下图:

这里绘图,作者选取了一个很聪明的方式,并不会因为树的节点的增减和深度的增减而导致绘制出来的图形出现问题,当然不能太密集。这里利用整 棵树的叶子节点数作为份数将整个x轴的长度进行平均切分,利用树的深度作为份数将y轴长度作平均切分,并利用plotTree.xOff作为最近绘制的一 个叶子节点的x坐标,当再一次绘制叶子节点坐标的时候才会plotTree.xOff才会发生改变;用plotTree.yOff作为当前绘制的深 度,plotTree.yOff是在每递归一层就会减一份(上边所说的按份平均切分),其他时候是利用这两个坐标点去计算非叶子节点,这两个参数其实就可 以确定一个点坐标,这个坐标确定的时候就是绘制节点的时候

plotTree函数的整体步骤分为以下三步:

  1. 绘制自身
  2. 若当前子节点不是叶子节点,递归
  3. 若当子节点为叶子节点,绘制该节点

以下是plotTreecreatePlot函数的详细解析,因此把两个函数的代码单独拿出来了:

# 实现整个树的绘制逻辑和坐标运算,使用的递归,重要的函数
# 其中两个全局变量plotTree.xOff和plotTree.yOff
# 用于追踪已绘制的节点位置,并放置下个节点的恰当位置
def plotTree(tree, parentPt, info):
    # 分别调用两个函数算出树的叶子节点数目和树的深度
    leafNum = getLeafNum(tree)
    treeDepth = getTreeDepth(tree)
    firstKey = tree.keys()[0] # the text label for this node
    firstPt = (plotTree.xOff + (1.0 + float(leafNum)) / 2.0/plotTree.totalW,                plotTree.yOff)
    insertText(firstPt, parentPt, info)
    plotNode(firstKey, firstPt, parentPt, nonLeafNodes)
    secondDict = tree[firstKey]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == ‘dict‘:
            plotTree(secondDict[key], firstPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),                     firstPt, leafNodes)
            insertText((plotTree.xOff, plotTree.yOff), firstPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD

# 以下函数执行真正的绘图操作,plotTree()函数只是树的一些逻辑和坐标运算
def createPlot(inTree):
    fig = plt.figure(1, facecolor = ‘white‘)
    fig.clf()
    createPlot.ax1 = plt.subplot(111, frameon = False) #, **axprops)
    # 全局变量plotTree.totalW和plotTree.totalD
    # 用于存储树的宽度和树的深度
    plotTree.totalW = float(getLeafNum(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), ‘ ‘)
    plt.show()

首先代码对整个画图区间根据叶子节点数和深度进行平均切分,并且xy轴的总长度均为1,如同下图:

解释如下

1.图中的方形为非叶子节点的位置,@是叶子节点的位置,因此上图的一个表格的长度应该为: 1/plotTree.totalW,但是叶子节点的位置应该为@所在位置,则在开始的时候 plotTree.xOff 的赋值为: -0.5/plotTree.totalW,即意为开始x 轴位置为第一个表格左边的半个表格距离位置,这样作的好处是在以后确定@位置时候可以直接加整数倍的 1/plotTree.totalW

2.plotTree函数中的一句代码如下:

firstPt = (plotTree.xOff + (1.0 + float(leafNum)) / 2.0/ plotTree.totalW, plotTree.yOff)

其中,变量plotTree.xOff即为最近绘制的一个叶子节点的x轴坐标,在确定当前节点位置时每次只需确定当前节点有几个叶子节点,因此其叶子节点所占的总距离就确定了即为: float(numLeafs)/plotTree.totalW,因此当前节点的位置即为其所有叶子节点所占距离的中间即一半为: float(numLeafs)/2.0/plotTree.totalW,但是由于开始plotTree.xOff赋值并非从0开始,而是左移了半个表格,因此还需加上半个表格距离即为: 1/2/plotTree.totalW,则加起来便为: (1.0 + float(numLeafs))/2.0/plotTree.totalW,因此偏移量确定,则x轴的位置变为: plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW

3.关于plotTree()函数的参数

plotTree(inTree, (0.5, 1.0), ‘ ‘)

plotTree()函数的第二个参数赋值为(0.5, 1.0),因为开始的根节点并不用划线,因此父节点和当前节点的位置需要重合,利用2中的确定当前节点的位置为(0.5, 1.0)

总结:利用这样的逐渐增加x 轴的坐标,以及逐渐降低y轴的坐标能能够很好的将树的叶子节点数和深度考虑进去,因此图的逻辑比例就很好的确定了,即使图像尺寸改变,我们仍然可以看到按比例绘制的树形图。

二.使用决策树预测隐形眼镜类型

这里实现一个例子,即利用决策树预测一个患者需要佩戴的隐形眼镜类型。以下是整个预测的大体步骤:

  1. 收集数据:使用书中提供的小型数据集
  2. 准备数据:对文本中的数据进行预处理,如解析数据行
  3. 分析数据:快速检查数据,并使用createPlot()函数绘制最终的树形图
  4. 训练决策树:使用createTree()函数训练
  5. 测试决策树:编写简单的测试函数验证决策树的输出结果&绘图结果
  6. 使用决策树:这部分可选择将训练好的决策树进行存储,以便随时使用

    此处新建脚本文件saveTree.py,将训练好的决策树保存在磁盘中,这里需要使用Python模块的pickle序列化对象。storeTree()函数负责把tree存放在当前目录下的filename(.txt)文件中,而getTree(filename)则是在当前目录下的filename(.txt)文件中读取决策树的相关数据。

# -*- coding: utf-8 -*-
"""
Created on Sat Sep 05 01:56:04 2015

@author: Herbert
"""

import pickle

def storeTree(tree, filename):
    fw = open(filename, ‘w‘)
    pickle.dump(tree, fw)
    fw.close()

def getTree(filename):
    fr = open(filename)
    return pickle.load(fr)

以下代码实现了决策树预测隐形眼镜模型的实例,使用的数据集是隐形眼镜数据集,它包含很多患者的眼部状况的观察条件以及医生推荐的隐形眼镜类型,其中隐形眼镜类型包括:硬材质(hard)、软材质(soft)和不适合佩戴隐形眼镜(no lenses) , 数据来源于UCI数据库。代码最后调用了之前准备好的createPlot()函数绘制树形图。

# -*- coding: utf-8 -*-
"""
Created on Sat Sep 05 14:21:43 2015

@author: Herbert
"""
import tree
import plotTree
import saveTree

fr = open(‘lenses.txt‘)
lensesData = [data.strip().split(‘\t‘) for data in fr.readlines()]
lensesLabel = [‘age‘, ‘prescript‘, ‘astigmatic‘, ‘tearRate‘]
lensesTree = tree.buildTree(lensesData, lensesLabel)
#print lensesData
print lensesTree

print plotTree.createPlot(lensesTree)

可以看到,前期实现了决策树的构建和绘制,使用不同的数据集都可以得到很直观的结果,从图中可以看到,沿着决策树的不同分支,可以得到不同患者需要佩戴的隐形眼镜的类型。

三.关于本章使用的决策树的总结

回到决策树的算法层面,以上代码的实现基于ID3决策树构造算法,它是一个非常经典的算法,但其实缺点也不少。实际上决策树的使用中常常会遇到一个问题,即“过度匹配”。有时候,过多的分支选择或匹配选项会给决策带来负面的效果。为了减少过度匹配的问题,通常算法设计者会在一些实际情况中选择“剪枝”。简单说来,如果叶子节点只能增加少许信息,则可以删除该节点。

另外,还有几种目前很流行的决策树构造算法:C4.5、C5.0和CART,后期需继续深入研究。

参考资料:http://blog.sina.com.cn/s/blog_7399ad1f01014wec.html

版权声明:本文为博主原创文章,未经博主允许不得转载。

时间: 2024-10-10 18:28:03

《机器学习实战》学习笔记:绘制树形图&使用决策树预测隐形眼镜类型的相关文章

机器学习实战学习笔记(一)

1.k-近邻算法 算法原理: 存在一个样本数据集(训练样本集),并且我们知道样本集中的每个数据与其所属分类的对应关系.输入未知类别的数据后将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似(最近邻)的k组数据.然后将k组数据中出现次数最多的分类,来作为新数据的分类. 算法步骤: 计算已知类别数据集中的每一个点与当前点之前的距离.(相似度度量) 按照距离递增次序排序 选取与当前点距离最小的k个点 确定k个点所在类别的出现频率 返回频率最高的类别作为当前点的分类 py

决策树-预测隐形眼镜类型 (ID3算法,C4.5算法,CART算法,GINI指数,剪枝,随机森林)

1. 1.问题的引入 2.一个实例 3.基本概念 4.ID3 5.C4.5 6.CART 7.随机森林 2. 我们应该设计什么的算法,使得计算机对贷款申请人员的申请信息自动进行分类,以决定能否贷款? 一个女孩的母亲要给这个女孩介绍男朋友,于是有了下面的对话: 女儿:多大年纪了? 母亲:26. 女儿:长的帅不帅? 母亲:挺帅的. 女儿:收入高不? 母亲:不算很高,中等情况. 女儿:是公务员不? 母亲:是,在税务局上班呢. 女儿:那好,我去见见. 决策过程: 这个女孩的决策过程就是典型的分类树决策.

机器学习实战-学习笔记-第一章

Added C:\Anaconda and C:\Anaconda\Scripts to PATH. C:\Anaconda>pythonPython 2.7.10 |Anaconda 2.3.0 (64-bit)| (default, May 28 2015, 16:44:52) [MSC v.1500 64 bit (AMD64)] on win32Type "help", "copyright", "credits" or "

机器学习实战-学习笔记-第十四章

1.将代码拷贝到F:\studio\MachineLearningInAction\ch14下 2.启动ipython 3.在ipython中改变工作目录到F:\studio\MachineLearningInAction\ch14 In [17]: cd F:\\studio\\MachineLearningInAction\\ch14 F:\studio\MachineLearningInAction\ch14 4.在工作目录下新建一个svdRec.py文件并加入如下代码: from num

机器学习实战读书笔记(三)决策树

3.1 决策树的构造 优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据. 缺点:可能会产生过度匹配问题. 适用数据类型:数值型和标称型. 一般流程: 1.收集数据 2.准备数据 3.分析数据 4.训练算法 5.测试算法 6.使用算法 3.1.1 信息增益 创建数据集 def createDataSet(): dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, '

《机器学习》学习笔记(一)

今天看了两集Stanford 的Machine Learning,先说说感受,在看的过程中,脑海里冒出来一个念头:在中国的大学里,教授们都是好像在做研究,而学生们都是好像在上课,到头来不知道学到了什么,我在屏幕的这边都能感受到他们和我们的不一样. 其实对于机器学习,我是真心不懂,也不知道为什么忽然就想学习一下了,然后看了第一集就觉得实在是太牛X了,他们做的那个爬越障碍物的狗和快速避障的小车,都不是我们能搞出来的,说来也奇怪,我们不是也有他们一样的课程体系吗?照理说在大学里能做出来的东西,我们也应

C++ Primer 学习笔记_102_特殊工具与技术 --运行时类型识别[续]

特殊工具与技术 --运行时类型识别[续] 三.RTTI的使用 当比较两个派生类对象的时候,我们希望比较可能特定于派生类的数据成员.如果形参是基类引用,就只能比较基类中出现的成员,我们不能访问在派生类中但不在基类中出现的成员. 因此我们可以使用RTTI,在试图比较不同类型的对象时返回假(false). 我们将定义单个相等操作符.每个类定义一个虚函数 equal,该函数首先将操作数强制转换为正确的类型.如果转换成功,就进行真正的比较:如果转换失败,equal 操作就返回 false. 1.类层次 c

[Guava学习笔记]Collections: 不可变集合, 新集合类型

不可变集合 不接受null值. 创建:ImmutableSet.copyOf(set); ImmutableMap.of(“a”, 1, “b”, 2); public static final ImmutableSet<Color> GOOGLE_COLORS = ImmutableSet.<Color>builder() .addAll(WEBSAFE_COLORS) .add(new Color(0, 191, 255)) .build(); 可以有序(如ImmutableS

springmvc学习笔记(13)-springmvc注解开发之集合类型參数绑定

springmvc学习笔记(13)-springmvc注解开发之集合类型參数绑定 springmvc学习笔记13-springmvc注解开发之集合类型參数绑定 数组绑定 需求 表现层实现 list绑定 需求 表现层实现 map绑定 本文主要介绍注解开发的集合类型參数绑定,包含数组绑定,list绑定以及map绑定 数组绑定 需求 商品批量删除,用户在页面选择多个商品.批量删除. 表现层实现 关键:将页面选择(多选)的商品id,传到controller方法的形參,方法形參使用数组接收页面请求的多个商