Weka算法Classifier-trees-REPTree源码分析(一)

一、算法

关于REPTree我实在是没找到什么相关其算法的资料,或许是Weka自创的一个关于决策树的改进,也许是其它某种决策树方法的别名,根据类的注释:Fast decision tree learner. Builds a decision/regression tree using information gain/variance and prunes it using reduced-error pruning (with backfitting).
 Only sorts values for numeric attributes once. Missing values are dealt with by splitting the corresponding instances into pieces (i.e. as in C4.5).

我们大概知道和C4.5相比,大概多了backfitting过程,并且数值型排序只进行一次(回想一下J48也就是C4.5算法是每个数据子集都要进行排序),并且缺失值的处理方式和C4.5一样,走不同的path再把结果进行加权。

具体和C4.5的比较将在代码分析之后给出一个总结。

二、buildClassifier

“大名鼎鼎”的分类器训练主入口,几乎每篇分析分类器源码都从这个方法入手。

 public void buildClassifier(Instances data) throws Exception {

    // 首先例行公事看一下给定数据集是否能使用REPTree进行分类,REPTREE基本能支持所有类型
    getCapabilities().testWithFail(data);

    // 把classIndex上没有数据的instance干掉,这些数据既不能用于训练也不能用于backfit
    data = new Instances(data);
    data.deleteWithMissingClass();

    Random random = new Random(m_Seed);

    m_zeroR = null;
    if (data.numAttributes() == 1) {
      m_zeroR = new ZeroR();//如果只有一列的话,就是用m_ZerO作为分类器,很直观只有一列的话肯定就是结果列了,只有结果列无法训练分类器,只能使用最基本的米ZerO作为分类器,mZerO的分类方法再上篇日志有说到。
      m_zeroR.buildClassifier(data);
      return;
    }

    // Randomize and stratify
    data.randomize(random);//进行随机排列
    if (data.classAttribute().isNominal()) {
      data.stratify(m_NumFolds);//如果枚举型还要进行一下分层,目的是
    }

    // 如果需要剪枝,则分为train集合和prune集合,否则只要train集合就行了
    Instances train = null;
    Instances prune = null;
    if (!m_NoPruning) {
      train = data.trainCV(m_NumFolds, 0, random);//这里是用了多折交叉验证的方法取得train和test
      prune = data.testCV(m_NumFolds, 0);
    } else {
      train = data;
    }

    // 建立了两个数组,第一维数据无意义,只是把三维数组当二维数组用而已,第二维代表各属性,第三维代表排序的index(顺序统计量)
    int[][][] sortedIndices = new int[1][train.numAttributes()][0];//这个里面存放的是各instance的下标
    double[][][] weights = new double[1][train.numAttributes()][0];//这个里面存放的是下标对应的instance的weight
    double[] vals = new double[train.numInstances()];//这个是临时数组,用于排序用的
    for (int j = 0; j < train.numAttributes(); j++) {
      if (j != train.classIndex()) {
	weights[0][j] = new double[train.numInstances()];
	if (train.attribute(j).isNominal()) {

	  //如果是枚举类型,所做的排序工作就是简单的把Missing放到最后面
	  sortedIndices[0][j] = new int[train.numInstances()];
	  int count = 0;
	  for (int i = 0; i < train.numInstances(); i++) {
	    Instance inst = train.instance(i);
	    if (!inst.isMissing(j)) {
	      sortedIndices[0][j][count] = i;
	      weights[0][j][count] = inst.weight();
	      count++;
	    }
	  }
	  for (int i = 0; i < train.numInstances(); i++) {
	    Instance inst = train.instance(i);
	    if (inst.isMissing(j)) {
	      sortedIndices[0][j][count] = i;
	      weights[0][j][count] = inst.weight();
	      count++;
	    }
	  }
	} else {

	  // 如果是数值类型,则进行排序
	  for (int i = 0; i < train.numInstances(); i++) {
	    Instance inst = train.instance(i);
	    vals[i] = inst.value(j);
	  }
	  sortedIndices[0][j] = Utils.sort(vals);
	  for (int i = 0; i < train.numInstances(); i++) {
	    weights[0][j][i] = train.instance(sortedIndices[0][j][i]).weight();
	  }
	}
      }
    }

    // 这里建立数组存放训练集中每个类的分布
    double[] classProbs = new double[train.numClasses()];
    double totalWeight = 0, totalSumSquared = 0;
    for (int i = 0; i < train.numInstances(); i++) {
      Instance inst = train.instance(i);
      if (data.classAttribute().isNominal()) {
	classProbs[(int)inst.classValue()] += inst.weight();//如果是枚举类型,就进行简单的统计
	totalWeight += inst.weight();
      } else {
	classProbs[0] += inst.classValue() * inst.weight();//如果是数值型,就相加,到后面进行取平均的操作
	totalSumSquared += inst.classValue() * inst.classValue() * inst.weight();
	totalWeight += inst.weight();
      }
    }
    m_Tree = new Tree();//建立决策树节点
    double trainVariance = 0;//训练集的方差
    if (data.classAttribute().isNumeric()) {
      trainVariance = m_Tree.
	singleVariance(classProbs[0], totalSumSquared, totalWeight) / totalWeight;
      classProbs[0] /= totalWeight;//这里取平均操作
    }

    // Build tree
    m_Tree.buildTree(sortedIndices, weights, train, totalWeight, classProbs,
		     new Instances(train, 0), m_MinNum, m_MinVarianceProp *
		     trainVariance, 0, m_MaxDepth);//执行具体树上的构建操作,这参数还真多

    // Insert pruning data and perform reduced error pruning
    if (!m_NoPruning) {
      m_Tree.insertHoldOutSet(prune);//传入剪枝数据
      m_Tree.reducedErrorPrune();//进行剪枝
      m_Tree.backfitHoldOutSet();//backfit
    }
  }

