C4.5算法(一)代码实现

入门学习机器学习的十大算法,第一站就是C4.5算法。C4.5是一种决策树算法,属于监督学习。先给一个样本集,从而建立一棵决策树,然后根据这个决策树来对后续的数据做决策。

作为没有相关背景知识和系统学习过的人,当然要边学边记啦。C4.5算法我的学习步骤是这样:

step 1: 了解清楚算法的逻辑,以及编程实现

step 2: 其中对连续变量的离散化处理

step 3: C4.5的剪枝

step 4: C4.5算法的spark实现

因为个人认为C4.5算法中比较难和重要的两个点就是对连续变量的离散化,和剪枝策略,所以会单独着重学习下。因为我终归是做hadoop和spark的,所以还会看看C4.5在spark上的应用和实现(C4.5显然不适合MapReduce模型)。本文只是step1,算法逻辑和编程实现的总结。

算法逻辑

1. 先明确几个概念:

: 朴素点说,就是信息的不确定性,多样性,包含的信息量的大小,需要用多少bit来传递这个信息。比如,抛一枚银币3次,得到的可能结果有8种,我们知道计算机要用3bit来传递,所以熵就是log2(8)=3。wiki上这样解释“你需要用 log2(n) 位来表示一个可以取
n 个值的变量。”

信息增益: 熵的减小量。决策树的期望是尽快定位,也就是说我们希望数据集的多样性越小越好,越小说明结果越稳定,越能定位到准确的结果。信息增益越大,则熵会变的越小,说明结果越好。信息增益的计算方式,是原数据集的熵,减去依照属性划分后,每个属性值的概率 * 对应的子数据集的熵。

信息增益率:对信息增益进行修正。信息增益会优先选择那些属性值多的属性,为了克服这种倾向,用一个属性计算出的信息增益,除以该属性本身的熵(SplitInfo),得到信息增益率。

2. C4.5算法逻辑:

先给一个来自网上的算法步骤:

我的概括: 

(1) 先查看是否为“纯”数据集(即结果一致)

(2) 选择信息增益率最大的属性bestAttr

(3) 根据bestAttr属性,把数据集划分成几个子数据集

(4) 对每个子数据集,递归C4.5算法

把整个C4.5算法的属性划分轨迹记录下来,就形成了一棵C4.5决策树。然后就能用这棵树做决策了。

Java代码实现

把这四段代码拷贝到四个java文件中,然后就直接可以运行了。

下面的代码实现决策树的主要逻辑。

