机器学习算法:决策树

决策树的核心思想是:根据训练样本构建这样一棵树,使得其叶节点是分类标签,非叶节点是判断条件,这样对于一个未知样本,能在树上找到一条路径到达叶节点,就得到了它的分类。

举个简单的例子,如何识别有毒的蘑菇?如果能够得到一棵这样的决策树,那么对于一个未知的蘑菇就很容易判断出它是否有毒了。

                          
                          它是什么颜色的?
                               |
                 -------鲜艳---------浅色----
                |                           |
              有毒                      有什么气味?
                                            |
                              -----刺激性--------无味-----
                             |                           |
                            有毒                        安全

构建决策树有很多算法,常用的有ID3、C4.5、CART等。本篇以ID3为研究算法。

构建决策树的关键在于每一次分支时选择哪个特征作为分界条件。这里的原则是:选择最能把数据变得有序的特征作为分界条件。所谓有序,是指划分后,每一个分支集合的标签尽可能一致。用信息论的方式表述,就是选择信息增益最大的方式划分集合。

所谓信息增益(information gain),是指变化前后熵(entropy)的增加量。为了计算熵,需要计算所有类别所有可能值包含的信息期望值,通过下面的公式得到:

其中H为熵,n为分类数目,p(xi)是选择该分类的概率。

根据公式,计算一个集合熵的方式为:

计算每个分类出现的次数
foreach(每一个分类)
{
    计算出现概率
    根据概率计算熵
    累加熵
}
return 累加结果

判断如何划分集合,方式为:

foreach(每一个特征)
{
    计算按此特征切分时的熵
    计算与切分前相比的信息增益
    保留能产生最大增益的特征为切分方式
}
return 选定的特征

构建树节点的方法为:

if(集合没有特征可用了)
{
    按多数原则决定此节点的分类
}
else if(集合中所有样本的分类都一致)
{
    此标签就是节点分类
}
else
{
    以最佳方式切分集合
    每一种可能形成当前节点的一个分支
    递归
}

OK,上C#版代码,DataVector和上篇文章一样,不放了,只放核心算法:

using System;
using System.Collections.Generic;

namespace MachineLearning
{
    /// <summary>
    /// 决策树节点
    /// </summary>
    public class DecisionNode
    {
        /// <summary>
        /// 此节点的分类标签,为空表示此节点不是叶节点
        /// </summary>
        public string Label { get; set; }
        /// <summary>
        /// 此节点的划分特征,为-1表示此节点是叶节点
        /// </summary>
        public int FeatureIndex { get; set; }
        /// <summary>
        /// 分支
        /// </summary>
        public Dictionary<string, DecisionNode> Child { get; set; }

        public DecisionNode()
        {
            this.FeatureIndex = -1;
            this.Child = new Dictionary<string, DecisionNode>();
        }
    }
}
using System;
using System.Collections.Generic;
using System.Linq;

namespace MachineLearning
{
    /// <summary>
    /// 决策树(ID3算法)
    /// </summary>
    public class DecisionTree
    {
        private DecisionNode m_Tree;

        /// <summary>
        /// 训练
        /// </summary>
        /// <param name="trainingSet"></param>
        public void Train(List<DataVector<string>> trainingSet)
        {
            var features = new List<int>(trainingSet[0].Dimension);
            for(int i = 0;i < trainingSet[0].Dimension;++i)
                features.Add(i);
                
            //生成决策树
            m_Tree = CreateTree(trainingSet, features);
        }

        /// <summary>
        /// 分类
        /// </summary>
        /// <param name="vector"></param>
        /// <returns></returns>
        public string Classify(DataVector<string> vector)
        {
            return Classify(vector, m_Tree);
        }

