深入学习高级非线性回归算法 --- 树回归系列算法

前言

  前文讨论的回归算法都是全局且针对线性问题的回归,即使是其中的局部加权线性回归法,也有其弊端(具体请参考前文:)

  采用全局模型会导致模型非常的臃肿,因为需要计算所有的样本点,而且现实生活中很多样本都有大量的特征信息。

  另一方面,实际生活中更多的问题都是非线性问题。

  针对这些问题,有了树回归系列算法。

回归树

  在先前决策树 (链接) 的学习中,构建树是采用的 ID3 算法。在回归领域,该算法就有个问题,就是派生子树是按照所有可能值来进行派生。

  因此 ID3 算法无法处理连续性数据。

  故可使用二元切分法,以某个特定值为界进行切分。在这种切分法下,子树个数小于等于2。

  除此之外,再修改择优原则香农熵 (因为数据变为连续型的了),便可将树构建成一棵可用于回归的树,这样一棵树便叫做回归树。

  构建回归树的伪代码:

1 找到最佳的待切分特征:
2     如果该节点不能再分,将此节点存为叶节点。
3     执行二元切分
4     左右子树分别递归调用此函数

  二元切分的伪代码:

1 对每个特征:
2     对每个特征值:
3         将数据集切成两份
4         计算切分误差
5         如果当前误差小于最小误差,则更新最佳切分以及最小误差。

  特别说明,终止划分 (并直接建立叶节点)有三种情况:

  1. 剩余特征值仅 1

  2. 划分子集太小

  3. 划分后误差改进不大

  这几个操作被称做 "预剪枝"。

  下面给出一个完整的回归树的小程序:

  1 #!/usr/bin/env python
  2 # -*- coding:UTF-8 -*-
  3
  4 ‘‘‘
  5 Created on 2015-01-05
  6
  7 @author: fangmeng
  8 ‘‘‘
  9
 10 from numpy import *
 11
 12 def loadDataSet(fileName):
 13     ‘载入测试数据‘
 14
 15     dataMat = []
 16     fr = open(fileName)
 17     for line in fr.readlines():
 18         curLine = line.strip().split(‘\t‘)
 19         # 所有元素转换为浮点类型(函数编程)
 20         fltLine = map(float,curLine)
 21         dataMat.append(fltLine)
 22     return dataMat
 23
 24 #============================
 25 # 输入:
 26 #        dataSet: 待切分数据集
 27 #        feature: 切分特征序号
 28 #        value:    切分值
 29 # 输出:
 30 #        mat0,mat1: 切分结果
 31 #============================
 32 def binSplitDataSet(dataSet, feature, value):
 33     ‘切分数据集‘
 34
 35     mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
 36     mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
 37     return mat0,mat1
 38
 39 #========================================
 40 # 输入:
 41 #        dataSet: 数据集
 42 # 输出:
 43 #        mean(dataSet[:,-1]): 均值(也就是叶节点的内容)
 44 #========================================
 45 def regLeaf(dataSet):
 46     ‘生成叶节点‘
 47
 48     return mean(dataSet[:,-1])
 49
 50 #========================================
 51 # 输入:
 52 #        dataSet: 数据集
 53 # 输出:
 54 #        var(dataSet[:,-1]) * shape(dataSet)[0]: 平方误差
 55 #========================================
 56 def regErr(dataSet):
 57     ‘计算平方误差‘
 58
 59     return var(dataSet[:,-1]) * shape(dataSet)[0]
 60
 61 #========================================
 62 # 输入:
 63 #        dataSet: 数据集
 64 #        leafType: 叶子节点生成器
 65 #        errType: 误差统计器
 66 #        ops: 相关参数
 67 # 输出:
 68 #        bestIndex: 最佳划分特征
 69 #        bestValue: 最佳划分特征值
 70 #========================================
 71 def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
 72     ‘选择最优划分‘
 73
 74     # 获得相关参数中的最大样本数和最小误差效果提升值
 75     tolS = ops[0];
 76     tolN = ops[1]
 77
 78     # 如果所有样本点的值一致,那么直接建立叶子节点。
 79     if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
 80         return None, leafType(dataSet)
 81
 82     m,n = shape(dataSet)
 83     # 当前误差
 84     S = errType(dataSet)
 85     # 最小误差
 86     bestS = inf;
 87     # 最小误差对应的划分方式
 88     bestIndex = 0;
 89     bestValue = 0
 90
 91     # 对于所有特征
 92     for featIndex in range(n-1):
 93         # 对于某个特征的所有特征值
 94         for splitVal in set(dataSet[:,featIndex]):
 95             # 划分
 96             mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
 97             # 如果划分后某个子集中的个数不达标
 98             if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
 99             # 当前划分方式的误差