<span style="font-size:14px;">import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class DecisionTree {

	InfoGainRatio infoGainRatio = new InfoGainRatio();

	public TreeNode createDecisionTree(List<String> attribute, List<ArrayList<String>> dataset) {
		TreeNode tree = new TreeNode();

		//check if it is pure
		if(DataSetUtil.isPure(DataSetUtil.getTarget(dataset))) {
			tree.setLeaf(true);
			tree.setTargetValue(DataSetUtil.getTarget(dataset).get(0));
			return tree;
		}
		//choose the best attribute
		int bestAttr = getBestAttribute(attribute, dataset);
		//create a decision tree
		tree.setAttribute(attribute.get(bestAttr));
		tree.setLeaf(false);
		List<String> attrValueList = DataSetUtil.getAttributeValueOfUnique(bestAttr, dataset);
		List<String> subAttribute = new ArrayList<String>();
		subAttribute.addAll(attribute);
		subAttribute.remove(bestAttr);
		for(String attrValue : attrValueList) {
			//更新数据集dataset
			List<ArrayList<String>> subDataSet = DataSetUtil.getSubDataSetByAttribute(dataset, bestAttr, attrValue);
			//递归构建子树
			TreeNode childTree = createDecisionTree(subAttribute, subDataSet);
			tree.addAttributeValue(attrValue);
			tree.addChild(childTree);
		}

		return tree;
	}

	/**
	 * 选出最优属性
	 * @param attribute
	 * @param dataset
	 * @return
	 */
	public int getBestAttribute(List<String> attribute, List<ArrayList<String>> dataset) {
		//calculate the gainRatio of each attribute, choose the max
		int bestAttr = 0;
		double maxGainRatio = 0;

		for(int i = 0; i < attribute.size(); i++) {
			double thisGainRatio = infoGainRatio.getGainRatio(i, dataset);
			if(thisGainRatio > maxGainRatio) {
				maxGainRatio = thisGainRatio;
				bestAttr = i;
			}
		}

		System.out.println("The best attribute is \"" + attribute.get(bestAttr) + "\"");
		return bestAttr;
	}

	public static void main(String args[]) {
		//eg 1
		String attr = "age income student credit_rating";
		String[] set = new String[12];
		set[0] = "youth high no fair no";
		set[1] = "youth high no excellent no";
		set[2] = "middle_aged high no fair yes";
		set[3] = "senior low yes fair yes";
		set[4] = "senior low yes excellent no";
		set[5] = "middle_aged low yes excellent yes";
		set[6] = "youth medium no fair no";
		set[7] = "youth low yes fair yes";
		set[8] = "senior medium yes fair yes";
		set[9] = "youth medium yes excellent yes";
		set[10] = "middle_aged high yes fair yes";
		set[11] = "senior medium no excellent no";

		List<ArrayList<String>> dataset = new ArrayList<ArrayList<String>>();
		List<String> attribute = Arrays.asList(attr.split(" "));
		for(int i = 0; i < set.length; i++) {
			String[] s = set[i].split(" ");
			ArrayList<String> list = new ArrayList<String>();
			for(int j = 0; j < s.length; j++) {
				list.add(s[j]);
			}
			dataset.add(list);
		}

		DecisionTree dt = new DecisionTree();
		TreeNode tree = dt.createDecisionTree(attribute, dataset);
		tree.print("");
	}

}</span>

下面的代码用来计算信息增益率。

<span style="font-size:14px;">import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

public class InfoGainRatio {

	/**
	 * 获取某个属性的熵
	 *  	= -∑ p(xi)log(2,p(xi))
	 * @param list
	 * @return
	 */
	@SuppressWarnings("rawtypes")
	public double getEntropy(List<String> list) {
		//概率统计
		Map<String, Double> probability = DataSetUtil.getProbability(list);

		//熵计算
		double entropy = 0;
		Set set = probability.entrySet();
		Iterator iterator = set.iterator();
		while(iterator.hasNext()) {
			Map.Entry entry = (Entry) iterator.next();
			double prob = (double) entry.getValue();
			entropy -= prob * (Math.log(prob) / Math.log(2));
		}

		return entropy;
	}

	/**
	 * 获取某个属性的信息增益 = Entropy(U) ? ∑(|Di|/|D|)Entropy(Di)
	 * <br/> 离散属性
	 * @param attrId
	 * @param dataset
	 * @return
	 */
	@SuppressWarnings("rawtypes")
	public double getGain(int attrId, List<ArrayList<String>> dataset) {
		List<String> targetList = DataSetUtil.getTarget(dataset);
		List<String> attrValueList = DataSetUtil.getAttributeValue(attrId, dataset);

		double totalEntropy = getEntropy(targetList);

		Map<String, Double> probability = DataSetUtil.getProbability(attrValueList);

		double subEntropy = 0;

		Set set = probability.entrySet();
		Iterator iterator = set.iterator();
		while(iterator.hasNext()) {
			Map.Entry entry = (Entry) iterator.next();
			double prob = (double) entry.getValue();
			List<String> subTargetList = DataSetUtil.getTargetByAttribute((String) entry.getKey(), attrValueList, targetList);
			double entropy = getEntropy(subTargetList);

			subEntropy += prob * entropy;
		}

		return totalEntropy - subEntropy;
	}

