决策树归纳(ID3属性选择度量)Java实现

一般的决策树归纳框架见之前的博文:http://blog.csdn.net/zhyoulun/article/details/41978381

ID3属性选择度量原理

ID3使用信息增益作为属性选择度量。该度量基于香农在研究消息的值或”信息内容“的信息论方面的先驱工作。该结点N代表或存放分区D的元组。选择具有最高信息增益的属性作为结点N的分裂属性。该属性使结果分区中对元祖分类所需要的信息量最小,并反映这些分区中的最小随机性或”不纯性“。这种方法使得对一个对象分类所需要的期望测试数目最小,并确保找到一颗简单的(但不必是最简单的)树。

对D中的元组分类所需要的期望信息由下式给出,

其中pi是D忠任意元组属于类Ci的非零概率。使用以2为底的对数函数是因为信息用二进制编码。Info(D)是识别D中元组的类标号所需要的平均信息量。注意,此时我们所有的信息只是每个类的元组所占的百分比。

现在假设我们要按照某属性A划分D中的元组,其中属性A根据训练数据的观测具有v个不同的值{a1,a2,...av}。可以用属性A将D划分为v个分区或子集{D1,D2,...,Dv},其中Dj包含D中的元组,它们的A值为aj。这些分区对应于从节点N生长出来的分支。理想情况下,我们希望该划分产生元组的准确分类。即希望每个分区都是纯的(实际情况多半是不纯的,如分区可能包含来自不同类的元组)。在此划分之后,为了得到准确的分类,我们还需要多少信息?这个量由下式度量:

其中|Dj|/|D|充当第j个分区的权重。Info_A(D)是基于按A划分对D的元组分类所需要的期望值信息需要的期望信息越小,分区的纯度越高

信息增益定义为原来的信息需求(仅基于类比例)与新的信息需求(对A划分后)之前的差。即

换言之,Gain(A)告诉我们通过A上的划分我们得到了多少。它是知道A的值而导致的信息需求的期望减少。选择具有最高信息增益Gain(A)的属性A作为结点N的分裂属性。

以下为例子。

数据

data.txt

youth,high,no,fair,no
youth,high,no,excellent,no
middle_aged,high,no,fair,yes
senior,medium,no,fair,yes
senior,low,yes,fair,yes
senior,low,yes,excellent,no
middle_aged,low,yes,excellent,yes
youth,medium,no,fair,no
youth,low,yes,fair,yes
senior,medium,yes,fair,yes
youth,medium,yes,excellent,yes
middle_aged,medium,no,excellent,yes
middle_aged,high,yes,fair,yes
senior,medium,no,excellent,no

attr.txt

age,income,student,credit_rating,buys_computer

运算结果

age(1:youth; 2:middle_aged; 3:senior; )
	credit_rating(1:fair; 2:excellent; )
		leaf:no()
		leaf:yes()
	leaf:yes()
	student(1:no; 2:yes; )
		leaf:no()
		leaf:yes()

最后附上java代码

DecisionTree.java

package com.zhyoulun.decision;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Map;

/**
 * 负责数据的读入和写出,以及生成决策树
 *
 * @author zhyoulun
 *
 */
public class DecisionTree
{
	private ArrayList<ArrayList<String>> allDatas;
	private ArrayList<String> allAttributes;

	/**
	 * 从文件中读取所有相关数据
	 * @param dataFilePath
	 * @param attrFilePath
	 */
	public DecisionTree(String dataFilePath,String attrFilePath)
	{
		super();

		try
		{
			this.allDatas = new ArrayList<>();
			this.allAttributes = new ArrayList<>();

			InputStreamReader inputStreamReader = new InputStreamReader(new FileInputStream(new File(dataFilePath)));
			BufferedReader bufferedReader = new BufferedReader(inputStreamReader);
			String line = null;
			while((line=bufferedReader.readLine())!=null)
			{
				String[] strings = line.split(",");
				ArrayList<String> data = new ArrayList<>();
				for(int i=0;i<strings.length;i++)
					data.add(strings[i]);
				this.allDatas.add(data);
			}

			inputStreamReader = new InputStreamReader(new FileInputStream(new File(attrFilePath)));
			bufferedReader = new BufferedReader(inputStreamReader);
			while((line=bufferedReader.readLine())!=null)
			{
				String[] strings = line.split(",");
				for(int i=0;i<strings.length;i++)
					this.allAttributes.add(strings[i]);
			}

			inputStreamReader.close();
			bufferedReader.close();

		} catch (FileNotFoundException e)
		{
			// TODO Auto-generated catch block
			e.printStackTrace();
		} catch (IOException e)
		{
			// TODO Auto-generated catch block
			e.printStackTrace();
		}

//		for(int i=0;i<this.allAttributes.size();i++)
//		{
//			System.out.print(this.allAttributes.get(i)+" ");
//		}
//		System.out.println();
//
//		for(int i=0;i<this.allDatas.size();i++)
//		{
//			for(int j=0;j<this.allDatas.get(i).size();j++)
//			{
//				System.out.print(this.allDatas.get(i).get(j)+" ");
//			}
//			System.out.println();
//		}

	}