        /// <summary>
        /// 分类
        /// </summary>
        /// <param name="vector"></param>
        /// <param name="node"></param>
        /// <returns></returns>
        private string Classify(DataVector<string> vector, DecisionNode node)
        {
            var label = string.Empty;
            
            if(!string.IsNullOrEmpty(node.Label))
            {
                //是叶节点,直接返回结果
                label = node.Label;
            }
            else
            {
                //取需要分类的字段,继续深入
                var key = vector.Data[node.FeatureIndex];
                if(node.Child.ContainsKey(key))
                    label = Classify(vector, node.Child[key]);
                else
                    label = "[UNKNOWN]";
            }
            return label;
        }
        
        /// <summary>
        /// 创建决策树
        /// </summary>
        /// <param name="dataSet"></param>
        /// <param name="features"></param>
        /// <returns></returns>
        private DecisionNode CreateTree(List<DataVector<string>> dataSet, List<int> features)
        {
            var node = new DecisionNode();
            
            if(dataSet[0].Dimension == 0)
            {
                //所有字段已用完,按多数原则决定Label,结束分类
                node.Label = GetMajorLabel(dataSet);
            }
            else if(dataSet.Count == dataSet.Count(d => string.Equals(d.Label, dataSet[0].Label)))
            {
                //如果数据集中的Label相同,结束分类
                node.Label = dataSet[0].Label;
            }
            else
            {
                //挑选一个最佳分类,分割集合,递归
                int featureIndex = ChooseBestFeature(dataSet);
                node.FeatureIndex = features[featureIndex];
                var uniqueValues = GetUniqueValues(dataSet, featureIndex);
                features.RemoveAt(featureIndex);
                foreach(var value in uniqueValues)
                {
                    node.Child[value.ToString()] = CreateTree(SplitDataSet(dataSet, featureIndex, value), new List<int>(features));
                }
            }
            
            return node;
        }
        
        /// <summary>
        /// 计算给定集合的香农熵
        /// </summary>
        /// <param name="dataSet"></param>
        /// <returns></returns>
        private double ComputeShannon(List<DataVector<string>> dataSet)
        {
            double shannon = 0.0;
            
            var dict = new Dictionary<string, int>();
            foreach(var item in dataSet)
            {
                if(!dict.ContainsKey(item.Label))
                    dict[item.Label] = 0;
                dict[item.Label] += 1;
            }
            
            foreach(var label in dict.Keys)
            {
                double prob = dict[label] * 1.0 / dataSet.Count;
                shannon -= prob * Math.Log(prob, 2);
            }
            
            return shannon;
        }
        
        /// <summary>
        /// 用给定的方式切分出数据子集
        /// </summary>
        /// <param name="dataSet"></param>
        /// <param name="splitIndex"></param>
        /// <param name="value"></param>
        /// <returns></returns>
        private List<DataVector<string>> SplitDataSet(List<DataVector<string>> dataSet, int splitIndex, string value)
        {
            var newDataSet = new List<DataVector<string>>();
            
            foreach(var item in dataSet)
            {
                //只保留指定维度上符合给定值的项
                if(item.Data[splitIndex] == value)
                {
                    var newItem = new DataVector<string>(item.Dimension - 1);
                    newItem.Label = item.Label;
                    Array.Copy(item.Data, 0, newItem.Data, 0, splitIndex - 0);
                    Array.Copy(item.Data, splitIndex + 1, newItem.Data, splitIndex, item.Dimension - splitIndex - 1);
                    newDataSet.Add(newItem);
                }
            }
            
            return newDataSet;
        }