100             newS = errType(mat0) + errType(mat1)
101             # 如果这种划分方式的误差小于最小误差
102             if newS < bestS:
103                 bestIndex = featIndex
104                 bestValue = splitVal
105                 bestS = newS
106
107     # 如果当前划分方式还不如不划分时候的误差效果
108     if (S - bestS) < tolS:
109         return None, leafType(dataSet)
110     # 按照最优划分方式进行划分
111     mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
112     # 如果划分后某个子集中的个数不达标
113     if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
114         return None, leafType(dataSet)
115
116     return bestIndex,bestValue
117
118 #========================================
119 # 输入:
120 #        dataSet: 数据集
121 #        leafType: 叶子节点生成器
122 #        errType: 误差统计器
123 #        ops: 相关参数
124 # 输出:
125 #        retTree: 回归树
126 #========================================
127 def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
128     ‘构建回归树‘
129
130     # 选择最佳划分方式
131     feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
132     # feat为None的时候无需划分返回叶子节点
133     if feat == None: return val #if the splitting hit a stop condition return val
134
135     # 递归调用构建函数并更新树
136     retTree = {}
137     retTree[‘spInd‘] = feat
138     retTree[‘spVal‘] = val
139     lSet, rSet = binSplitDataSet(dataSet, feat, val)
140     retTree[‘left‘] = createTree(lSet, leafType, errType, ops)
141     retTree[‘right‘] = createTree(rSet, leafType, errType, ops)
142
143     return retTree
144
145 def test():
146     ‘展示结果‘
147
148     # 载入数据
149     myDat = loadDataSet(‘/home/fangmeng/ex0.txt‘)
150     # 构建回归树
151     myDat = mat(myDat)
152
153     print createTree(myDat)
154
155
156 if __name__ == ‘__main__‘:
157     test()

  测试结果:

  

回归树的优化工作 - 剪枝

  在上面的代码中,终止递归的条件中已经加入了重重的 "剪枝" 工作。

  这些在建树的时候的剪枝操作通常被成为预剪枝。这是很有很有必要的,经过预剪枝的树几乎就是没有预剪枝树的大小的百分之一甚至更小,而性能相差无几。

  而在树建立完毕之后,基于训练集和测试集能做更多更高效的剪枝工作,这些工作叫做 "后剪枝"。

  可见,剪枝是一项较大的工作量,是对树非常关键的优化过程。

  后剪枝过程的伪代码如下:

1 基于已有的树切分测试数据:
2     如果存在任一子集是一棵树,则在该子集上递归该过程。
3     计算将当前两个叶节点合并后的误差
4     计算不合并的误差
5     如果合并会降低误差,则将叶节点合并。

  具体实现函数如下:

 1 #===================================
 2 # 输入:
 3 #        obj: 判断对象
 4 # 输出:
 5 #        (type(obj).__name__==‘dict‘): 判断结果
 6 #===================================
 7 def isTree(obj):
 8     ‘判断对象是否为树类型‘
 9
