决策树的核心思想是:根据训练样本构建这样一棵树,使得其叶节点是分类标签,非叶节点是判断条件,这样对于一个未知样本,能在树上找到一条路径到达叶节点,就得到了它的分类。
举个简单的例子,如何识别有毒的蘑菇?如果能够得到一棵这样的决策树,那么对于一个未知的蘑菇就很容易判断出它是否有毒了。
它是什么颜色的? | -------鲜艳---------浅色---- | | 有毒 有什么气味? | -----刺激性--------无味----- | | 有毒 安全
构建决策树有很多算法,常用的有ID3、C4.5、CART等。本篇以ID3为研究算法。
构建决策树的关键在于每一次分支时选择哪个特征作为分界条件。这里的原则是:选择最能把数据变得有序的特征作为分界条件。所谓有序,是指划分后,每一个分支集合的标签尽可能一致。用信息论的方式表述,就是选择信息增益最大的方式划分集合。
所谓信息增益(information gain),是指变化前后熵(entropy)的增加量。为了计算熵,需要计算所有类别所有可能值包含的信息期望值,通过下面的公式得到:
其中H为熵,n为分类数目,p(xi)是选择该分类的概率。
根据公式,计算一个集合熵的方式为:
计算每个分类出现的次数 foreach(每一个分类) { 计算出现概率 根据概率计算熵 累加熵 } return 累加结果
判断如何划分集合,方式为:
foreach(每一个特征) { 计算按此特征切分时的熵 计算与切分前相比的信息增益 保留能产生最大增益的特征为切分方式 } return 选定的特征
构建树节点的方法为:
if(集合没有特征可用了) { 按多数原则决定此节点的分类 } else if(集合中所有样本的分类都一致) { 此标签就是节点分类 } else { 以最佳方式切分集合 每一种可能形成当前节点的一个分支 递归 }
OK,上C#版代码,DataVector和上篇文章一样,不放了,只放核心算法:
using System; using System.Collections.Generic; namespace MachineLearning { /// <summary> /// 决策树节点 /// </summary> public class DecisionNode { /// <summary> /// 此节点的分类标签,为空表示此节点不是叶节点 /// </summary> public string Label { get; set; } /// <summary> /// 此节点的划分特征,为-1表示此节点是叶节点 /// </summary> public int FeatureIndex { get; set; } /// <summary> /// 分支 /// </summary> public Dictionary<string, DecisionNode> Child { get; set; } public DecisionNode() { this.FeatureIndex = -1; this.Child = new Dictionary<string, DecisionNode>(); } } }
using System; using System.Collections.Generic; using System.Linq; namespace MachineLearning { /// <summary> /// 决策树(ID3算法) /// </summary> public class DecisionTree { private DecisionNode m_Tree; /// <summary> /// 训练 /// </summary> /// <param name="trainingSet"></param> public void Train(List<DataVector<string>> trainingSet) { var features = new List<int>(trainingSet[0].Dimension); for(int i = 0;i < trainingSet[0].Dimension;++i) features.Add(i); //生成决策树 m_Tree = CreateTree(trainingSet, features); } /// <summary> /// 分类 /// </summary> /// <param name="vector"></param> /// <returns></returns> public string Classify(DataVector<string> vector) { return Classify(vector, m_Tree); } /// <summary> /// 分类 /// </summary> /// <param name="vector"></param> /// <param name="node"></param> /// <returns></returns> private string Classify(DataVector<string> vector, DecisionNode node) { var label = string.Empty; if(!string.IsNullOrEmpty(node.Label)) { //是叶节点,直接返回结果 label = node.Label; } else { //取需要分类的字段,继续深入 var key = vector.Data[node.FeatureIndex]; if(node.Child.ContainsKey(key)) label = Classify(vector, node.Child[key]); else label = "[UNKNOWN]"; } return label; } /// <summary> /// 创建决策树 /// </summary> /// <param name="dataSet"></param> /// <param name="features"></param> /// <returns></returns> private DecisionNode CreateTree(List<DataVector<string>> dataSet, List<int> features) { var node = new DecisionNode(); if(dataSet[0].Dimension == 0) { //所有字段已用完,按多数原则决定Label,结束分类 node.Label = GetMajorLabel(dataSet); } else if(dataSet.Count == dataSet.Count(d => string.Equals(d.Label, dataSet[0].Label))) { //如果数据集中的Label相同,结束分类 node.Label = dataSet[0].Label; } else { //挑选一个最佳分类,分割集合,递归 int featureIndex = ChooseBestFeature(dataSet); node.FeatureIndex = features[featureIndex]; var uniqueValues = GetUniqueValues(dataSet, featureIndex); features.RemoveAt(featureIndex); foreach(var value in uniqueValues) { node.Child[value.ToString()] = CreateTree(SplitDataSet(dataSet, featureIndex, value), new List<int>(features)); } } return node; } /// <summary> /// 计算给定集合的香农熵 /// </summary> /// <param name="dataSet"></param> /// <returns></returns> private double ComputeShannon(List<DataVector<string>> dataSet) { double shannon = 0.0; var dict = new Dictionary<string, int>(); foreach(var item in dataSet) { if(!dict.ContainsKey(item.Label)) dict[item.Label] = 0; dict[item.Label] += 1; } foreach(var label in dict.Keys) { double prob = dict[label] * 1.0 / dataSet.Count; shannon -= prob * Math.Log(prob, 2); } return shannon; } /// <summary> /// 用给定的方式切分出数据子集 /// </summary> /// <param name="dataSet"></param> /// <param name="splitIndex"></param> /// <param name="value"></param> /// <returns></returns> private List<DataVector<string>> SplitDataSet(List<DataVector<string>> dataSet, int splitIndex, string value) { var newDataSet = new List<DataVector<string>>(); foreach(var item in dataSet) { //只保留指定维度上符合给定值的项 if(item.Data[splitIndex] == value) { var newItem = new DataVector<string>(item.Dimension - 1); newItem.Label = item.Label; Array.Copy(item.Data, 0, newItem.Data, 0, splitIndex - 0); Array.Copy(item.Data, splitIndex + 1, newItem.Data, splitIndex, item.Dimension - splitIndex - 1); newDataSet.Add(newItem); } } return newDataSet; } /// <summary> /// 在给定的数据集上选择一个最好的切分方式 /// </summary> /// <param name="dataSet"></param> /// <returns></returns> private int ChooseBestFeature(List<DataVector<string>> dataSet) { int bestFeature = -1; double bestInfoGain = 0.0; double baseShannon = ComputeShannon(dataSet); //遍历每一个维度来寻找 for(int i = 0;i < dataSet[0].Dimension;++i) { var uniqueValues = GetUniqueValues(dataSet, i); double newShannon = 0.0; //遍历此维度下的每一个可能值,切分数据集并计算熵 foreach(var value in uniqueValues) { var subSet = SplitDataSet(dataSet, i, value); double prob = subSet.Count * 1.0 / dataSet.Count; newShannon += prob * ComputeShannon(subSet); } //计算信息增益,保留最佳切分方式 double infoGain = baseShannon - newShannon; if(infoGain > bestInfoGain) { bestInfoGain = infoGain; bestFeature = i; } } return bestFeature; } /// <summary> /// 数据去重 /// </summary> /// <param name="dataSet"></param> /// <param name="index"></param> /// <returns></returns> private List<string> GetUniqueValues(List<DataVector<string>> dataSet, int index) { var dict = new Dictionary<string, int>(); foreach(var item in dataSet) { dict[item.Data[index]] = 0; } return dict.Keys.ToList<string>(); } /// <summary> /// 取多数标签 /// </summary> /// <param name="dataSet"></param> /// <returns></returns> private string GetMajorLabel(List<DataVector<string>> dataSet) { var dict = new Dictionary<string, int>(); foreach(var item in dataSet) { if(!dict.ContainsKey(item.Label)) dict[item.Label] = 0; dict[item.Label]++; } string label = string.Empty; int count = -1; foreach(var key in dict.Keys) { if(dict[key] > count) { label = key; count = dict[key]; } } return label; } } }
拿个例子实际检验一下,还是以毒蘑菇的识别为例,从这里找了点数据,http://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data ,它整理了8000多个样本,每个样本描述了蘑菇的22个属性,比如形状、气味等等,然后给出了这个蘑菇是否可食用。
比如一行数据:p,x,s,n,t,p,f,c,n,k,e,e,s,s,w,w,p,w,o,p,k,s,u
第0个元素p表示poisonous(有毒),其它22个元素分别是蘑菇的属性,可以参见agaricus-lepiota.names的描述,但实际上根本不用关心具体含义。以此构建样本并测试错误率:
public void TestDecisionTree() { var trainingSet = new List<DataVector<string>>(); //训练数据集 var testSet = new List<DataVector<string>>(); //测试数据集 //读取数据 var file = new StreamReader("agaricus-lepiota.data", Encoding.Default); string line = string.Empty; int count = 0; while((line = file.ReadLine()) != null) { var parts = line.Split(‘,‘); var p = new DataVector<string>(22); p.Label = parts[0]; for(int i = 0;i < p.Dimension;++i) p.Data[i] = parts[i + 1]; //前7000作为训练样本,其余作为测试样本 if(++count <= 7000) trainingSet.Add(p); else testSet.Add(p); } file.Close(); //检验 var dt = new DecisionTree(); dt.Train(trainingSet); int error = 0; foreach(var p in testSet) { //做猜测分类,并与实际结果比较 var label = dt.Classify(p); if(label != p.Label) ++error; } Console.WriteLine("Error = {0}/{1}, {2}%", error, testSet.Count, (error * 100.0 / testSet.Count)); }
使用7000个样本做训练,1124个样本做测试,只有4个猜测出错,错误率仅有0.35%,相当不错的结果。
生成的决策树是这样的:
{ "FeatureIndex": 4, //按第4个特征划分 "Child": { "p": {"Label": "p"}, //如果第4个特征是p,则分类为p "a": {"Label": "e"}, //如果第4个特征是a,则分类是e "l": {"Label": "e"}, "n": { "FeatureIndex": 19, //如果第4个特征是n,要继续按第19个特征划分 "Child": { "n": {"Label": "e"}, "k": {"Label": "e"}, "w": { "FeatureIndex": 21, "Child": { "w": {"Label": "e"}, "l": { "FeatureIndex": 2, "Child": { "c": {"Label": "e"}, "n": {"Label": "e"}, "w": {"Label": "p"}, "y": {"Label": "p"} } }, "d": { "FeatureIndex": 1, "Child": { "y": {"Label": "p"}, "f": {"Label": "p"}, "s": {"Label": "e"} } }, "g": {"Label": "e"}, "p": {"Label": "e"} } }, "h": {"Label": "e"}, "r": {"Label": "p"}, "o": {"Label": "e"}, "y": {"Label": "e"}, "b": {"Label": "e"} } }, "f": {"Label": "p"}, "c": {"Label": "p"}, "y": {"Label": "p"}, "s": {"Label": "p"}, "m": {"Label": "p"} } }
可以看到,实际只使用了其中的5个特征,就能做出比较精确的判断了。