	/**
	 * 获取某个属性的信息增益率 = Gain(A) / SplitInfo(A)
	 * <br/> 离散属性
	 * @param attrId
	 * @param dataset
	 * @return
	 */
	public double getGainRatio(int attrId, List<ArrayList<String>> dataset) {

		List<String> attrValueList = DataSetUtil.getAttributeValue(attrId, dataset);
		double gain = getGain(attrId, dataset);
		double splitInfo = getEntropy(attrValueList);

		return splitInfo == 0 ? 0 : gain/splitInfo;
	}

}</span>

下面的代码是数据集处理的相关操作。

<span style="font-size:14px;">import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class DataSetUtil {

	/**
	 * 获取数据集中的结果列
	 * @param dataset
	 * @return
	 */
	public static List<String> getTarget(List<ArrayList<String>> dataset) {
		List<String> target = new ArrayList<String>();
		int targetId = dataset.get(0).size() - 1;

		for(List<String> element : dataset) {
			target.add(element.get(targetId));
		}

		return target;
	}

	/**
	 * 获取属性值
	 * @param attrId
	 * @param dataset
	 * @return
	 */
	public static List<String> getAttributeValue(int attrId, List<ArrayList<String>> dataset) {
		List<String> attrValue = new ArrayList<String>();

		for(List<String> element : dataset) {
			attrValue.add(element.get(attrId));
		}

		return attrValue;
	}

	/**
	 * 获取属性值,唯一值
	 * @param bestAttr
	 * @param dataset
	 * @return
	 */
	@SuppressWarnings({ "rawtypes", "unchecked" })
	public static List<String> getAttributeValueOfUnique(int attrId, List<ArrayList<String>> dataset) {
		Set attrSet = new HashSet();
		List<String> attrValue = new ArrayList<String>();
		for(List<String> element : dataset) {
			attrSet.add(element.get(attrId));
		}

		Iterator iterator = attrSet.iterator();
		while(iterator.hasNext()) {
			attrValue.add((String) iterator.next());
		}

		return attrValue;
	}

	/**
	 * for test <br/>
	 * 输出数据集
	 * @param attribute
	 * @param dataset
	 */
	public static void printDataset(List<String> attribute, List<ArrayList<String>> dataset) {
		System.out.println(attribute);
		for(List<String> element : dataset) {
			System.out.println(element);
		}
	}

	/**
	 * 数据集纯度检测
	 */
	public static boolean isPure(List<String> data) {
		String result = data.get(0);
		for(int i = 1; i < data.size(); i++) {
			if(!data.get(i).equals(result))
				return false;
		}
		return true;
	}

	/**
	 * 对一列进行概率统计
	 * @param list
	 * @return
	 */
	public static Map<String, Double> getProbability(List<String> list) {
		double unitProb = 1.00/list.size();
		Map<String, Double> probability = new HashMap<String, Double>();
		for(String key : list) {
			if(probability.containsKey(key)) {
				probability.put(key, unitProb + probability.get(key));
			}else{
				probability.put(key, unitProb);
			}
		}
		return probability;
	}

	/**
	 * 根据属性值,分离出结果列target
	 * @param attrValue
	 * @param attrValueList
	 * @param targetList
	 * @return
	 */
	public static List<String> getTargetByAttribute(String attrValue,
							List<String> attrValueList, List<String> targetList) {
		List<String> result = new ArrayList<String>();
		for(int i=0; i<attrValueList.size(); i++) {
			if(attrValueList.get(i).equals(attrValue))
				result.add(targetList.get(i));
		}
		return result;
	}

	/**
	 * 拿出指定属性值对应的子数据集
	 * @param dataset
	 * @param bestAttr
	 * @param attrValue
	 * @return
	 */
	public static List<ArrayList<String>> getSubDataSetByAttribute(
			List<ArrayList<String>> dataset, int attrId, String attrValue) {
		List<ArrayList<String>> subDataset = new ArrayList<ArrayList<String>>();
		for(ArrayList<String> list : dataset) {
			if(list.get(attrId).equals(attrValue)) {
				ArrayList<String> cutList = new ArrayList<String>();
				cutList.addAll(list);
				cutList.remove(attrId);
				subDataset.add(cutList);
			}
		}
		System.out.println(subDataset);

		return subDataset;
	}

}</span>