	/**
	 * @param allDatas
	 * @param allAttributes
	 */
	public DecisionTree(ArrayList<ArrayList<String>> allDatas,
			ArrayList<String> allAttributes)
	{
		super();
		this.allDatas = allDatas;
		this.allAttributes = allAttributes;
	}

	public ArrayList<ArrayList<String>> getAllDatas()
	{
		return allDatas;
	}

	public void setAllDatas(ArrayList<ArrayList<String>> allDatas)
	{
		this.allDatas = allDatas;
	}

	public ArrayList<String> getAllAttributes()
	{
		return allAttributes;
	}

	public void setAllAttributes(ArrayList<String> allAttributes)
	{
		this.allAttributes = allAttributes;
	}

	/**
	 * 递归生成决策数
	 * @return
	 */
	public static TreeNode generateDecisionTree(ArrayList<ArrayList<String>> datas, ArrayList<String> attrs)
	{
		TreeNode treeNode = new TreeNode();

		//如果D中的元素都在同一类C中,then
		if(isInTheSameClass(datas))
		{
			treeNode.setName(datas.get(0).get(datas.get(0).size()-1));
//			rootNode.setName();
			return treeNode;
		}
		//如果attrs为空,then(这种情况一般不会出现,我们应该是要对所有的候选属性集合构建决策树)
		if(attrs.size()==0)
			return treeNode;

		CriterionID3 criterionID3 = new CriterionID3(datas, attrs);
		int splitingCriterionIndex = criterionID3.attributeSelectionMethod();

		treeNode.setName(attrs.get(splitingCriterionIndex));
		treeNode.setRules(getValueSet(datas, splitingCriterionIndex));

		attrs.remove(splitingCriterionIndex);

		Map<String, ArrayList<ArrayList<String>>> subDatasMap = criterionID3.getSubDatasMap(splitingCriterionIndex);
//		for(String key:subDatasMap.keySet())
//		{
//			System.out.println("===========");
//			System.out.println(key);
//			for(int i=0;i<subDatasMap.get(key).size();i++)
//			{
//				for(int j=0;j<subDatasMap.get(key).get(i).size();j++)
//				{
//					System.out.print(subDatasMap.get(key).get(i).get(j)+" ");
//				}
//				System.out.println();
//			}
//		}

		for(String key:subDatasMap.keySet())
		{
			ArrayList<TreeNode> treeNodes = treeNode.getChildren();
			treeNodes.add(generateDecisionTree(subDatasMap.get(key), attrs));
			treeNode.setChildren(treeNodes);
		}

		return treeNode;
	}

	/**
	 * 获取datas中index列的值域
	 * @param data
	 * @param index
	 * @return
	 */
	public static ArrayList<String> getValueSet(ArrayList<ArrayList<String>> datas,int index)
	{
		ArrayList<String> values = new ArrayList<String>();
		String r = "";
		for (int i = 0; i < datas.size(); i++) {
			r = datas.get(i).get(index);
			if (!values.contains(r)) {
				values.add(r);
			}
		}
		return values;
	}

	/**
	 * 最后一列是类标号,判断最后一列是否相同
	 * @param datas
	 * @return
	 */
	private static boolean isInTheSameClass(ArrayList<ArrayList<String>> datas)
	{
		String flag = datas.get(0).get(datas.get(0).size()-1);//第0行,最后一列赋初值
		for(int i=0;i<datas.size();i++)
		{
			if(!datas.get(i).get(datas.get(i).size()-1).equals(flag))
				return false;
		}
		return true;
	}