10     return (type(obj).__name__==‘dict‘)
11
12 #===================================
13 # 输入:
14 #        tree: 处理对象
15 # 输出:
16 #        (tree[‘left‘]+tree[‘right‘])/2.0: 坍塌后的替代值
17 #===================================
18 def getMean(tree):
19     ‘坍塌处理‘
20
21     if isTree(tree[‘right‘]): tree[‘right‘] = getMean(tree[‘right‘])
22     if isTree(tree[‘left‘]): tree[‘left‘] = getMean(tree[‘left‘])
23
24     return (tree[‘left‘]+tree[‘right‘])/2.0
25
26 #===================================
27 # 输入:
28 #        tree: 处理对象
29 #        testData: 测试数据集
30 # 输出:
31 #        tree: 剪枝后的树
32 #===================================
33 def prune(tree, testData):
34     ‘后剪枝‘
35
36     # 无测试数据则坍塌此树
37     if shape(testData)[0] == 0:
38         return getMean(tree)
39
40     # 若左/右子集为树类型
41     if (isTree(tree[‘right‘]) or isTree(tree[‘left‘])):
42         # 划分测试集
43         lSet, rSet = binSplitDataSet(testData, tree[‘spInd‘], tree[‘spVal‘])
44     # 在新树新测试集上递归进行剪枝
45     if isTree(tree[‘left‘]): tree[‘left‘] = prune(tree[‘left‘], lSet)
46     if isTree(tree[‘right‘]): tree[‘right‘] =  prune(tree[‘right‘], rSet)
47
48     # 如果两个子集都是叶子的话,则在进行误差评估后决定是否进行合并。
49     if not isTree(tree[‘left‘]) and not isTree(tree[‘right‘]):
50         lSet, rSet = binSplitDataSet(testData, tree[‘spInd‘], tree[‘spVal‘])
51         errorNoMerge = sum(power(lSet[:,-1] - tree[‘left‘],2)) +sum(power(rSet[:,-1] - tree[‘right‘],2))
52         treeMean = (tree[‘left‘]+tree[‘right‘])/2.0
53         errorMerge = sum(power(testData[:,-1] - treeMean,2))
54         if errorMerge < errorNoMerge:
55             return treeMean
56         else: return tree
57     else: return tree

模型树

  这也是一种很棒的树回归算法。

  该算法将所有的叶子节点不是表述成一个值,而是对叶子部分节点建立线性模型。比如可以是最小二乘法的基本线性回归模型。

  这样在叶子节点里存放的就是一组线性回归系数了。非叶子节点部分构造就和回归树一样。

  这个是上面建立回归树算法的函数头:

  createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):

  对于模型树,只需要修改修改 leafType(叶节点构造器) 和 errType(误差分析器) 的实现即可,分别对应如下modelLeaf 函数和 modelErr 函数:

 1 #=========================
 2 # 输入:
 3 #        dataSet: 测试集
 4 # 输出:
 5 #        ws,X,Y: 回归模型
 6 #=========================
 7 def linearSolve(dataSet):
 8     ‘辅助函数,用于构建线性回归模型。‘
 9
10     m,n = shape(dataSet)
11     X = mat(ones((m,n)));
12     Y = mat(ones((m,1)))
13     X[:,1:n] = dataSet[:,0:n-1];
14     Y = dataSet[:,-1]
15     xTx = X.T*X
16     if linalg.det(xTx) == 0.0:
17         raise NameError(‘系数矩阵不可逆‘)
18     ws = xTx.I * (X.T * Y)
19     return ws,X,Y
20
21 #=======================
22 # 输入:
23 #       dataSet: 数据集
24 # 输出:
25 #        ws: 回归系数
26 #=======================
27 def modelLeaf(dataSet):
28     ‘叶节点构造器‘
29
30     ws,X,Y = linearSolve(dataSet)
31     return ws
32
33 #=======================================
34 # 输入:
35 #       dataSet: 数据集
36 # 输出:
37 #        sum(power(Y - yHat,2)): 平方误差
38 #=======================================
39 def modelErr(dataSet):
40     ‘误差分析器‘
41
42     ws,X,Y = linearSolve(dataSet)
43     yHat = X * ws
44     return sum(power(Y - yHat,2))