下面代码是决策树的树节点对象实现。

<span style="font-size:14px;">import java.util.ArrayList;
import java.util.List;

public class TreeNode {
	public String attribute;
	public List<String> attributeValue;
	public List<TreeNode> child;
	//for leaf node
	public boolean isLeaf;
	public String targetValue;

	TreeNode() {
		attributeValue = new ArrayList<String>();
		child = new ArrayList<TreeNode>();
	}

	public String getAttribute() {
		return attribute;
	}

	public void setAttribute(String attribute) {
		this.attribute = attribute;
	}

	public List<String> getAttributeValue() {
		return attributeValue;
	}

	public void setAttributeValue(List<String> attributeValue) {
		this.attributeValue = attributeValue;
	}

	public void addAttributeValue(String attributeValue) {
		this.attributeValue.add(attributeValue);
	}

	public List<TreeNode> getChild() {
		return child;
	}

	public void setChild(List<TreeNode> child) {
		this.child = child;
	}

	public void addChild(TreeNode child) {
		this.child.add(child);
	}

	public boolean isLeaf() {
		return isLeaf;
	}

	public void setLeaf(boolean isLeaf) {
		this.isLeaf = isLeaf;
	}

	public String getTargetValue() {
		return targetValue;
	}

	public void setTargetValue(String targetValue) {
		this.targetValue = targetValue;
	}

	public void print(String depth) {
		if(!this.isLeaf){
			System.out.println(depth + this.attribute);
			depth += "\t";
			for(int i = 0; i < this.attributeValue.size(); i++) {
				System.out.println(depth + "---(" + this.attributeValue.get(i) + ")---" );
				this.child.get(i).print(depth + "\t");
			}
		} else {
			System.out.println(depth + "[" + this.targetValue + "]");
		}

	}

}</span>

这是很简单的实现,当然代码还是有很需要完善的地方。比如,数据集和相关的操作其实可以放在一个类里来实现,这里没有添加对连续变量的处理,剪枝也还没实现。anyway,C4.5算法的主要逻辑毕竟已经实现了,其余的用到的时候再慢慢扩充吧。

版权声明:本文为博主原创文章,未经博主允许不得转载。

时间: 2024-11-07 23:17:26

C4.5算法(一)代码实现的相关文章

决策树-预测隐形眼镜类型 (ID3算法,C4.5算法,CART算法,GINI指数,剪枝,随机森林)

1. 1.问题的引入 2.一个实例 3.基本概念 4.ID3 5.C4.5 6.CART 7.随机森林 2. 我们应该设计什么的算法,使得计算机对贷款申请人员的申请信息自动进行分类,以决定能否贷款? 一个女孩的母亲要给这个女孩介绍男朋友,于是有了下面的对话: 女儿:多大年纪了? 母亲:26. 女儿:长的帅不帅? 母亲:挺帅的. 女儿:收入高不? 母亲:不算很高,中等情况. 女儿:是公务员不? 母亲:是,在税务局上班呢. 女儿:那好,我去见见. 决策过程: 这个女孩的决策过程就是典型的分类树决策.

决策分类树算法之ID3,C4.5算法系列

一.引言 在最开始的时候,我本来准备学习的是C4.5算法,后来发现C4.5算法的核心还是ID3算法,所以又辗转回到学习ID3算法了,因为C4.5是他的一个改进.至于是什么改进,在后面的描述中我会提到. 二.ID3算法 ID3算法是一种分类决策树算法.他通过一系列的规则,将数据最后分类成决策树的形式.分类的根据是用到了熵这个概念.熵在物理这门学科中就已经出现过,表示是一个物质的稳定度,在这里就是分类的纯度的一个概念.公式为: 在ID3算法中,是采用Gain信息增益来作为一个分类的判定标准的.他的定

C4.5算法总结