        /// <summary>
        /// 在给定的数据集上选择一个最好的切分方式
        /// </summary>
        /// <param name="dataSet"></param>
        /// <returns></returns>
        private int ChooseBestFeature(List<DataVector<string>> dataSet)
        {
            int bestFeature = -1;
            double bestInfoGain = 0.0;
            double baseShannon = ComputeShannon(dataSet);
            
            //遍历每一个维度来寻找
            for(int i = 0;i < dataSet[0].Dimension;++i)
            {
                var uniqueValues = GetUniqueValues(dataSet, i);
                double newShannon = 0.0;

                //遍历此维度下的每一个可能值,切分数据集并计算熵
                foreach(var value in uniqueValues)
                {
                    var subSet = SplitDataSet(dataSet, i, value);
                    double prob = subSet.Count * 1.0 / dataSet.Count;
                    newShannon += prob * ComputeShannon(subSet);
                }

                //计算信息增益,保留最佳切分方式
                double infoGain = baseShannon - newShannon;
                if(infoGain > bestInfoGain)
                {
                    bestInfoGain = infoGain;
                    bestFeature = i;
                }
            }
            
            return bestFeature;
        }

        /// <summary>
        /// 数据去重
        /// </summary>
        /// <param name="dataSet"></param>
        /// <param name="index"></param>
        /// <returns></returns>
        private List<string> GetUniqueValues(List<DataVector<string>> dataSet, int index)
        {
            var dict = new Dictionary<string, int>();
            foreach(var item in dataSet)
            {
                dict[item.Data[index]] = 0;
            }
            return dict.Keys.ToList<string>();
        }

        /// <summary>
        /// 取多数标签
        /// </summary>
        /// <param name="dataSet"></param>
        /// <returns></returns>
        private string GetMajorLabel(List<DataVector<string>> dataSet)
        {
            var dict = new Dictionary<string, int>();
            foreach(var item in dataSet)
            {
                if(!dict.ContainsKey(item.Label))
                    dict[item.Label] = 0;
                dict[item.Label]++;
            }

            string label = string.Empty;
            int count = -1;
            foreach(var key in dict.Keys)
            {
                if(dict[key] > count)
                {
                    label = key;
                    count = dict[key];
                }
            }
            
            return label;
        }
    }
}

拿个例子实际检验一下,还是以毒蘑菇的识别为例,从这里找了点数据,http://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data ,它整理了8000多个样本,每个样本描述了蘑菇的22个属性,比如形状、气味等等,然后给出了这个蘑菇是否可食用。

比如一行数据:p,x,s,n,t,p,f,c,n,k,e,e,s,s,w,w,p,w,o,p,k,s,u

第0个元素p表示poisonous(有毒),其它22个元素分别是蘑菇的属性,可以参见agaricus-lepiota.names的描述,但实际上根本不用关心具体含义。以此构建样本并测试错误率:

public void TestDecisionTree()
{
    var trainingSet = new List<DataVector<string>>();    //训练数据集
    var testSet = new List<DataVector<string>>();        //测试数据集
    
    //读取数据
    var file = new StreamReader("agaricus-lepiota.data", Encoding.Default);
    string line = string.Empty;
    int count = 0;
    while((line = file.ReadLine()) != null)
    {
        var parts = line.Split(‘,‘);
        
        var p = new DataVector<string>(22);
        p.Label = parts[0];
        for(int i = 0;i < p.Dimension;++i)
            p.Data[i] = parts[i + 1];
            
        //前7000作为训练样本,其余作为测试样本
        if(++count <= 7000)
            trainingSet.Add(p);
        else
            testSet.Add(p);
    }
    file.Close();

    //检验
    var dt = new DecisionTree();
    dt.Train(trainingSet);
    int error = 0;
    foreach(var p in testSet)
    {
        //做猜测分类,并与实际结果比较
        var label = dt.Classify(p);
        if(label != p.Label)
            ++error;
    }

    Console.WriteLine("Error = {0}/{1}, {2}%", error, testSet.Count, (error * 100.0 / testSet.Count));
}

使用7000个样本做训练,1124个样本做测试,只有4个猜测出错,错误率仅有0.35%,相当不错的结果。

生成的决策树是这样的:

{
    "FeatureIndex": 4,              //按第4个特征划分
    "Child": {
        "p": {"Label": "p"},        //如果第4个特征是p,则分类为p
        "a": {"Label": "e"},        //如果第4个特征是a,则分类是e
        "l": {"Label": "e"},
        "n": {
            "FeatureIndex": 19,            //如果第4个特征是n,要继续按第19个特征划分
            "Child": {
                "n": {"Label": "e"},
                "k": {"Label": "e"},
                "w": {
                    "FeatureIndex": 21,
                    "Child": {
                        "w": {"Label": "e"},
                        "l": {
                            "FeatureIndex": 2,
                            "Child": {
                                "c": {"Label": "e"},
                                "n": {"Label": "e"},
                                "w": {"Label": "p"},
                                "y": {"Label": "p"}
                            }
                        },
                        "d": {
                            "FeatureIndex": 1,
                            "Child": {
                                "y": {"Label": "p"},
                                "f": {"Label": "p"},
                                "s": {"Label": "e"}
                            }
                        },
                        "g": {"Label": "e"},
                        "p": {"Label": "e"}
                    }
                },
                "h": {"Label": "e"},
                "r": {"Label": "p"},
                "o": {"Label": "e"},
                "y": {"Label": "e"},
                "b": {"Label": "e"}
            }
        },
        "f": {"Label": "p"},
        "c": {"Label": "p"},
        "y": {"Label": "p"},
        "s": {"Label": "p"},
        "m": {"Label": "p"}
    }
}

可以看到,实际只使用了其中的5个特征,就能做出比较精确的判断了。

时间: 2024-10-09 16:32:04

机器学习算法:决策树的相关文章

[转载]简单易学的机器学习算法-决策树之ID3算的

一.决策树分类算法概述 决策树算法是从数据的属性(或者特征)出发,以属性作为基础,划分不同的类.例如对于如下数据集 (数据集) 其中,第一列和第二列为属性(特征),最后一列为类别标签,1表示是,0表示否.决策树算法的思想是基于属性对数据分类,对于以上的数据我们可以得到以下的决策树模型 (决策树模型) 先是根据第一个属性将一部份数据区分开,再根据第二个属性将剩余的区分开. 实现决策树的算法有很多种,有ID3.C4.5和CART等算法.下面我们介绍ID3算法. 二.ID3算法的概述 ID3算法是由Q

机器学习---算法---决策树

转自:https://blog.csdn.net/qq_43208303/article/details/84837412 决策树是一种机器学习的方法.决策树的生成算法有ID3, C4.5和CART等.决策树是一种树形结构,其中每个内部节点表示一个属性上的判断,每个分支代表一个判断结果的输出,最后每个叶节点代表一种分类结果.决策树是一种十分常用的分类方法,需要监管学习(有教师的Supervised Learning),监管学习就是给出一堆样本,每个样本都有一组属性和一个分类结果,也就是分类结果已

简单易学的机器学习算法——AdaBoost

一.集成方法(Ensemble Method) 集成方法主要包括Bagging和Boosting两种方法,随机森林算法是基于Bagging思想的机器学习算法,在Bagging方法中,主要通过对训练数据集进行随机采样,以重新组合成不同的数据集,利用弱学习算法对不同的新数据集进行学习,得到一系列的预测结果,对这些预测结果做平均或者投票做出最终的预测.AdaBoost算法和GBDT(Gradient Boost Decision Tree,梯度提升决策树)算法是基于Boosting思想的机器学习算法.

机器学习算法之决策树

机器学习算法之决策树 什么是决策树 决策树(Decision Tree)是一种简单但是广泛使用的分类器.通过训练数据构建决策树,可以高效的对未知的数据进行分类.决策数有两大优点:1)决策树模型可以读性好,具有描述性,有助于人工分析:2)效率高,决策树只需要一次构建,反复使用,每一次预测的最大计算次数不超过决策树的深度. 决策树是一个树结构(可以是二叉树或者非二叉树),非叶节点表示一个特征属性上的测试,每个分支代表在某个值域上的输出,每个叶节点存放一个类别. 测试就是按照从根节点往下走,直到叶节点