回归树 / 模型树的使用

  前面的工作主要介绍了两种树 - 回归树,模型树的构建,下面进一步学习如何利用这些树来进行预测。

  当然,本质也就是递归遍历树。

  下为遍历代码,通过修改参数设置要使用并传递进来的是回归树还是模型树:

#==============================
# 输入:
#       model: 叶子
#       inDat: 测试数据
# 输出:
#        float(model): 叶子值
#==============================
def regTreeEval(model, inDat):
    ‘回归树预测‘

    return float(model)

#==============================
# 输入:
#       model: 叶子
#       inDat: 测试数据
# 输出:
#        float(X*model): 叶子值
#==============================
def modelTreeEval(model, inDat):
    ‘模型树预测‘
    n = shape(inDat)[1]
    X = mat(ones((1,n+1)))
    X[:,1:n+1]=inDat
    return float(X*model)

#==============================
# 输入:
#        tree: 待遍历树
#        inDat: 测试数据
#        modelEval: 叶子值获取器
# 输出:
#        分类结果
#==============================
def treeForeCast(tree, inData, modelEval=regTreeEval):
    ‘使用回归/模型树进行预测 (modelEval参数指定)‘

    # 如果非树类型,返回值。
    if not isTree(tree): return modelEval(tree, inData)

    # 左遍历
    if inData[tree[‘spInd‘]] > tree[‘spVal‘]:
        if isTree(tree[‘left‘]): return treeForeCast(tree[‘left‘], inData, modelEval)
        else: return modelEval(tree[‘left‘], inData)

    # 右遍历
    else:
        if isTree(tree[‘right‘]): return treeForeCast(tree[‘right‘], inData, modelEval)
        else: return modelEval(tree[‘right‘], inData)

  使用方法非常简单,将树和要分类的样本传递进去就可以了。如果是模型树就将分类函数 treeForeCast 的第三个参数改为modelTreeEval即可。

  这里就不再演示实验具体过程了。

小结

  1. 选择哪个回归方法,得看哪个方法的相关系数高。(可使用 corrcoef 函数计算)

  2. 树的回归和分类算法其实本质上都属于贪心算法,不断去寻找局部最优解。

  3. 关于回归的讨论就先告一段落,接下来将进入到无监督学习部分。

时间: 2024-07-31 18:22:22

深入学习高级非线性回归算法 --- 树回归系列算法的相关文章

数据结构算法 (树 的基本算法)

一.树的序列化 和反序列化 1) 将二叉树进行序列化  和反序列化; 使用的是前序. 1 package com.tree; 2 3 import java.util.LinkedList; 4 import java.util.Queue; 5 6 // 将一个两叉树 序列化成 字符串 ; 7 // 和将一字符串 反序列为一个树. 8 public class TreeNode_Serialization { 9 public static void main(String[] args) {

集成学习之梯度提升树(GBDT)算法

梯度提升树(GBDT)的全称是Gradient Boosting Decision Tree.GBDT还有很多的简称,例如GBT(Gradient Boosting Tree), GTB(Gradient Tree Boosting ),GBRT(Gradient Boosting Regression Tree), MART(Multiple Additive Regression Tree)等,其实都是指的同一种算法,本文统一简称GBDT. GBDT 也是 Boosting 算法的一种,但是

学习日志---树回归(回归树,模型树)

CART算法的树回归: 返回的每个节点最后是一个最终确定的平均值. #coding:utf-8 import numpy as np # 加载文件数据 def loadDataSet(fileName):      #general function to parse tab -delimited floats     dataMat = []                #assume last column is target value     fr = open(fileName)  

深入浅出排序学习:写给程序员的算法系统开发实践