	public static void main(String[] args)
	{
		String dataPath = "files/data.txt";
		String attrPath = "files/attr.txt";

		//初始化原始数据
		DecisionTree decisionTree = new DecisionTree(dataPath,attrPath);

		//生成决策树
		TreeNode treeNode = generateDecisionTree(decisionTree.getAllDatas(), decisionTree.getAllAttributes());

		print(treeNode,0);
	}

	private static void print(TreeNode treeNode,int level)
	{
		for(int i=0;i<level;i++)
			System.out.print("\t");
		System.out.print(treeNode.getName());
		System.out.print("(");
		for(int i=0;i<treeNode.getRules().size();i++)
			System.out.print((i+1)+":"+treeNode.getRules().get(i)+"; ");
		System.out.println(")");

		ArrayList<TreeNode> treeNodes = treeNode.getChildren();
		for(int i=0;i<treeNodes.size();i++)
		{
			print(treeNodes.get(i),level+1);
		}
	}

}

CriterionID3.java

package com.zhyoulun.decision;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/**
 * ID3,选择分裂准则
 *
 * @author zhyoulun
 *
 */
public class CriterionID3
{
	private ArrayList<ArrayList<String>> datas;
	private ArrayList<String> attributes;

	private Map<String, ArrayList<ArrayList<String>>> subDatasMap;

	/**
	 * 计算所有的信息增益,获取最大的一项作为分裂属性
	 * @return
	 */
	public int attributeSelectionMethod()
	{
		double gain = -1.0;
		int maxIndex = 0;
		for(int i=0;i<this.attributes.size()-1;i++)
		{
			double tempGain = this.calcGain(i);
			if(tempGain>gain)
			{
				gain = tempGain;
				maxIndex = i;
			}
		}

		return maxIndex;
	}

	/**
	 * 计算 Gain(age)=Info(D)-Info_age(D) 等
	 * @param index
	 * @return
	 */
	/**
	 * @param index
	 * @param isCalcSubDatasMap
	 * @return
	 */
	private double calcGain(int index)
	{
		double result = 0;

		//计算Info(D)
		int lastIndex = datas.get(0).size()-1;
		ArrayList<String> valueSet = DecisionTree.getValueSet(this.datas,lastIndex);
		for(String value:valueSet)
		{
			int count = 0;
			for(int i=0;i<datas.size();i++)
			{
				if(datas.get(i).get(lastIndex).equals(value))
					count++;
			}

			result += -(1.0*count/datas.size())*Math.log(1.0*count/datas.size())/Math.log(2);
//			System.out.println(result);
		}
//		System.out.println("==========");

		//计算Info_a(D)
		valueSet = DecisionTree.getValueSet(this.datas,index);

//		for(String temp:valueSet)
//			System.out.println(temp);
//		System.out.println("==========");

		for(String value:valueSet)
		{
			ArrayList<ArrayList<String>> subDatas = new ArrayList<>();
			for(int i=0;i<datas.size();i++)
			{
				if(datas.get(i).get(index).equals(value))
					subDatas.add(datas.get(i));
			}

//			for(ArrayList<String> temp:subDatas)
//			{
//				for(String temp2:temp)
//					System.out.print(temp2+" ");
//				System.out.println();
//			}

			ArrayList<String> subValueSet = DecisionTree.getValueSet(subDatas, lastIndex);

//			System.out.print("subValueSet:");
//			for(String temp:subValueSet)
//				System.out.print(temp+" ");
//			System.out.println();

			for(String subValue:subValueSet)
			{
//				System.out.println("+++++++++++++++");
//				System.out.println(subValue);
				int count = 0;
				for(int i=0;i<subDatas.size();i++)
				{
					if(subDatas.get(i).get(lastIndex).equals(subValue))
						count++;
				}
//				System.out.println(count);
				result += -1.0*subDatas.size()/datas.size()*(-(1.0*count/subDatas.size())*Math.log(1.0*count/subDatas.size())/Math.log(2));
//				System.out.println(result);
			}

		}

		return result;

	}

	public CriterionID3(ArrayList<ArrayList<String>> datas,
			ArrayList<String> attributes)
	{
		super();
		this.datas = datas;
		this.attributes = attributes;
	}

	public ArrayList<ArrayList<String>> getDatas()
	{
		return datas;
	}

	public void setDatas(ArrayList<ArrayList<String>> datas)
	{
		this.datas = datas;
	}

