学习日志---树回归(回归树,模型树)

CART算法的树回归:

返回的每个节点最后是一个最终确定的平均值。

#coding:utf-8

import numpy as np

# 加载文件数据
def loadDataSet(fileName):      #general function to parse tab -delimited floats
    dataMat = []                #assume last column is target value
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split(‘\t‘)
        fltLine = map(float,curLine) #map all elements to float()
        dataMat.append(fltLine)
    return dataMat

#在dataset中选择特征为feature的这一列,以value值分成两部分
def binSplitDataSet(dataSet, feature, value):
    mat0 = dataSet[np.nonzero(dataSet[:,feature] > value)[0],:][0]
    mat1 = dataSet[np.nonzero(dataSet[:,feature] <= value)[0],:][0]
    return mat0,mat1

#计算此矩阵的最后一列结果的平均值,用平均值来当做最后的返回结果,后面的模型树返回的是一个 线性模型
def regLeaf(dataSet):
    return np.mean(dataSet[:,-1])

#计算dataset结果的混乱程度,用方差反应,因为是连续数据
def regErr(dataSet):
    return np.var(dataSet[:,-1]) * np.shape(dataSet)[0]

#选择最佳的分离特征和该特征的分离点
#这里的ops是预先的给定值,1是差别太小就不分了,4是分开后的各自样本数,太小就舍去,这是一  种预剪枝方法
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0]; tolN = ops[1]
    #if all the target variables are the same value: quit and return value
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1
        return None, leafType(dataSet)
    m,n = np.shape(dataSet)
    #the choice of the best feature is driven by Reduction in RSS error from mean
    S = errType(dataSet)
    bestS = np.inf; bestIndex = 0; bestValue = 0
    #循环所有的特征
    for featIndex in range(n-1):
        #循环该特征下的所有特征值
        for splitVal in set(dataSet[:,featIndex]):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            #如果更具这个特征值分成的两类有一个小与预先给定值,说明分类太偏,则不考虑
            if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    #if the decrease (S-bestS) is less than a threshold don‘t do the split
    if (S - bestS) < tolS:
        return None, leafType(dataSet) #exit cond 2
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):  #exit cond 3
        return None, leafType(dataSet)
    return bestIndex,bestValue              

#创建树
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):   
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)       
    if feat == None: return val                                        
    retTree = {}
    retTree[‘spInd‘] = feat
    retTree[‘spVal‘] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree[‘left‘] = createTree(lSet, leafType, errType, ops)
    retTree[‘right‘] = createTree(rSet, leafType, errType, ops)
    return retTree

myDat = loadDataSet(‘ex0.txt‘)
myMat = np.mat(myDat)
result = createTree(myMat)
print result

结果:

{‘spInd‘: 1, ‘spVal‘: matrix([[ 0.39435]]), ‘right‘: {‘spInd‘: 1, ‘spVal‘: matrix([[ 0.197834]]), ‘right‘: -0.023838155555555553, ‘left‘: 1.0289583666666666}, ‘left‘: {‘spInd‘: 1, ‘spVal‘: matrix([[ 0.582002]]), ‘right‘: 1.980035071428571, ‘left‘: {‘spInd‘: 1, ‘spVal‘: matrix([[ 0.797583]]), ‘right‘: 2.9836209534883724, ‘left‘: 3.9871631999999999}}}

结果的意思是:第几个特征,以多大作为特征值分开,分成左右,依次分下去。

这个算法很好,但是对数据的分类太过于高,容易造成过拟合。因此要采用剪枝技术。

通过降低决策树的复杂度来避免过拟合的过程称为剪枝。

#判断obj是否是一个子树
def isTree(obj):
    return (type(obj).__name__==‘dict‘)

#用于坍塌处理,当测试数据集是空是,则取整个树的平均值
def getMean(tree):
    if isTree(tree[‘right‘]): tree[‘right‘] = getMean(tree[‘right‘])
    if isTree(tree[‘left‘]): tree[‘left‘] = getMean(tree[‘left‘])
    return (tree[‘left‘]+tree[‘right‘])/2.0

