Weka算法Classifier-tree-J48源码分析(三)ModelSelection

ModelSelection主要是用于选择合适的列对数据集进行分割,结合上一篇J48的主流程,发现用到的ModelSelection有 C45ModelSelection以及BinC45ModelSelection,先来分析C45ModelSelection。

一、C45ModelSelection

首先作为一个ModelSelection接口,实现的主要方法有两个,分别是selectModel(Instances)和selectionModel(Instances,Instances)。C45ModelSelection的后一个方法如下:

  public final ClassifierSplitModel selectModel(Instances train, Instances test) {

    return selectModel(train);
  }

可以看到就是忽略了test测试集直接调用selectModel方法而已,因此主要分词selectModel方法。

先放出整段代码,然后对该段代码进行分析:

public final ClassifierSplitModel selectModel(Instances data){

    double minResult;
    double currentResult;
    C45Split [] currentModel;
    C45Split bestModel = null;
    NoSplit noSplitModel = null;
    double averageInfoGain = 0;
    int validModels = 0;
    boolean multiVal = true;
    Distribution checkDistribution;
    Attribute attribute;
    double sumOfWeights;
    int i;

    try{

      // Check if all Instances belong to one class or if not
      // enough Instances to split.
      checkDistribution = new Distribution(data);
      noSplitModel = new NoSplit(checkDistribution);
      if (Utils.sm(checkDistribution.total(),2*m_minNoObj) ||
	  Utils.eq(checkDistribution.total(),
		   checkDistribution.perClass(checkDistribution.maxClass())))
	return noSplitModel;

      // Check if all attributes are nominal and have a
      // lot of values.
      if (m_allData != null) {
	Enumeration enu = data.enumerateAttributes();
	while (enu.hasMoreElements()) {
	  attribute = (Attribute) enu.nextElement();
	  if ((attribute.isNumeric()) ||
	      (Utils.sm((double)attribute.numValues(),
			(0.3*(double)m_allData.numInstances())))){
	    multiVal = false;
	    break;
	  }
	}
      } 

      currentModel = new C45Split[data.numAttributes()];
      sumOfWeights = data.sumOfWeights();

      // For each attribute.
      for (i = 0; i < data.numAttributes(); i++){

	// Apart from class attribute.
	if (i != (data).classIndex()){

	  // Get models for current attribute.
	  currentModel[i] = new C45Split(i,m_minNoObj,sumOfWeights);
	  currentModel[i].buildClassifier(data);

	  // Check if useful split for current attribute
	  // exists and check for enumerated attributes with
	  // a lot of values.
	  if (currentModel[i].checkModel())
	    if (m_allData != null) {
	      if ((data.attribute(i).isNumeric()) ||
		  (multiVal || Utils.sm((double)data.attribute(i).numValues(),
					(0.3*(double)m_allData.numInstances())))){
		averageInfoGain = averageInfoGain+currentModel[i].infoGain();
		validModels++;
	      }
	    } else {
	      averageInfoGain = averageInfoGain+currentModel[i].infoGain();
	      validModels++;
	    }
	}else
	  currentModel[i] = null;
      }

      // Check if any useful split was found.
      if (validModels == 0)
	return noSplitModel;
      averageInfoGain = averageInfoGain/(double)validModels;

      // Find "best" attribute to split on.
      minResult = 0;
      for (i=0;i<data.numAttributes();i++){
	if ((i != (data).classIndex()) &&
	    (currentModel[i].checkModel()))

	  // Use 1E-3 here to get a closer approximation to the original
	  // implementation.
	  if ((currentModel[i].infoGain() >= (averageInfoGain-1E-3)) &&
	      Utils.gr(currentModel[i].gainRatio(),minResult)){
	    bestModel = currentModel[i];
	    minResult = currentModel[i].gainRatio();
	  }
      }

      // Check if useful split was found.
      if (Utils.eq(minResult,0))
	return noSplitModel;

      // Add all Instances with unknown values for the corresponding
      // attribute to the distribution for the model, so that
      // the complete distribution is stored with the model.
      bestModel.distribution().
	  addInstWithUnknown(data,bestModel.attIndex());

      // Set the split point analogue to C45 if attribute numeric.
      if (m_allData != null)
	bestModel.setSplitPoint(m_allData);
      return bestModel;
    }catch(Exception e){
      e.printStackTrace();
    }
    return null;
  }