	public ArrayList<String> getAttributes()
	{
		return attributes;
	}

	public void setAttributes(ArrayList<String> attributes)
	{
		this.attributes = attributes;
	}

	public Map<String, ArrayList<ArrayList<String>>> getSubDatasMap(int index)
	{
		ArrayList<String> valueSet = DecisionTree.getValueSet(this.datas, index);
		this.subDatasMap = new HashMap<String, ArrayList<ArrayList<String>>>();

		for(String value:valueSet)
		{
			ArrayList<ArrayList<String>> subDatas = new ArrayList<>();
			for(int i=0;i<this.datas.size();i++)
			{
				if(this.datas.get(i).get(index).equals(value))
					subDatas.add(this.datas.get(i));
			}
			for(int i=0;i<subDatas.size();i++)
			{
				subDatas.get(i).remove(index);
			}
			this.subDatasMap.put(value, subDatas);
		}

		return subDatasMap;
	}

	public void setSubDatasMap(Map<String, ArrayList<ArrayList<String>>> subDatasMap)
	{
		this.subDatasMap = subDatasMap;
	}

}

TreeNode.java

package com.zhyoulun.decision;

import java.util.ArrayList;

public class TreeNode
{
	private String name; 								// 该结点的名称(分裂属性)
	private ArrayList<String> rules; 				// 结点的分裂规则(假设均为离散值)
//	private ArrayList<ArrayList<String>> datas; 	// 划分到该结点的训练元组(datas.get(i)表示一个训练元组)
//	private ArrayList<String> candidateAttributes; // 划分到该结点的候选属性(与训练元组的个数一致)
	private ArrayList<TreeNode> children; 			// 子结点

	public TreeNode()
	{
		this.name = "";
		this.rules = new ArrayList<String>();
		this.children = new ArrayList<TreeNode>();
//		this.datas = null;
//		this.candidateAttributes = null;
	}

	public String getName()
	{
		return name;
	}

	public void setName(String name)
	{
		this.name = name;
	}

	public ArrayList<String> getRules()
	{
		return rules;
	}

	public void setRules(ArrayList<String> rules)
	{
		this.rules = rules;
	}

	public ArrayList<TreeNode> getChildren()
	{
		return children;
	}

	public void setChildren(ArrayList<TreeNode> children)
	{
		this.children = children;
	}

//	public ArrayList<ArrayList<String>> getDatas()
//	{
//		return datas;
//	}
//
//	public void setDatas(ArrayList<ArrayList<String>> datas)
//	{
//		this.datas = datas;
//	}
//
//	public ArrayList<String> getCandidateAttributes()
//	{
//		return candidateAttributes;
//	}
//
//	public void setCandidateAttributes(ArrayList<String> candidateAttributes)
//	{
//		this.candidateAttributes = candidateAttributes;
//	}

}

参考:《数据挖掘概念与技术(第3版)》

转载请注明出处:

时间: 2024-11-04 14:58:15

决策树归纳(ID3属性选择度量)Java实现的相关文章

决策树归纳算法解析之ID3

学习是一个循序渐进的过程,我们首先来认识一下,什么是决策树.顾名思义,决策树就是拿来对一个事物做决策,作判断.那如何判断呢?凭什么判断呢?都是值得我们去思考的问题. 请看以下两个简单例子: 第一个例子 现想象一个女孩的母亲要给自己家的闺女介绍男朋友,女孩儿通过对方的一些情况来考虑要不要去,于是有了下面的对话: 女儿:多大年纪了?       母亲:26.       女儿:长的帅不帅?       母亲:挺帅的.       女儿:收入高不?       母亲:不算很高,中等情况.      

决策树之ID3、C4.5、C5.0 、CART