(2)Tree.buildTree

Tree是REPTree的一个子对象,训练用参数较多。

 protected void buildTree(int[][][] sortedIndices, double[][][] weights,
			     Instances data, double totalWeight,
			     double[] classProbs, Instances header,
			     double minNum, double minVariance,
			     int depth, int maxDepth)
      throws Exception {
      //第一个参数是按属性排好序的下标,第二个是这些下标对应的weight,第三个是训练数据
<span style="white-space:pre">	</span>//第四个是总权重,第五个是各类的分布,第六个是表头,第七个是每个节点最小instance数量
<span style="white-space:pre">	</span>//第八个是最小的方差 ,第九个是当前深度(0 base),第十个是最大深度
      

      m_Info = header;//首先存下表头
      if (data.classAttribute().isNumeric()) {
        m_HoldOutDist = new double[2];//这个数组用于存放分布
      } else {
        m_HoldOutDist = new double[data.numClasses()];
      }

      // 看看是否有有效数据
      int helpIndex = 0;
      if (data.classIndex() == 0) {
	helpIndex = 1;//传入的数据至少两列,因为一列的话上层就用m_zerO模型了,这个if是为了保证helpIndex对应的肯定是训练数据
      }
      if (sortedIndices[0][helpIndex].length == 0) {//如果没数据,就直接反悔了
	if (data.classAttribute().isNumeric()) {
	  m_Distribution = new double[2];//为什么是二维的?第一维存放方差,第二维存放weight,基于约定的编程方式
	} else {
	  m_Distribution = new double[data.numClasses()];
	}
	m_ClassProbs = null;
        sortedIndices[0] = null;
        weights[0] = null;
	return;
      }

      double priorVar = 0;//存放class的方差(其实是方差*num),只有class是数值才有意义,下面就是计算方差的过程。
      if (data.classAttribute().isNumeric()) {

	// 每个sortedIndices[0][i]里面的都是一个Instances的index不同排列而已,使用helpIndex只是为了保证别对应到classIndex上
	double totalSum = 0, totalSumSquared = 0, totalSumOfWeights = 0;
	for (int i = 0; i < sortedIndices[0][helpIndex].length; i++) {
	  Instance inst = data.instance(sortedIndices[0][helpIndex][i]);
	  totalSum += inst.classValue() * weights[0][helpIndex][i];
	  totalSumSquared +=
	    inst.classValue() * inst.classValue() * weights[0][helpIndex][i];
	  totalSumOfWeights += weights[0][helpIndex][i];
	}
	priorVar = singleVariance(totalSum, totalSumSquared,
				  totalSumOfWeights);
      }

      //把分布拷贝一下
      m_ClassProbs = new double[classProbs.length];
      System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);
      if ((//退出条件有4个
<span style="white-space:pre">	</span>//第一个是instances里面的totalweight总量(可以理解成里面的instance数量,因为weight默认都是1)小于两倍的minNum,minNum默认是2.
<span style="white-space:pre">	</span>totalWeight < (2 * minNum)) ||

	  // 如果是枚举类型,并且都在一类中
	  (data.classAttribute().isNominal() &&
	   Utils.eq(m_ClassProbs[Utils.maxIndex(m_ClassProbs)],
		    Utils.sum(m_ClassProbs))) ||

	  // 数值型则比较方差是否小于minVariance,这个minVariance默认是原始方差的0.001,从上层代码可以得知
	  (data.classAttribute().isNumeric() &&
	   ((priorVar / totalWeight) < minVariance)) ||

	  // 达到最大深度
	  ((m_MaxDepth >= 0) && (depth >= maxDepth))) {

	// 设置成叶子
	m_Attribute = -1;
	if (data.classAttribute().isNominal()) {

	  // 设置枚举类型的分布
	  m_Distribution = new double[m_ClassProbs.length];
	  for (int i = 0; i < m_ClassProbs.length; i++) {
	    m_Distribution[i] = m_ClassProbs[i];
	  }
	  Utils.normalize(m_ClassProbs);
	} else {

	  // 设置数值类型的“分布”
	  m_Distribution = new double[2];
	  m_Distribution[0] = priorVar;
	  m_Distribution[1] = totalWeight;
	}
        sortedIndices[0] = null;
        weights[0] = null;
	return;
      }

      // 下面是寻找分裂点的过程
      double[] vals = new double[data.numAttributes()];//每个属性产生的信息增益
      double[][][] dists = new double[data.numAttributes()][0][0];//每个属性下每个类的分布
      double[][] props = new double[data.numAttributes()][0];//每个属性下class的概率,也就是根据上面这个数组的分布求概率
      double[][] totalSubsetWeights = new double[data.numAttributes()][0];//每个属性下每个subset的数量
      double[] splits = new double[data.numAttributes()];//每个属性的分裂点,如果是枚举型则为NaN
      if (data.classAttribute().isNominal()) { 

	// 首先来看classAttribute是枚举类型的情况
	for (int i = 0; i < data.numAttributes(); i++) {
	  if (i != data.classIndex()) {
	    splits[i] = distribution(props, dists, i, sortedIndices[0][i],
				     weights[0][i], totalSubsetWeights, data);//得到分裂点、概率和分布
	    vals[i] = gain(dists[i], priorVal(dists[i]));//得到信息增益
	  }
	}
      } else {

	// 如果是数值类型则不算信息增益(为什么数值类型不算增益?只有因为枚举型才算的出信息熵)(吐个槽:话说这个if-else为啥不放在循环里面??)
	for (int i = 0; i < data.numAttributes(); i++) {
	  if (i != data.classIndex()) {
	    splits[i] =
	      numericDistribution(props, dists, i, sortedIndices[0][i],
				  weights[0][i], totalSubsetWeights, data,
				  vals);
	  }
	}
      }

      // 选出信息增益最大的作为分裂属性
      m_Attribute = Utils.maxIndex(vals);
      int numAttVals = dists[m_Attribute].length;

      // 每个subset都要多于minNum,这样才算一个有效subset
      int count = 0;
      for (int i = 0; i < numAttVals; i++) {
	if (totalSubsetWeights[m_Attribute][i] >= minNum) {
	  count++;
	}
	if (count > 1) {
	  break;
	}
      }

      // 至少存在2个有效subset,才算是一个有效的split
      if (Utils.gr(vals[m_Attribute], 0) && (count > 1)) {      

        // Set split point, proportions, and temp arrays
	m_SplitPoint = splits[m_Attribute];
	m_Prop = props[m_Attribute];
        double[][] attSubsetDists = dists[m_Attribute];
        double[] attTotalSubsetWeights = totalSubsetWeights[m_Attribute];

        // 释放内存
        vals = null;
        dists = null;
        props = null;
        totalSubsetWeights = null;
        splits = null;

	// 得到subSet的有序index
	int[][][][] subsetIndices =
	  new int[numAttVals][1][data.numAttributes()][0];
	double[][][][] subsetWeights =
	  new double[numAttVals][1][data.numAttributes()][0];
	splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitPoint,
		  sortedIndices[0], weights[0], data);

        // 释放内存
        sortedIndices[0] = null;
        weights[0] = null;

        //释放内存
	m_Successors = new Tree[numAttVals];
	for (int i = 0; i < numAttVals; i++) {
	  m_Successors[i] = new Tree();//构建孩子节点
	  m_Successors[i].
	    buildTree(subsetIndices[i], subsetWeights[i],
		      data, attTotalSubsetWeights[i],
		      attSubsetDists[i], header, minNum,
		      minVariance, depth + 1, maxDepth);

          // 还是释放内存
          attSubsetDists[i] = null;
	}
      } else {

	// 如果不存在2个有效的subset,就直接当叶子节点了
	m_Attribute = -1;
        sortedIndices[0] = null;
        weights[0] = null;
      }

      // 构建attribute用于之后的分类过程(当然这是在没有prune和backfit情况下用的)
      if (data.classAttribute().isNominal()) {
	m_Distribution = new double[m_ClassProbs.length];
	for (int i = 0; i < m_ClassProbs.length; i++) {
	    m_Distribution[i] = m_ClassProbs[i];
	}
	Utils.normalize(m_ClassProbs);
      } else {
	m_Distribution = new double[2];
	m_Distribution[0] = priorVar;
	m_Distribution[1] = totalWeight;
      }
    }