第一部分,主要是对局部变量的一些定义。

    double minResult;//最小的信息增益率
    double currentResult;//当前信息增益率
    C45Split [] currentModel;//存放所有未分类属性产生的模型
    C45Split bestModel = null;//目前为止的最好模型
    NoSplit noSplitModel = null;//代表不用分的模型
    double averageInfoGain = 0;//各模型(currentModel)的平均信息增益
    int validModels = 0;//是否存在有效模型
    boolean multiVal = true;//是否多值
    Distribution checkDistribution;//训练数据集的分布
    Attribute attribute;//属性列集合
    double sumOfWeights;//训练数据集的weight的和
    int i;//循环变量

第二部分,递归出口。

 checkDistribution = new Distribution(data);
      noSplitModel = new NoSplit(checkDistribution);
      if (Utils.sm(checkDistribution.total(),2*m_minNoObj) ||
	  Utils.eq(checkDistribution.total(),
		   checkDistribution.perClass(checkDistribution.maxClass())))
	return noSplitModel;

可以看到,如果当前数据集数量小于2*m_minNoObj(这个值默认是2),或者当前数据集已经全在同一个分类中,就返回noSplitModel代表不用分,这就是整个C45分类树节点停止分裂的条件。

第三部分,判断是否是多值:

      if (m_allData != null) {
	Enumeration enu = data.enumerateAttributes();
	while (enu.hasMoreElements()) {
	  attribute = (Attribute) enu.nextElement();
	  if ((attribute.isNumeric()) ||
	      (Utils.sm((double)attribute.numValues(),
			(0.3*(double)m_allData.numInstances())))){
	    multiVal = false;
	    break;
	  }
	}
      } 

如果属性中,任意一列是数值型,或者其取值的数量小于训练集数量*0.3,则不是多值,否则按多值处理。是否是多值影响到后面某些逻辑。

第四部分,对于每一列属性构造Spliter。

    for (i = 0; i < data.numAttributes(); i++){

	// Apart from class attribute.
	if (i != (data).classIndex()){

	  // Get models for current attribute.
	  currentModel[i] = new C45Split(i,m_minNoObj,sumOfWeights);
	  currentModel[i].buildClassifier(data);

	  // Check if useful split for current attribute
	  // exists and check for enumerated attributes with
	  // a lot of values.
	  if (currentModel[i].checkModel())
	    if (m_allData != null) {
	      if ((data.attribute(i).isNumeric()) ||
		  (multiVal || Utils.sm((double)data.attribute(i).numValues(),
					(0.3*(double)m_allData.numInstances())))){
		averageInfoGain = averageInfoGain+currentModel[i].infoGain();
		validModels++;
	      }
	    } else {
	      averageInfoGain = averageInfoGain+currentModel[i].infoGain();
	      validModels++;
	    }
	}else
	  currentModel[i] = null;
      }

对于每一列属性,如果不是存放分类的值得话,则构造C45Split对象,在该对象上进行分类,然后算出信息增益,相加到averageInfoGain上。对于C45Split的构造,稍后再看。

