Weka算法Classifier-meta-AdditiveRegression源码分析

博主最近迷上了打怪物猎人,这片文章拖了很久才开始动笔

一、算法

AdditiveRegression,换个更出名一点的叫法可以称作GBDT(Grandient Boosting Decision Tree)梯度下降分类树,或者GBRT(Grandient Boosting Regression Tree)梯度下降回归树,是一种多分类器组合的算法,更确切的说,是属于Boosting算法。

谈到Boosting算法,就不能不提AdaBoost,参见之前我写的博客,可以看到AdaBoost的核心是级联分类器,使后一级分类器更加“关注”较为容易分错的数据,即后一级的分类器更有在易出错的数据集上进行训练。。

而GBDT作为Boosting算法,也是将多分类器进行级联训练,后一级的分类器则更多关注前面所有分类器预测结果与实际结果的残差,在这个残差上训练新的分类器,最终预测时将残差级联相加。

关于GBDT相关算法的公式推导可参考:

http://en.wikipedia.org/wiki/Gradient_boosting#Gradient_tree_boosting

http://www.360doc.com/content/12/0428/15/5874309_207282768.shtml

扯了这么多,下面简单说一下算法训练流程。

(1)输入训练集Data和基分类器的数量N

(2)使用训练集Data训练第1个基分类器

(3)for (int i=2;i<N;i++)

(4)使用前i-1个分类器进行预测,计算预测结果和训练数据的残差

(5)如果残差小于某个阈值,则退出循环。

(5)使用此残差训练第i个分类器

(6)转(3)

预测流程:

(1)根据输入数据,计算N个分类器的预测结果。

(2)将预测结果相加并返回。

可以看到,GBDT从原理上来讲并不复杂,“残差”的概念就用梯度来进行标示,抓住这一个线索看懂Wiki中的推导公式也并不是难事。复杂的是“如何证明其有效性”,这远超过本文可论证的范畴。

二、源码实现

就像之前所有的分类器一样,依然从buildClassifier入手。

(1)buildClassifier

public void buildClassifier(Instances data) throws Exception {

    super.buildClassifier(data);

    //additiveRegerssion只支持数值型数据。
    getCapabilities().testWithFail(data);

    //如果训练数据的class列为空,则去掉
    Instances newData = new Instances(data);
    newData.deleteWithMissingClass();

    double sum = 0;
    double temp_sum = 0;
    // 第一个分类器使用ZeroR,也就是预测的值是训练值的均值,没有使用基分类器(默认的基分类器为weka.classifiers.trees.DecisionStump(),也就是单层决策树(决策桩)
    m_zeroR = new ZeroR();
    m_zeroR.buildClassifier(newData);

    // 如果只有一列,则没法训练
    if (newData.numAttributes() == 1) {
      System.err.println(
	  "Cannot build model (only class attribute present in data!), "
	  + "using ZeroR model instead!");
      m_SuitableData = false;
      return;
    }
    else {
      m_SuitableData = true;
    }
    //这个residualReplace函数会将数据集用某个分类器进行分类后,再将其class列替换为残差,这个稍后详细分析一下。
    newData = residualReplace(newData, m_zeroR, false);
    for (int i = 0; i < newData.numInstances(); i++) {
      sum += newData.instance(i).weight() *
	newData.instance(i).classValue() * newData.instance(i).classValue();//这里计算了加权的残差平方和
    }
    if (m_Debug) {
      System.err.println("Sum of squared residuals "
			 +"(predicting the mean) : " + sum);
    }

    m_NumIterationsPerformed = 0;
    do {
      temp_sum = sum;

      // Build the classifier
      m_Classifiers[m_NumIterationsPerformed].buildClassifier(newData);//在新的数据集上训练,注意新的数据集的class已经替换为残差了,体现了gradient boosting思想

      newData = residualReplace(newData, m_Classifiers[m_NumIterationsPerformed], true);//再重新替换为残差
      sum = 0;
      for (int i = 0; i < newData.numInstances(); i++) {
	sum += newData.instance(i).weight() *
	  newData.instance(i).classValue() * newData.instance(i).classValue();//重新计算残差平方和
      }
      if (m_Debug) {
	System.err.println("Sum of squared residuals : "+sum);
      }
      m_NumIterationsPerformed++;
    } while (((temp_sum - sum) > Utils.SMALL) &&
	     (m_NumIterationsPerformed < m_Classifiers.length));//退出条件有2个,第一个是两次迭代残差平方没有明显变化,第二个是已训练完所有分类器。
  }