时间: 2024-10-17 15:33:11

Weka算法Classifier-trees-REPTree源码分析(一)的相关文章

Opencv2.4.9源码分析——Gradient Boosted Trees

一.原理 梯度提升树(GBT,Gradient Boosted Trees,或称为梯度提升决策树)算法是由Friedman于1999年首次完整的提出,该算法可以实现回归.分类和排序.GBT的优点是特征属性无需进行归一化处理,预测速度快,可以应用不同的损失函数等. 从它的名字就可以看出,GBT包括三个机器学习的优化算法:决策树方法.提升方法和梯度下降法.前两种算法在我以前的文章中都有详细的介绍,在这里我只做简单描述. 决策树是一个由根节点.中间节点.叶节点和分支构成的树状模型,分支代表着数据的走向

Opencv2.4.9源码分析——Random Trees

一.原理 随机森林(Random Forest)的思想最早是由Ho于1995年首次提出,后来Breiman完整系统的发展了该算法,并命名为随机森林,而且他和他的博士学生兼同事Cutler把Random Forest注册成了商标,这可能也是Opencv把该算法命名为Random Trees的原因吧. 一片森林是由许多棵树木组成,森林中的每棵树可以说是彼此不相关,也就是说每棵树木的生长完全是由自身条件决定的,只有保持森林的多样性,森林才能更好的生长下去.随机森林算法与真实的森林相类似,它是由许多决策