C4.5是一系列用在机器学习和数据挖掘的分类问题中的算法.它的目标是监督学习:给定一个数据集,其中的每一个元组都能用一组属性值来描述,每一个元组属于一个互斥的类别中的某一类.C4.5的目标是通过学习,找到一个从属性值到类别的映射关系,并且这个映射能用于对新的类别未知的实体进行分类. C4.5由J.Ross Quinlan在ID3的基础上提出的.ID3算法用来构造决策树.决策树是一种类似流程图的树结构,其中每个内部节点(非树叶节点)表示在一个属性上的测试,每个分枝代表一个测试输出,而每个树叶节点存

C4.5算法(摘抄)

1. C4.5算法简介 C4.5是一系列用在机器学习和数据挖掘的分类问题中的算法.它的目标是监督学习:给定一个数据集,其中的每一个元组都能用一组属性值来描述,每一个元组属于一个互斥的类别中的某一类.C4.5的目标是通过学习,找到一个从属性值到类别的映射关系,并且这个映射能用于对新的类别未知的实体进行分类. C4.5由J.Ross Quinlan在ID3的基础上提出的.ID3算法用来构造决策树.决策树是一种类似流程图的树结构,其中每个内部节点(非树叶节点)表示在一个属性上的测试,每个分枝代表一个测

Python实现各种排序算法的代码示例总结

Python实现各种排序算法的代码示例总结 作者:Donald Knuth 字体:[增加 减小] 类型:转载 时间:2015-12-11我要评论 这篇文章主要介绍了Python实现各种排序算法的代码示例总结,其实Python是非常好的算法入门学习时的配套高级语言,需要的朋友可以参考下 在Python实践中,我们往往遇到排序问题,比如在对搜索结果打分的排序(没有排序就没有Google等搜索引擎的存在),当然,这样的例子数不胜数.<数据结构>也会花大量篇幅讲解排序.之前一段时间,由于需要,我复习了

决策树之C4.5算法学习

决策树<Decision Tree>是一种预测模型,它由决策节点,分支和叶节点三个部分组成.决策节点代表一个样本测试,通常代表待分类样本的某个属性,在该属性上的不同测试结果代表一个分支:分支表示某个决策节点的不同取值.每个叶节点代表一种可能的分类结果. 使用训练集对决策树算法进行训练,得到一个决策树模型,利用模型对未知样本(类别未知)的类别判断时,从决策树根节点开始,从上到下搜索,直到沿某分支到达叶节点,叶节点的类别标签就是该未知样本的类别. 网上有个例子可以很形象的说明利用决策树决策的过程(

排序算法总结---代码+性能

// data_sort_alg.cpp : 定义控制台应用程序的入口点. // #include "stdafx.h" #include "sort_alg.h" #include <iostream> #include <vector> void show(std::vector<int> &a) { std::vector<int>::iterator it=a.begin(); while(it!=a.

决策树-C4.5算法(三)

在上述两篇的文章中主要讲述了决策树的基础,但是在实际的应用中经常用到C4.5算法,C4.5算法是以ID3算法为基础,他在ID3算法上做了如下的改进: 1) 用信息增益率来选择属性,克服了用信息增益选择属性时偏向选择取值多的属性的不足,公式为GainRatio(A): 2) 在树构造过程中进行剪枝: 3) 能够完成对连续属性的离散化处理: 4) 能够对不完整数据进行处理. C4.5算法与其它分类算法如统计方法.神经网络等比较起来有如下优点:产生的分类规则易于理解,准确率较高.其缺点是:在构造树的过

C4.5算法

C4.5是一套用来处理分类问题的算法,属于有监督学习的类型,每个实例由一组属性来描述,每个实例仅属于一个类别. 如下是一个数据集 算法的发展历史 J.Ross Quinlan设计的C4.5算法源于名为ID3的一种决策树诱导算法. 而ID3是迭代分解器(iterative dichotomizers)系列算法的第3代. 除了可以分类之外,C4.5还可以具有良好可理解性的规则. Friedman的Original Tree算法. Breiman,Olshen和Stone的呢个人参与下发展为CART算