#剪枝函数
def prune(tree, testData):
    
    #如果测试数据集为空,则坍塌处理
    if np.shape(testData)[0] == 0: return getMean(tree)   
    
    #如果左或者右是树,则把测试数据集根据决策树进行分割
    if (isTree(tree[‘right‘]) or isTree(tree[‘left‘])):
        lSet, rSet = binSplitDataSet(testData, tree[‘spInd‘], tree[‘spVal‘])
    
    #如果左侧是树,则把数据集和子树带入继续找
    if isTree(tree[‘left‘]): tree[‘left‘] = prune(tree[‘left‘], lSet)
    #同理
    if isTree(tree[‘right‘]): tree[‘right‘] =  prune(tree[‘right‘], rSet)
    #if they are now both leafs, see if we can merge them
    #如果左右都是节点,则计算节点误差
    if not isTree(tree[‘left‘]) and not isTree(tree[‘right‘]):
        lSet, rSet = binSplitDataSet(testData, tree[‘spInd‘], tree[‘spVal‘])
        #计算不合并的误差
        errorNoMerge = sum(np.power(lSet[:,-1] - tree[‘left‘],2)) + sum(np.power(rSet[:,-1] - tree[‘right‘],2))
        treeMean = (tree[‘left‘]+tree[‘right‘])/2.0
        #计算将当前两个叶子节点合并后的误差
        errorMerge = sum(np.power(testData[:,-1] - treeMean,2))
        if errorMerge < errorNoMerge:
            print "merging"
            #可以合并就返回平均值
            return treeMean
        #不可以合并就返回树,不变
        else: return tree
    else: return tree

一般来说都是预剪枝和后剪枝合并使用

模型树

每个节点是一个线性模型

其他基本一样:

#对数据集进行线性回归
def linearSolve(dataSet):
    m,n = np.shape(dataSet)
    X = np.mat(np.ones((m,n))); Y = np.mat(np.ones((m,1)))
    #有一列是常数项,因此要多出一列放置常数项
    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]
    xTx = X.T*X
    if np.linalg.det(xTx) == 0.0:
        raise NameError(‘This matrix is singular, cannot do inverse,\n        try increasing the second value of ops‘)
    ws = xTx.I * (X.T * Y)
    return ws,X,Y

#产生针对该数据集的线性模型
#相当于上面的regLeaf函数
def modelLeaf(dataSet):
    ws,X,Y = linearSolve(dataSet)
    return ws

#产生针对该数据集的线性模型,并计算误差返回
#相当于上面的regErr函数,计算模型的误差,如果分后和不分的误差差不多则选择不分
def modelErr(dataSet):
    ws,X,Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(np.power(Y - yHat,2))

模型树回归很好,而且可以用作预测

时间: 2024-11-19 21:22:11

学习日志---树回归(回归树,模型树)的相关文章

机器学习——模型树

和回归树(在每个叶节点上使用各自的均值做预测)不同,模型树算法需要在每个叶节点上都构建出一个线性模型,这就是把叶节点设定为分段线性函数,这个所谓的分段线性(piecewise linear)是指模型由多个线性片段组成. #####################模型树##################### def linearSolve(dataSet): #模型树的叶节点生成函数 m,n = shape(dataSet) X = mat(ones((m,n))); Y = mat(ones

机器学习day14 机器学习实战树回归之CART与模型树

这几天完成了树回归的相关学习,这一部分内容挺多,收获也挺多,刚刚终于完成了全部内容,非常开心. 树回归这一章涉及了CART,CART树称作(classify and regression tree) 分类与回归树,既可以用于分类,也可以用于回归.这正是前面决策树没有说到的内容,在这里补充一下.正好也总结一下我们学的3种决策树. ID3:用信息增益来选择特性进行分类,只能处理分类问题.缺点是往往偏向于特性种类多的特性进行分解,比如特性A有2种选择,特性B有3种选择,混乱度差不多的情况下,ID3会偏

机器学习经典算法详解及Python实现--CART分类决策树、回归树和模型树