K-近邻算法的Python实现 : 源码分析

网上介绍K-近邻算法的例子很多,其Python实现版本基本都是来自于机器学习的入门书籍<机器学习实战>,虽然K-近邻算法本身很简单,但很多初学者对其Python版本的源代码理解不够,所以本文将对其源代码进行分析. 什么是K-近邻算法? 简单的说,K-近邻算法采用不同特征值之间的距离方法进行分类.所以它是一个分类算法. 优点:无数据输入假定,对异常值不敏感 缺点:复杂度高 好了,直接先上代码,等会在分析:(这份代码来自<机器学习实战>) def classify0(inx, data

OpenStack_Swift源码分析——Ring的rebalance算法源代码详细分析

今天有同学去百度,带回一道面试题,和大家分享一下: 打印: n=1 1 n=2 3 3 2 4 1 1 4 5 5 n=3 7 7 7 7 6 8 3 3 2 6 8 4 1 1 6 8 4 5 5 5 8 9 9 9 9 提供一段参考程序: <pre name="code" class="cpp">// ConsoleApplication1.cpp: 主项目文件. #include "stdafx.h" #include &quo

Openck_Swift源码分析——增加、删除设备时算法具体的实现过程

1 初始添加设备后.上传Object的具体流程 前几篇博客中,我们讲到环的基本原理即具体的实现过程,加入我们在初始创建Ring是执行如下几条命令: ?swift-ring-builder object.builder create 5 3 1 ?swift-ring-builder object.builder add z1-127.0.0.1:6010/sdb1 100 ?swift-ring-builder object.builder add z2-127.0.0.1:6020/sdb2 

OpenCV学习笔记(27)KAZE 算法原理与源码分析(一)非线性扩散滤波

http://blog.csdn.net/chenyusiyuan/article/details/8710462 OpenCV学习笔记(27)KAZE 算法原理与源码分析(一)非线性扩散滤波 2013-03-23 17:44 16963人阅读 评论(28) 收藏 举报 分类: 机器视觉(34) 版权声明:本文为博主原创文章,未经博主允许不得转载. 目录(?)[+] KAZE系列笔记: OpenCV学习笔记(27)KAZE 算法原理与源码分析(一)非线性扩散滤波 OpenCV学习笔记(28)KA

WEKA学习——CSVLoader 实例训练 和 源码分析

简介: Weka支持多种数据导入方式,CSVLoader是能从csv文件加载数据集,也可以保存为arff格式文件.官方介绍文件:Converting CSV to ARFF ( http://weka.wikispaces.com/Converting+CSV+to+ARFF) CSVLoader加载文件,关键是对文件字段属性名称和属性的类型需要自己定义,这样才能得到满足自己需要的数据集. CSVLoader通过options设置,可以设置每一列的属性为Nominal,String,Date类型

OpenStack_Swift源码分析——Ring基本原理及一致性Hash算法

1.Ring的基本概念 Ring是swfit中最重要的组件,用于记录存储对象与物理位置之间的映射关系,当用户需要对Account.Container.Object操作时,就需要查询对应的Ring文件(Account.Container.Object都有自己对应的Ring),Ring 使用Region(最近几个版本中新加入的).Zone.Device.Partition和Replica来维护这些信息,对于每一个对象,根据你在部署swift设置的Replica数量,集群中会存有Replica个对象.

SURF算法与源码分析、下

上一篇文章 SURF算法与源码分析.上 中主要分析的是SURF特征点定位的算法原理与相关OpenCV中的源码分析,这篇文章接着上篇文章对已经定位到的SURF特征点进行特征描述.这一步至关重要,这是SURF特征点匹配的基础.总体来说算法思路和SIFT相似,只是每一步都做了不同程度的近似与简化,提高了效率. 1. SURF特征点方向分配 为了保证特征矢量具有旋转不变性,与SIFT特征一样,需要对每个特征点分配一个主方向.为些,我们需要以特征点为中心,以$6s$($s = 1.2 *L /9$为特征点

Mahout源码分析:并行化FP-Growth算法

FP-Growth是一种常被用来进行关联分析,挖掘频繁项的算法.与Aprior算法相比,FP-Growth算法采用前缀树的形式来表征数据,减少了扫描事务数据库的次数,通过递归地生成条件FP-tree来挖掘频繁项.参考资料[1]详细分析了这一过程.事实上,面对大数据量时,FP-Growth算法生成的FP-tree非常大,无法放入内存,挖掘到的频繁项也可能有指数多个.本文将分析如何并行化FP-Growth算法以及Mahout中并行化FP-Growth算法的源码. 1. 并行化FP-Growth 并行