引言 我们正处在一个知识爆炸的时代,伴随着信息量的剧增和人工智能的蓬勃发展,互联网公司越发具有强烈的个性化.智能化信息展示的需求.而信息展示个性化的典型应用主要包括搜索列表.推荐列表.广告展示等等. 很多人不知道的是,看似简单的个性化信息展示背后,涉及大量的数据.算法以及工程架构技术,这些足以让大部分互联网公司望而却步.究其根本原因,个性化信息展示背后的技术是排序学习问题(Learning to Rank).市面上大部分关于排序学习的文章,要么偏算法.要么偏工程.虽然算法方面有一些系统性的介绍文

机器学习day14 机器学习实战树回归之CART与模型树

这几天完成了树回归的相关学习,这一部分内容挺多,收获也挺多,刚刚终于完成了全部内容,非常开心. 树回归这一章涉及了CART,CART树称作(classify and regression tree) 分类与回归树,既可以用于分类,也可以用于回归.这正是前面决策树没有说到的内容,在这里补充一下.正好也总结一下我们学的3种决策树. ID3:用信息增益来选择特性进行分类,只能处理分类问题.缺点是往往偏向于特性种类多的特性进行分解,比如特性A有2种选择,特性B有3种选择,混乱度差不多的情况下,ID3会偏

笔记︱集成学习Ensemble Learning与树模型、Bagging 和 Boosting

本杂记摘录自文章<开发 | 为什么说集成学习模型是金融风控新的杀手锏?> 基本内容与分类见上述思维导图. . . 一.机器学习元算法 随机森林:决策树+bagging=随机森林 梯度提升树:决策树Boosting=GBDT  . 1.随机森林 博客: R语言︱决策树族--随机森林算法 随机森林的原理是基于原始样本随机抽样获取子集,在此之上训练基于决策树的基学习器,然后对基学习器的结果求平均值,最终得到预测值.随机抽样的方法常用的有放回抽样的booststrap,也有不放回的抽样.RF的基学习器

机器学习&amp;深度学习基础(tensorflow版本实现的算法概述0)

tensorflow集成和实现了各种机器学习基础的算法,可以直接调用. 监督学习 1)决策树(Decision Tree) 决策树是一种树形结构,为人们提供决策依据,决策树可以用来回答yes和no问题,它通过树形结构将各种情况组合都表示出来,每个分支表示一次选择(选择yes还是no),直到所有选择都进行完毕,最终给出正确答案. 决策树(decision tree)是一个树结构(可以是二叉树或非二叉树).在实际构造决策树时,通常要进行剪枝,这时为了处理由于数据中的噪声和离群点导致的过分拟合问题.剪

php学习高级-提高PHP编程效率的几点建议

1.如果能将类的方法定义成static,就尽量定义成static,它的速度会提升将近4倍. 2.$row['id'] 的速度是$row[id]的7倍. 3.echo 比 print 快,并且使用echo的多重参数(译注:指用逗号而不是句点)代替字符串连接,比如echo $str1,$str2. 4.在执行for循环之前确定最大循环数,不要每循环一次都计算最大值,最好运用foreach代替. 5.注销那些不用的变量尤其是大数组,以便释放内存. 6.尽量避免使用__get,__set,__autol

NLTK学习笔记(四):自然语言处理的一些算法研究

自然语言处理中算法设计有两大部分:分而治之 和 转化 思想.一个是将大问题简化为小问题,另一个是将问题抽象化,向向已知转化.前者的例子:归并排序:后者的例子:判断相邻元素是否相同(与排序). 这次总结的自然语言中常用的一些基本算法,算是入个门了. 递归 使用递归速度上会受影响,但是便于理解算法深层嵌套对象.而一些函数式编程语言会将尾递归优化为迭代. 如果要计算n个词有多少种组合方式?按照阶乘定义:n! = n*(n-1)*...*1 def func(wordlist): length = le