摘要: Classification And Regression Tree(CART)是一种很重要的机器学习算法,既可以用于创建分类树(Classification Tree),也可以用于创建回归树(Regression Tree),本文介绍了CART用于离散标签分类决策和连续特征回归时的原理.决策树创建过程分析了信息混乱度度量Gini指数.连续和离散特征的特殊处理.连续和离散特征共存时函数的特殊处理和后剪枝:用于回归时则介绍了回归树和模型树的原理.适用场景和创建过程.个人认为,回归树和模型树

模型树——就是回归树的分段常数预测修改为线性回归 对于非线性回归有较好的预测效果

说完了树回归,再简单的提下模型树,因为树回归每个节点是一些特征和特征值,选取的原则是根据特征方差最小.如果把叶子节点换成分段线性函数,那么就变成了模型树,如(图六)所示: (图六) (图六)中明显是两个直线组成,以X坐标(0.0-0.3)和(0.3-1.0)分成的两个线段.如果我们用两个叶子节点保存两个线性回归模型,就完成了这部分数据的拟合.实现也比较简单,代码如下: [python] view plain copy def linearSolve(dataSet):   #helper fun

支配树学习日志

支配树 学习日志 给定一张有向图 $G=(V, E)$,其中 $\lvert V \rvert=n, \lvert E \rvert=m$,以及根 $r \in V$. 我们称顶点 $x\ (x \ne r)$ 可达,当且仅当存在一条从 $r$ 到 $x$ 的路径. 对于 $x \ne r$ 且可达的 $x$,如果 $y \ne x$,且删去 $y$ 后 $x$ 不可达,那么就说 $y$ 支配 $x$.特别地,$r$ 一定支配 $x$. 不可达的点的支配点没有定义,因此我们不妨设 $G$ 的所有

树状结构Java模型、层级关系Java模型、上下级关系Java模型与html页面展示

树状结构Java模型.层级关系Java模型.上下级关系Java模型与html页面展示 一.业务原型:公司的组织结构.传销关系网 二.数据库模型 很简单,创建 id 与 pid 关系即可.(pid:parent_id) 三.Java模型 (我们把这张网撒在html的一张表里.其实用ul来展示会简单N多,自己思考为什么LZ会选择放在表里) private class Table {        private Long id; // 当前对象的id         private int x; /

吴裕雄 python 机器学习——集成学习随机森林RandomForestRegressor回归模型

import numpy as np import matplotlib.pyplot as plt from sklearn import datasets,ensemble from sklearn.model_selection import train_test_split def load_data_regression(): ''' 加载用于回归问题的数据集 ''' #使用 scikit-learn 自带的一个糖尿病病人的数据集 diabetes = datasets.load_di

数据结构学习笔记04树(堆 哈夫曼树 并查集)

一.堆(heap) 优先队列(Priority Queue):特殊的“队列”,取出元素的顺序是依照元素的优先权(关键字)大小,而不是元素进入队列的先后顺序. 数组 : 插入 — 元素总是插入尾部 ~ O ( 1 ) 删除 — 查找最大(或最小)关键字 ~ O ( n ) 从数组中删去需要移动元素 ~ O( n ) 链表: 插入 — 元素总是插入链表的头部 ~ O ( 1 ) 删除 — 查找最大(或最小)关键字 ~ O ( n ) 删去结点 ~ O( 1 ) 有序数组: 插入 — 找到合适的位置

浅谈二维中的树状数组与线段树

一般来说,树状数组可以实现的东西线段树均可胜任,实际应用中也是如此.但是在二维中,线段树的操作变得太过复杂,更新子矩阵时第一维的lazy标记更是麻烦到不行. 但是树状数组在某些询问中又无法胜任,如最值等不符合区间减法的询问.此时就需要根据线段树与树状数组的优缺点来选择了. 做一下基本操作的对比,如下图. 因为线段树为自上向下更新,从而可以使用lazy标记使得矩阵的更新变的高校起来,几个不足就是代码长,代码长和代码长. 对于将将矩阵内元素变为某个值,因为树状数组自下向上更新,且要满足区间加法等限制