算法思想很简单,代码也很直观。

下面分析一下residualReplace函数。

(2)residualReplace

private Instances residualReplace(Instances data, Classifier c,
				    boolean useShrinkage) throws Exception {
    double pred,residual;
    Instances newInst = new Instances(data);

    for (int i = 0; i < newInst.numInstances(); i++) {
      pred = c.classifyInstance(newInst.instance(i)); //进行预测
      if (useShrinkage) {
	pred *= getShrinkage();//使用shrinkage来防止过拟合
      }
      residual = newInst.instance(i).classValue() - pred;//算出残差
      newInst.instance(i).setClassValue(residual);//原始数据的class用残差替换
    }
    //    System.err.print(newInst);
    return newInst;
  }

什么是shrinkage?

shrinkage(缩减)的思想认为,每次走一小步逐渐逼近结果的效果,要比每次迈一大步很快逼近结果的方式更容易避免过拟合。即它不完全信任每一个棵残差树,它认为每棵树只学到了真理的一小部分,累加的时候只累加一小部分,通过多学几棵树弥补不足。(转自http://blog.csdn.net/w28971023/article/details/8240756)

可以看到,残差本身可以理解成“希望分类器结果前进的向量”,也就是梯度的含义,即包含了方向(分类器往哪个方向调整),也包含了长度(调整多少)。而shrinkage就是缩小这个长度到一定的比值,如10%,这样每次在这个向量方向上前进10%,以此来防止过拟合。

为什么shrinkage能防止过拟合?这又是一个看上去就复杂的不得了的问题啊。。。。

(3)classifyInstance

public double classifyInstance(Instance inst) throws Exception {

    double prediction = m_zeroR.classifyInstance(inst);

    if (!m_SuitableData) {
      return prediction;
    }

    for (int i = 0; i < m_NumIterationsPerformed; i++) {
      double toAdd = m_Classifiers[i].classifyInstance(inst);
      toAdd *= getShrinkage();
      prediction += toAdd;
    }

    return prediction;
  }

按照分类器顺序把残差相加得到最终结果。

四、总结

如果非要写个什么总结的话,那么我希望是以下几点:

(1)gbdt思想简单,实现起来也简单,效果非常理想。

(2)weka的additiveRegression是一个gbrt的简单实现,只能处理数值型数据。

(3)其实现的核心逻辑是用残差替换原有数据集的class列。

(4)可以选择性的使用shrinkage来防止过拟合。

时间: 2024-07-30 10:15:33

Weka算法Classifier-meta-AdditiveRegression源码分析的相关文章

K-近邻算法的Python实现 : 源码分析

网上介绍K-近邻算法的例子很多,其Python实现版本基本都是来自于机器学习的入门书籍<机器学习实战>,虽然K-近邻算法本身很简单,但很多初学者对其Python版本的源代码理解不够,所以本文将对其源代码进行分析. 什么是K-近邻算法? 简单的说,K-近邻算法采用不同特征值之间的距离方法进行分类.所以它是一个分类算法. 优点:无数据输入假定,对异常值不敏感 缺点:复杂度高 好了,直接先上代码,等会在分析:(这份代码来自<机器学习实战>) def classify0(inx, data

OpenStack_Swift源码分析——Ring的rebalance算法源代码详细分析

今天有同学去百度,带回一道面试题,和大家分享一下: 打印: n=1 1 n=2 3 3 2 4 1 1 4 5 5 n=3 7 7 7 7 6 8 3 3 2 6 8 4 1 1 6 8 4 5 5 5 8 9 9 9 9 提供一段参考程序: <pre name="code" class="cpp">// ConsoleApplication1.cpp: 主项目文件. #include "stdafx.h" #include &quo

Openck_Swift源码分析——增加、删除设备时算法具体的实现过程

1 初始添加设备后.上传Object的具体流程 前几篇博客中,我们讲到环的基本原理即具体的实现过程,加入我们在初始创建Ring是执行如下几条命令: ?swift-ring-builder object.builder create 5 3 1 ?swift-ring-builder object.builder add z1-127.0.0.1:6010/sdb1 100 ?swift-ring-builder object.builder add z2-127.0.0.1:6020/sdb2 

OpenCV学习笔记(27)KAZE 算法原理与源码分析(一)非线性扩散滤波

http://blog.csdn.net/chenyusiyuan/article/details/8710462 OpenCV学习笔记(27)KAZE 算法原理与源码分析(一)非线性扩散滤波 2013-03-23 17:44 16963人阅读 评论(28) 收藏 举报 分类: 机器视觉(34) 版权声明:本文为博主原创文章,未经博主允许不得转载. 目录(?)[+] KAZE系列笔记: OpenCV学习笔记(27)KAZE 算法原理与源码分析(一)非线性扩散滤波 OpenCV学习笔记(28)KA

WEKA学习——CSVLoader 实例训练 和 源码分析

简介: Weka支持多种数据导入方式,CSVLoader是能从csv文件加载数据集,也可以保存为arff格式文件.官方介绍文件:Converting CSV to ARFF ( http://weka.wikispaces.com/Converting+CSV+to+ARFF) CSVLoader加载文件,关键是对文件字段属性名称和属性的类型需要自己定义,这样才能得到满足自己需要的数据集. CSVLoader通过options设置,可以设置每一列的属性为Nominal,String,Date类型

OpenStack_Swift源码分析——Ring基本原理及一致性Hash算法

1.Ring的基本概念 Ring是swfit中最重要的组件,用于记录存储对象与物理位置之间的映射关系,当用户需要对Account.Container.Object操作时,就需要查询对应的Ring文件(Account.Container.Object都有自己对应的Ring),Ring 使用Region(最近几个版本中新加入的).Zone.Device.Partition和Replica来维护这些信息,对于每一个对象,根据你在部署swift设置的Replica数量,集群中会存有Replica个对象.

SURF算法与源码分析、下

上一篇文章 SURF算法与源码分析.上 中主要分析的是SURF特征点定位的算法原理与相关OpenCV中的源码分析,这篇文章接着上篇文章对已经定位到的SURF特征点进行特征描述.这一步至关重要,这是SURF特征点匹配的基础.总体来说算法思路和SIFT相似,只是每一步都做了不同程度的近似与简化,提高了效率. 1. SURF特征点方向分配 为了保证特征矢量具有旋转不变性,与SIFT特征一样,需要对每个特征点分配一个主方向.为些,我们需要以特征点为中心,以$6s$($s = 1.2 *L /9$为特征点

Mahout源码分析:并行化FP-Growth算法

FP-Growth是一种常被用来进行关联分析,挖掘频繁项的算法.与Aprior算法相比,FP-Growth算法采用前缀树的形式来表征数据,减少了扫描事务数据库的次数,通过递归地生成条件FP-tree来挖掘频繁项.参考资料[1]详细分析了这一过程.事实上,面对大数据量时,FP-Growth算法生成的FP-tree非常大,无法放入内存,挖掘到的频繁项也可能有指数多个.本文将分析如何并行化FP-Growth算法以及Mahout中并行化FP-Growth算法的源码. 1. 并行化FP-Growth 并行

Mahout源码分析-K-means聚类算法

一 算法描述 1.随机选取k个对象作为初始簇中心: 2.计算每个对象到簇中心的距离,将每个对象聚类到离该对象最近的聚簇中去: 3.计算每个聚簇中的簇均值,并将簇均值作为新的簇中心: 4.计算准则函数: 5.重复(2).(3)和(4),直到准则函数不再发生变化. 二 源码分析 Mahout源码分析-K-means聚类算法

【E2LSH源码分析】LSH算法框架分析

位置敏感哈希(Locality Sensitive Hashing,LSH)是近似最近邻搜索算法中最流行的一种,它有坚实的理论依据并且在高维数据空间中表现优异.由于网络上相关知识的介绍比较单一,现就LSH的相关算法和技术做一介绍总结,希望能给感兴趣的朋友提供便利,也希望有兴趣的同道中人多交流.多指正. 1.LSH原理 最近邻问题(nearest neighbor problem)可以定义如下:给定n个对象的集合并建立一个数据结构,当给定任意的要查询对象时,该数据结构返回针对查询对象的最相似的数据