机器学习算法的R语言实现(二):决策树

1.介绍 ?决策树(decision tree)是一种有监督的机器学习算法,是一个分类算法.在给定训练集的条件下,生成一个自顶而下的决策树,树的根为起点,树的叶子为样本的分类,从根到叶子的路径就是一个样本进行分类的过程. ?下图为一个决策树的例子,见http://zh.wikipedia.org/wiki/%E5%86%B3%E7%AD%96%E6%A0%91 ? 可见,决策树上的判断节点是对某一个属性进行判断,生成的路径数量为该属性可能的取值,最终到叶子节点时,就完成一个分类(或预测).决策树

【机器学习算法-python实现】决策树-Decision tree(1) 信息熵划分数据集

(转载请注明出处:http://blog.csdn.net/buptgshengod) 1.背景 决策书算法是一种逼近离散数值的分类算法,思路比較简单,并且准确率较高.国际权威的学术组织,数据挖掘国际会议ICDM (the IEEE International Conference on Data Mining)在2006年12月评选出了数据挖掘领域的十大经典算法中,C4.5算法排名第一.C4.5算法是机器学习算法中的一种分类决策树算法,其核心算法是ID3算法. 算法的主要思想就是将数据集依照特

【机器学习算法-python实现】决策树-Decision tree(2) 决策树的实现

(转载请注明出处:http://blog.csdn.net/buptgshengod) 1.背景 接着上一节说,没看到请先看一下上一节关于数据集的划分数据集划分.如今我们得到了每一个特征值得信息熵增益,我们依照信息熵增益的从大到校的顺序,安排排列为二叉树的节点.数据集和二叉树的图见下. (二叉树的图是用python的matplotlib库画出来的) 数据集: 决策树: 2.代码实现部分 由于上一节,我们通过chooseBestFeatureToSplit函数已经能够确定当前数据集中的信息熵最大的

【机器学习算法-python实现】Adaboost的实现(1)-单层决策树(decision stump)

(转载请注明出处:http://blog.csdn.net/buptgshengod) 1.背景 上一节学习支持向量机,感觉公式都太难理解了,弄得我有点头大.不过这一章的Adaboost线比较起来就容易得多.Adaboost是用元算法的思想进行分类的.什么事元算法的思想呢?就是根据数据集的不同的特征在决定结果时所占的比重来划分数据集.就是要对每个特征值都构建决策树,并且赋予他们不同的权值,最后集合起来比较. 比如说我们可以通过是否有胡子和身高的高度这两个特征来来决定一个人的性别,很明显是否有胡子

利用机器学习算法寻找网页的缩略图

博客中的文章均为meelo原创,请务必以链接形式注明本文地址 描述一个网页 现在的世界处于一个信息爆炸的时代.微信.微博.新闻网站,每天人们在大海捞针的信息海洋里挑选自己感兴趣的信息.我们是如何判断哪条信息可能会感兴趣?回想一下,你会发现是标题.摘要和缩略图.通过标题.摘要和缩略图,就能够很好地猜测到网页的内容.打开百度搜索引擎,随便搜索一个关键字,每一条搜索结果也正是这三要素构成的. 那么一个自然的问题是搜索引擎是如何找到网页的标题.摘要和缩略图的呢. 寻找网页的标题其实是一个非常简单的问题.

机器学习系列(9)_机器学习算法一览(附Python和R代码)

本文资源翻译@酒酒Angie:伊利诺伊大学香槟分校统计学同学,大四在读,即将开始计算机的研究生学习.希望认识更多喜欢大数据和机器学习的朋友,互相交流学习. 内容校正调整:寒小阳 && 龙心尘 时间:2016年4月 出处:http://blog.csdn.net/han_xiaoyang/article/details/51191386 http://blog.csdn.net/longxinchen_ml/article/details/51192086 声明:版权所有,转载请联系作者并注