第五部分,选出最优模型。

 if (validModels == 0)
	return noSplitModel;
      averageInfoGain = averageInfoGain/(double)validModels;

      // Find "best" attribute to split on.
      minResult = 0;
      for (i=0;i<data.numAttributes();i++){
	if ((i != (data).classIndex()) &&
	    (currentModel[i].checkModel()))

	  // Use 1E-3 here to get a closer approximation to the original
	  // implementation.
	  if ((currentModel[i].infoGain() >= (averageInfoGain-1E-3)) &&
	      Utils.gr(currentModel[i].gainRatio(),minResult)){
	    bestModel = currentModel[i];
	    minResult = currentModel[i].gainRatio();
	  } 

如果存在有效模型,则选出有效模型。注意这个选出最优模型的逻辑,并不是单纯的选出gainRatio最大的,而是在基础上必须还要大于平均信息增益,这也是和传统的c45算法不一样的一点。

从上述过程来看,Weka在实现C45的时候做了一个小的变动,并没有从“还没有使用的”属性列中找出最合理的列最为分割属性,而是在“所有的列”中找出最合理的列作为分割属性,虽然这二者在结果上肯定是等价的(之前是有过的属性不和能有很好的信息增益率),但效率上个人对Weka的做法持保留意见。

二、C45Spliter

在ModelSelection中真正根据属性对训练集进行分割、计算信息增益和信息增益率的是C45Spliter,首先也从其buildClassifier方法入手进行分析。

public void buildClassifier(Instances trainInstances)
       throws Exception {

    // Initialize the remaining instance variables.
    m_numSubsets = 0;
    m_splitPoint = Double.MAX_VALUE;
    m_infoGain = 0;
    m_gainRatio = 0;

    // Different treatment for enumerated and numeric
    // attributes.
    if (trainInstances.attribute(m_attIndex).isNominal()) {
      m_complexityIndex = trainInstances.attribute(m_attIndex).numValues();
      m_index = m_complexityIndex;
      handleEnumeratedAttribute(trainInstances);
    }else{
      m_complexityIndex = 2;
      m_index = 0;
      trainInstances.sort(trainInstances.attribute(m_attIndex));
      handleNumericAttribute(trainInstances);
    }
  }    

可以看到,对于枚举型和数值型的属性是分开处理的,枚举型调用handlEnumeratedAttribute,数值型调用handleNumericAttribute,值得注意的是,在处理数值型之前,按照相应列进行排序,同时设置m_complexityIndex也就是期望分裂的节点数设定为2。

首先来看枚举类型是如何处理的。

private void handleEnumeratedAttribute(Instances trainInstances)
       throws Exception {

    Instance instance;

    m_distribution = new Distribution(m_complexityIndex,
			      trainInstances.numClasses());

    // Only Instances with known values are relevant.
    Enumeration enu = trainInstances.enumerateInstances();
    while (enu.hasMoreElements()) {
      instance = (Instance) enu.nextElement();
      if (!instance.isMissing(m_attIndex))
	m_distribution.add((int)instance.value(m_attIndex),instance);
    }

    // Check if minimum number of Instances in at least two
    // subsets.
    if (m_distribution.check(m_minNoObj)) {
      m_numSubsets = m_complexityIndex;
      m_infoGain = infoGainCrit.
	splitCritValue(m_distribution,m_sumOfWeights);
      m_gainRatio =
	gainRatioCrit.splitCritValue(m_distribution,m_sumOfWeights,
				     m_infoGain);
    }
  }

大概流程是新建一个分布,遍历所有instance,如果该instance对应的分裂的属性不为空的话,则放到不同的bag里,之后检查一下这个分布是否满足要求,要求就是最多允许有一个bag里的数据数量小于m_minNoObj,如果通过检查,就设置subset的数量,计算信息增益和信息增益率,否则subset默认会是0,上层调用checkModel就会返回false代表这是一个无效模型。

接下来看数值型是如何处理的:

 private void handleNumericAttribute(Instances trainInstances)
       throws Exception {

    int firstMiss;//最后一个有效instance的下标
    int next = 1;//下一个instance的index
    int last = 0;//当前instance的index
    int splitIndex = -1;//分裂点
    double currentInfoGain;//当前信息增益
    double defaultEnt;//分割之前的信息熵
    double minSplit;
    Instance instance;
    int i;
//首先新建一个分布,数值型默认处理为2维分布,也就可以理解为小于某个值放到一个Bag里,其余的放到另外一个Bag里
    m_distribution = new Distribution(2,trainInstances.numClasses());
    Enumeration enu = trainInstances.enumerateInstances();
    i = 0;
<pre name="code" class="cpp">//注意instances传入的时候是排好序的,这个排序保证了missingValue放在最后面,所以读到了missingValue其之后肯定都是miss//ingValue,换言之,firstMiss在循环之后代表了最后一个有效的instance的下标。

while (enu.hasMoreElements()) { instance = (Instance) enu.nextElement(); if (instance.isMissing(m_attIndex))break; m_distribution.add(1,instance); i++;
} firstMiss = i;//循环结束后,m_distribution里放入了所有的有效instance,并全放入了bag1里。


//minSplit是最后分类好每个Bag里最小的数据的量,也就是0.1*每个类的均值。
    minSplit =  0.1*(m_distribution.total())/
      ((double)trainInstances.numClasses());
    if (Utils.smOrEq(minSplit,m_minNoObj))
      minSplit = m_minNoObj;
    else
      if (Utils.gr(minSplit,25))
	minSplit = 25;

//如果有效数据总量不到2*minSplit,换言之无论怎么分均不能保证2个bag里的数量大于minSplit,就直接返回。
    if (Utils.sm((double)firstMiss,2*minSplit))
      return;

//defaultEnt代表旧的信息熵,也就是对该属性进行分类之前,Indexclass对应的信息熵。
    defaultEnt = infoGainCrit.oldEnt(m_distribution);
    while (next < firstMiss) {

      if (trainInstances.instance(next-1).value(m_attIndex)+1e-5 <
	  trainInstances.instance(next).value(m_attIndex)) {
	<pre name="code" class="cpp">//Instances里的记录是升序排列的,加上这个条件默认把值相差很小的Instance就当做同一个instance处理了
//last代表当前,next代表下一个,默认next=1,last=0,所以shiftRange可以理解成把当前记录从bag1移动到bag0中
<span style="font-family: Arial, Helvetica, sans-serif;">//注意一开始初始化时候所有的都是在bag1里面的。	</span>

m_distribution.shiftRange(1,0,trainInstances,last,next);if (Utils.grOrEq(m_distribution.perBag(0),minSplit) && //如果两个bag都满足最小数据集的数量minSplit Utils.grOrEq(m_distribution.perBag(1),minSplit)) { currentInfoGain
= infoGainCrit. splitCritValue(m_distribution,m_sumOfWeights, //算一下信息增益 defaultEnt);


	  if (Utils.gr(currentInfoGain,m_infoGain)) {
	    m_infoGain = currentInfoGain;//如果信息增益比当前最大的要大,则替换当前最大的值,并记录splitIndex
	    splitIndex = next-1;
	  }
	  m_index++;
	}
	last = next;
      }
      next++;
    }

    if (m_index == 0)
      return; //执行到这里说明没找到一个合适的分裂点,直接返回。

    // 计算最佳信息增益
    m_infoGain = m_infoGain-(Utils.log2(m_index)/m_sumOfWeights);
    if (Utils.smOrEq(m_infoGain,0))
      return; //如果信息增益是0也说明没找到合适的分裂点,直接返回。

    //剩下的就是根据分裂点进行属性的划分。
    m_numSubsets = 2;
    m_splitPoint =
      (trainInstances.instance(splitIndex+1).value(m_attIndex)+
       trainInstances.instance(splitIndex).value(m_attIndex))/2;

    // In case we have a numerical precision problem we need to choose the
    // smaller value
    if (m_splitPoint == trainInstances.instance(splitIndex + 1).value(m_attIndex)) {
      m_splitPoint = trainInstances.instance(splitIndex).value(m_attIndex);
    }

    // Restore distributioN for best split.
    m_distribution = new Distribution(2,trainInstances.numClasses());
    m_distribution.addRange(0,trainInstances,0,splitIndex+1);
    m_distribution.addRange(1,trainInstances,splitIndex+1,firstMiss);

    // Compute modified gain ratio for best split.
    m_gainRatio = gainRatioCrit.
      splitCritValue(m_distribution,m_sumOfWeights,
		     m_infoGain);
  }

这个函数有点复杂,具体逻辑也写到代码注释里了。

三、BinC45ModelSelection

该函数只负责生成二元分类树的模型,selectModel方法和C45ModelSelection几乎一样,不在多说,不同点在于其使用BinC45Spliter而不是C45Spliter。

四、BinC45Spliter

handleNumericAttribute对于数值类型的属性处理和C45Spliter完全一样。下面只分析一下handleEnumeratedAttribute。

 private void handleEnumeratedAttribute(Instances trainInstances)
       throws Exception {

    Distribution newDistribution,secondDistribution;
    int numAttValues;
    double currIG,currGR;
    Instance instance;
    int i;

    numAttValues = trainInstances.attribute(m_attIndex).numValues();
    newDistribution = new Distribution(numAttValues,
				       trainInstances.numClasses());

    // Only Instances with known values are relevant.
    Enumeration enu = trainInstances.enumerateInstances();
    while (enu.hasMoreElements()) {
      instance = (Instance) enu.nextElement();
      if (!instance.isMissing(m_attIndex))
	newDistribution.add((int)instance.value(m_attIndex),instance);
    }
    m_distribution = newDistribution;

    // For all values
    for (i = 0; i < numAttValues; i++){

      if (Utils.grOrEq(newDistribution.perBag(i),m_minNoObj)){
	secondDistribution = new Distribution(newDistribution,i);

	// Check if minimum number of Instances in the two
	// subsets.
	if (secondDistribution.check(m_minNoObj)){
	  m_numSubsets = 2;
	  currIG = m_infoGainCrit.splitCritValue(secondDistribution,
					       m_sumOfWeights);
	  currGR = m_gainRatioCrit.splitCritValue(secondDistribution,
						m_sumOfWeights,
						currIG);
	  if ((i == 0) || Utils.gr(currGR,m_gainRatio)){
	    m_gainRatio = currGR;
	    m_infoGain = currIG;
	    m_splitPoint = (double)i;
	    m_distribution = secondDistribution;
	  }
	}
      }
    }

可以看出,上一段代码根据该属性的不同的取值,在已有分布基础上,建立一个新的分布secondeDistribution,

secondDistribution = new Distribution(newDistribution,i);

该分布包含两列,属性下标为i的,其余的,在这个分布的基础上计算信息增益和信息增益率,并选出最优的。

换句话说,离散值分类的二元化处理就是选出其中一列当做一个branch,其余的当做另外一个branch。虽然从结构上来讲这肯定不是最优的选择,但简单易用就够了。

到这里基本分析完了J48的两个ModelSelection,下一篇文章将对classifierInstance过程进行分析,并给出一个简单的总结。

时间: 2024-10-06 15:19:05

Weka算法Classifier-tree-J48源码分析(三)ModelSelection的相关文章

Nouveau源码分析(三):NVIDIA设备初始化之nouveau_drm_probe

Nouveau源码分析(三) 向DRM注册了Nouveau驱动之后,内核中的PCI模块就会扫描所有没有对应驱动的设备,然后和nouveau_drm_pci_table对照. 对于匹配的设备,PCI模块就调用对应的probe函数,也就是nouveau_drm_probe. // /drivers/gpu/drm/nouveau/nouveau_drm.c 281 static int nouveau_drm_probe(struct pci_dev *pdev, 282 const struct

[Android]Fragment源码分析(三) 事务

Fragment管理中,不得不谈到的就是它的事务管理,它的事务管理写的非常的出彩.我们先引入一个简单常用的Fragment事务管理代码片段: FragmentTransaction ft = this.getSupportFragmentManager().beginTransaction(); ft.add(R.id.fragmentContainer, fragment, "tag"); ft.addToBackStack("<span style="fo

baksmali和smali源码分析(三)

baksmali 的源码分析 在baksmali进行源码分析之前,需要读者掌握一条主线,因为本身笔者只是由于项目需要用到这套源码,在工作之余的时间里面来进行学习也没有时间和精力熟读源码的每个文件每个方法,但是依据这条主线,至少能够猜出并且猜对baksmali里面的源码的文件大概的作用是什么,这样在修改问题和移植的时候才能做到游刃有余. 这条主线是,baksmali其实只是利用了dexlib2提供的接口,将dex文件读入到一块内存中,这块内存或者说数据结构开辟的大小是跟输入的dex文件相关的,而这

横屏小游戏--萝莉快跑源码分析三

主角出场: 初始化主角 hero = new GameObjHero(); hero->setScale(0.5); hero->setPosition(ccp(100,160)); hero->setVisible(false); addChild(hero,1); 进入GameObjHero类ccp文件 创建主角及动作 this->setContentSize(CCSizeMake(85,90)); //接收触摸事件 CCDirector* pDirector = CCDire

哇!板球 源码分析三

守门员出场 守门员出场,每个守门员是从屏幕的右侧中间的位置随机方向向左侧移动 FielderSprite* fielderSprite1 = FielderSprite::create("pic/fielder.png"); //守门员精灵初始位置为右侧中间位置 fielderSprite1->setPosition(ccp(GOALKEEPER_X, GOALKEEPER_Y)); fielderSprite1->setAnchorPoint(ccp(0.5, 0.5))

ABP源码分析三十三:ABP.Web

ABP.Web模块并不复杂,主要完成ABP系统的初始化和一些基础功能的实现. AbpWebApplication : 继承自ASP.Net的HttpApplication类,主要完成下面三件事一,在Application_Start完成AbpBootstrapper的初始化.整个ABP系统的初始化就是通过AbpBootstrapper完成初始化的.二,在Application_BeginRequest设置根据request或cookie中的Culture信息,完成当前工作线程的CurrentCu

ABP源码分析三十:ABP.RedisCache

ABP 通过StackExchange.Redis类库来操作Redis数据库. AbpRedisCacheModule:完成ABP.RedisCache模块的初始化(完成常规的依赖注入) AbpRedisCacheConfig:定义了connectionStringKey和databaseIdAppSetting的值.这两个值对象redis 在web.config中的key值. ABP.RedisCache模块通过读取web.config来获取redis的配置. IAbpRedisConnect

ABP源码分析三十一:ABP.AutoMapper

这个模块封装了Automapper,使其更易于使用. 下图描述了改模块涉及的所有类之间的关系. AutoMapAttribute,AutoMapFromAttribute和AutoMapToAttribute:这三个attribute用于标注一个类到另外一个类的map方向. AutoMapperHelper: 通过调用Automapper的API,根据类的AutoMap的特性完成类型之间的Map. AbpAutoMapperModule: 1. 查找项目中所有标注了AutoMap特性的类型,并完

YARN源码分析(三)-----ResourceManager HA之应用状态存储与恢复

前言 任何系统即使做的再大,都会有可能出现各种各样的突发状况.尽管你可以说我在软件层面上已经做到所有情况的意外处理了,但是万一硬件出问题了或者说物理层面上出了问题,恐怕就不是多写几行代码能够立刻解决的吧,说了这么多,无非就是想强调HA,系统高可用性的重要性.在YARN中,NameNode的HA方式估计很多人都已经了解了,那本篇文章就来为大家梳理梳理RM资源管理器HA方面的知识,并不是指简单的RM的HA配置,确切的说是RM的应用状态存储于恢复. RM应用状态存储使用 RM应用状态存储是什么意思呢,

ABP源码分析三十六:ABP.Web.Api

这里的内容和ABP 动态webapi没有关系.除了动态webapi,ABP必然是支持使用传统的webApi.ABP.Web.Api模块中实现了一些同意的基础功能,以方便我们创建和使用asp.net webApi. AbpApiController:这是一个抽象基类,继承自ApiController,是AB WebApi系统中所有controller的基类.如下图中,其封装了ABP核心模块中提供的大多数的功能对象.同时实现了一些公共的方法.它有四个派生类:DynamicApiController<