决策树是一种类似于流程图的树结构,其中,每个内部节点(非树叶节点)表示一个属性上的测试,每个分枝代表该测试的一个输出,而每个树叶节点(或终端节点存放一个类标号).树的最顶层节点是根节点.下图是一个典型的决策树(来自<数据挖掘:概念与技术>[韩家炜](中文第三版)第八章): 在构造决策树时,使用属性选择度量来选择将元祖划分成不同类的属性.这里我们介绍三种常用的属性选择度量-----信息增益.信息增益率和基尼指数.这里使用的符号如下.设数据分区\(D\)为标记类元组的训练集.假设类标号属性具有\(

决策树归纳一般框架(ID3,C4.5,CART)

感性认识决策树 构建决策树的目的是对已有的数据进行分类,得到一个树状的分类规则,然后就可以拿这个规则对未知的数据进行分类预测. 决策树归纳是从有类标号的训练元祖中学习决策树. 决策树是一种类似于流程图的树结构,其中每个内部节点(非树叶结点)表示一个属性上的测试,每个分支代表该测试上的一个输出,而每个树叶结点(或终端结点)存放一个类标号.树的最顶层结点是根结点.一个典型的决策树如下图所示, 该决策树是通过下表所示的训练元组和它们对应的类标号得到的, 为什么决策树如此流行 决策树分类器的构造不需要任

[梁山好汉说IT] 熵的概念 &amp; 决策树ID3如何选择子树

记录对概念的理解,用梁山好汉做例子来检验是否理解正确. 1. 事物的信息和信息熵 a. 事物的信息(信息量越大确定性越大): 信息会改变你对事物的未知度和好奇心.信息量越大,你对事物越了解,进而你对事物的好奇心也会降低,因为你对事物的确定性越高.如果你确定一件事件的发生概率是100%,你认为这件事情的信息量为0——可不是吗,既然都确定了,就没有信息量了:相反,如果你不确定这件事,你需要通过各种方式去了解,就说明这件事是有意义的,是有信息量的. b. 信息熵:为了抽象上述模型,聪明的香农总结出了信

决策树:ID3与C4.5算法

1.基本概念 1)定义: 决策树是一个预测模型:他代表的是对象属性与对象值之间的一种映射关系,树中每个节点代表的某个可能的属性值. 2)表示方法: 通过把实例从根结点排列到某个叶子结点来分类实例,叶子结点即为实例所属的分类.树上的每一个结点指定了对某个属性的测试,并在该结点的每一个后继分支对应于该属性的一个可能值. 3)决策树适用问题: a.实例是由‘属性-值’对表示的 b.目标函数具有离散的输出值 c.可能需要十析取的描述 d.训练数据可以包含错误 e.训练数据可以包含缺少属性值的实例 2.I

[转载]简单易学的机器学习算法-决策树之ID3算的

一.决策树分类算法概述 决策树算法是从数据的属性(或者特征)出发,以属性作为基础,划分不同的类.例如对于如下数据集 (数据集) 其中,第一列和第二列为属性(特征),最后一列为类别标签,1表示是,0表示否.决策树算法的思想是基于属性对数据分类,对于以上的数据我们可以得到以下的决策树模型 (决策树模型) 先是根据第一个属性将一部份数据区分开,再根据第二个属性将剩余的区分开. 实现决策树的算法有很多种,有ID3.C4.5和CART等算法.下面我们介绍ID3算法. 二.ID3算法的概述 ID3算法是由Q

JQuery 多属性选择节点

JQuery 1.6.0+以后用prop()代替attr(); 多属性选择节点 $("input[type=checkbox][name='first2'][value='first4']"). 1 <%@ page language="java" contentType="text/html; charset=UTF-8" 2 pageEncoding="UTF-8"%> 3 <!DOCTYPE html

数据仓库专题(8)-维度属性选择之维护历史是否应该保留

一.背景 数据仓库建模过程中,针对事务型事实表设计,经常会遇到维度属性选择的问题,比如客户维度,在操作型系统中,为了跟踪客户状态的变化,往往会附加客户记录的四个属性: 1.add time:添加时间: 2.add user:添加用户: 3.mod time:修改时间: 4.mod user:修改用户: 问题在于,当我们进行维度建模的时候,如果以客户作为维度,是否应该考虑以上四个属性? 二.观点 1.应该保留 (1)我觉得 添加时间 可以作为维度属性,以后可能进行相关的统计: 2.不应该保留 (1

决策树之ID3算法

一.决策树之ID3算法简述 1976年-1986年,J.R.Quinlan给出ID3算法原型并进行了总结,确定了决策树学习的理论.这可以看做是决策树算法的起点.1993,Quinlan将ID3算法改进成C4.5算法,称为机器学习的十大算法之一.ID3算法的另一个分支是CART(Classification adn Regression Tree, 分类回归决策树),用于预测.这样,决策树理论完全覆盖了机器学习中的分类和回归两个领域. 本文只做了ID3算法的回顾,所选数据的字段全部是有序多分类的分