数据挖掘之决策树ID3算法(C#实现)

决策树是一种非常经典的分类器,它的作用原理有点类似于我们玩的猜谜游戏。比如猜一个动物:

问:这个动物是陆生动物吗?

答:是的。

问:这个动物有鳃吗?

答:没有。

这样的两个问题顺序就有些颠倒,因为一般来说陆生动物是没有鳃的(记得应该是这样的,如有错误欢迎指正)。所以玩这种游戏,提问的顺序很重要,争取每次都能够获得尽可能多的信息量。

AllElectronics顾客数据库标记类的训练元组
RID age income student credit_rating Class: buys_computer
1 youth high no fair no
2 youth high no excellent no
3 middle_aged high no fair yes
4 senior medium no fair yes
5 senior low yes fair yes
6 senior low yes excellent no
7 middle_aged low yes excellent yes
8 youth medium no fair no
9 youth low yes fair yes
10 senior medium yes fair yes
11 youth medium yes excellent yes
12 middle_aged medium no excellent yes
13 middle_aged high yes fair yes
14 senior medium no excellent no

以AllElectronics顾客数据库标记类的训练元组为例。我们想要以这些样本为训练集,训练我们的决策树模型,以此来挖掘出顾客是否会购买电脑的决策模式。

在决策树ID3算法中,计算信息度的公式如下:

$$Info_A(D) = \sum_{j=1}^v\frac{|D_j|}{D} \times Info(D_j)$$

计算信息增益的公式如下:

$$Gain(A) = Info(D) - Info_A(D)$$

按照公式,在要进行分类的类别变量中,有5个“no”和9个“yes”,因此期望信息为:

$$Info(D)=-\frac{9}{14}log_2\frac{9}{14}-\frac{5}{14}log_2\frac{5}{14}=0.940$$

首先计算特征age的期望信息:

$$Info_{age}(D)=\frac{5}{14} \times (-\frac{2}{5}log_2\frac{2}{5} - \frac{3}{5}log_2\frac{3}{5})+\frac{4}{14} \times (-\frac{4}{4}log_2\frac{4}{4} - \frac{0}{4}log_2\frac{0}{4})+\frac{5}{14} \times (-\frac{3}{5}log_2\frac{3}{5} - \frac{2}{5}log_2\frac{2}{5})$$

因此,如果按照age进行划分,则获得的信息增益为:

$$Gain(age) = Info(D)-Info_{age}(D) = 0.940-0.694=0.246$$

依次计算以income、student和credit_rating来分裂的信息增益,由此选择能够带来最大信息增益的变量,在当

前结点选择以以该变量的取值进行分裂。递归地进行执行即可生成决策树。更加详细的内容可以参考:

https://en.wikipedia.org/wiki/Decision_tree

C#代码的实现如下:

  1 using System;
  2 using System.Collections.Generic;
  3 using System.Linq;
  4 namespace MachineLearning.DecisionTree
  5 {
  6     public class DecisionTreeID3<T> where T : IEquatable<T>
  7     {
  8         T[,] Data;
  9         string[] Names;
 10         int Category;
 11         T[] CategoryLabels;
 12         DecisionTreeNode<T> Root;
 13         public DecisionTreeID3(T[,] data, string[] names, T[] categoryLabels)
 14         {
 15             Data = data;
 16             Names = names;
 17             Category = data.GetLength(1) - 1;//类别变量需要放在最后一列
 18             CategoryLabels = categoryLabels;
 19         }
 20         public void Learn()
 21         {
 22             int nRows = Data.GetLength(0);
 23             int nCols = Data.GetLength(1);
 24             int[] rows = new int[nRows];
 25             int[] cols = new int[nCols];
 26             for (int i = 0; i < nRows; i++) rows[i] = i;
 27             for (int i = 0; i < nCols; i++) cols[i] = i;
 28             Root = new DecisionTreeNode<T>(-1, default(T));
 29             Learn(rows, cols, Root);
 30             DisplayNode(Root);
 31         }
 32         public void DisplayNode(DecisionTreeNode<T> Node, int depth = 0)
 33         {
 34             if (Node.Label != -1)
 35                 Console.WriteLine("{0} {1}: {2}", new string(‘-‘, depth * 3), Names[Node.Label], Node.Value);
 36             foreach (var item in Node.Children)
 37                 DisplayNode(item, depth + 1);
 38         }
 39         private void Learn(int[] pnRows, int[] pnCols, DecisionTreeNode<T> Root)
 40         {
 41             var categoryValues = GetAttribute(Data, Category, pnRows);
 42             var categoryCount = categoryValues.Distinct().Count();
 43             if (categoryCount == 1)
 44             {
 45                 var node = new DecisionTreeNode<T>(Category, categoryValues.First());
 46                 Root.Children.Add(node);
 47             }
 48             else
 49             {
 50                 if (pnRows.Length == 0) return;
 51                 else if (pnCols.Length == 1)
 52                 {
 53                     //投票~
 54                     //多数票表决制
 55                     var Vote = categoryValues.GroupBy(i => i).OrderBy(i => i.Count()).First();
 56                     var node = new DecisionTreeNode<T>(Category, Vote.First());
 57                     Root.Children.Add(node);
 58                 }
 59                 else
 60                 {
 61                     var maxCol = MaxEntropy(pnRows, pnCols);
 62                     var attributes = GetAttribute(Data, maxCol, pnRows).Distinct();
 63                     string currentPrefix = Names[maxCol];
 64                     foreach (var attr in attributes)
 65                     {
 66                         int[] rows = pnRows.Where(irow => Data[irow, maxCol].Equals(attr)).ToArray();
 67                         int[] cols = pnCols.Where(i => i != maxCol).ToArray();
 68                         var node = new DecisionTreeNode<T>(maxCol, attr);
 69                         Root.Children.Add(node);
 70                         Learn(rows, cols, node);//递归生成决策树
 71                     }
 72                 }
 73             }
 74         }
 75         public double AttributeInfo(int attrCol, int[] pnRows)
 76         {
 77             var tuples = AttributeCount(attrCol, pnRows);
 78             var sum = (double)pnRows.Length;
 79             double Entropy = 0.0;
 80             foreach (var tuple in tuples)
 81             {
 82                 int[] count = new int[CategoryLabels.Length];
 83                 foreach (var irow in pnRows)
 84                     if (Data[irow, attrCol].Equals(tuple.Item1))
 85                     {
 86                         int index = Array.IndexOf(CategoryLabels, Data[irow, Category]);
 87                         count[index]++;
 88                     }
 89                 double k = 0.0;
 90                 for (int i = 0; i < count.Length; i++)
 91                 {
 92                     double frequency = count[i] / (double)tuple.Item2;
 93                     double t = -frequency * Log2(frequency);
 94                     k += t;
 95                 }
 96                 double freq = tuple.Item2 / sum;
 97                 Entropy += freq * k;
 98             }
 99             return Entropy;
100         }
101         public double CategoryInfo(int[] pnRows)
102         {
103             var tuples = AttributeCount(Category, pnRows);
104             var sum = (double)pnRows.Length;
105             double Entropy = 0.0;
106             foreach (var tuple in tuples)
107             {
108                 double frequency = tuple.Item2 / sum;
109                 double t = -frequency * Log2(frequency);
110                 Entropy += t;
111             }
112             return Entropy;
113         }
114         private static IEnumerable<T> GetAttribute(T[,] data, int col, int[] pnRows)
115         {
116             foreach (var irow in pnRows)
117                 yield return data[irow, col];
118         }
119         private static double Log2(double x)
120         {
121             return x == 0.0 ? 0.0 : Math.Log(x, 2.0);
122         }
123         public int MaxEntropy(int[] pnRows, int[] pnCols)
124         {
125             double cateEntropy = CategoryInfo(pnRows);
126             int maxAttr = 0;
127             double max = double.MinValue;
128             foreach (var icol in pnCols)
129                 if (icol != Category)
130                 {
131                     double Gain = cateEntropy - AttributeInfo(icol, pnRows);
132                     if (max < Gain)
133                     {
134                         max = Gain;
135                         maxAttr = icol;
136                     }
137                 }
138             return maxAttr;
139         }
140         public IEnumerable<Tuple<T, int>> AttributeCount(int col, int[] pnRows)
141         {
142             var tuples = from n in GetAttribute(Data, col, pnRows)
143                          group n by n into i
144                          select Tuple.Create(i.First(), i.Count());
145             return tuples;
146         }
147     }
148 }

调用方法如下:

 1 using System;
 2 using System.Collections.Generic;
 3 using System.Linq;
 4 using System.Text;
 5 using System.Threading.Tasks;
 6 using MachineLearning.DecisionTree;
 7 namespace MachineLearning
 8 {
 9     class Program
10     {
11         static void Main(string[] args)
12         {
13             var data = new string[,]
14             {
15                 {"youth","high","no","fair","no"},
16                 {"youth","high","no","excellent","no"},
17                 {"middle_aged","high","no","fair","yes"},
18                 {"senior","medium","no","fair","yes"},
19                 {"senior","low","yes","fair","yes"},
20                 {"senior","low","yes","excellent","no"},
21                 {"middle_aged","low","yes","excellent","yes"},
22                 {"youth","medium","no","fair","no"},
23                 {"youth","low","yes","fair","yes"},
24                 {"senior","medium","yes","fair","yes"},
25                 {"youth","medium","yes","excellent","yes"},
26                 {"middle_aged","medium","no","excellent","yes"},
27                 {"middle_aged","high","yes","fair","yes"},
28                 {"senior","medium","no","excellent","no"}
29             };
30             var names = new string[] { "age", "income", "student", "credit_rating", "Class: buys_computer" };
31             var tree = new DecisionTreeID3<string>(data, names, new string[] { "yes", "no" });
32             tree.Learn();
33             Console.ReadKey();
34         }
35     }
36 }

运行结果:

注:作者本人也在学习中,能力有限,如有错漏还请不吝指正。转载请注明作者。

时间: 2024-08-03 07:59:24

数据挖掘之决策树ID3算法(C#实现)的相关文章

决策树ID3算法[分类算法]

ID3分类算法的编码实现 1 <?php 2 /* 3 *决策树ID3算法(分类算法的实现) 4 */ 5 6 /* 7 *把.txt中的内容读到数组中保存 8 *$filename:文件名称 9 */ 10 11 //-------------------------------------------------------------------- 12 function gerFileContent($filename) 13 { 14 $array = array(NULL); 15

决策树---ID3算法(介绍及Python实现)

决策树---ID3算法   决策树: 以天气数据库的训练数据为例. Outlook Temperature Humidity Windy PlayGolf? sunny 85 85 FALSE no sunny 80 90 TRUE no overcast 83 86 FALSE yes rainy 70 96 FALSE yes rainy 68 80 FALSE yes rainy 65 70 TRUE no overcast 64 65 TRUE yes sunny 72 95 FALSE

机器学习—— 决策树(ID3算法)的分析与实现

KNN算法请参考:http://blog.csdn.net/gamer_gyt/article/details/47418223 一.简介         决策树是一个预测模型:他代表的是对象属性与对象值之间的一种映射关系.树中每个节点表示某个对象,而每个分叉路径则代表的某个可能的属性值,而每个叶结点则对应从根节点到该叶节点所经历的路径所表示的对象的值.决策树仅有单一输出,若欲有复数输出,可以建立独立的决策树以处理不同输出. 数据挖掘中决策树是一种经常要用到的技术,可以用于分析数据,同样也可以用

决策树ID3算法的java实现(基本试用所有的ID3)

已知:流感训练数据集,预定义两个类别: 求:用ID3算法建立流感的属性描述决策树 流感训练数据集 No. 头痛 肌肉痛 体温 患流感 1 是(1) 是(1) 正常(0) 否(0) 2 是(1) 是(1) 高(1) 是(1) 3 是(1) 是(1) 很高(2) 是(1) 4 否(0) 是(1) 正常(0) 否(0) 5 否(0) 否(0) 高(1) 否(0) 6 否(0) 是(1) 很高(2) 是(1) 7 是(1) 否(0) 高(1) 是(1) 原理分析: 在决策树的每一个非叶子结点划分之前,先

决策树 -- ID3算法小结

ID3算法(Iterative Dichotomiser 3 迭代二叉树3代),是一个由Ross Quinlan发明的用于决策树的算法:简单理论是越是小型的决策树越优于大的决策树. 算法归纳: 1.使用所有没有使用的属性并计算与之相关的样本熵值: 2.选取其中熵值最小的属性 3.生成包含该属性的节点 4.使用新的分支表继续前面步骤 ID3算法以信息论为基础,以信息熵和信息增益为衡量标准,从而实现对数据的归纳分类:所以归根结底,是为了从一堆数据中生成决策树而采取的一种归纳方式: 具体介绍: 1.信

决策树ID3算法的java实现

决策树的分类过程和人的决策过程比较相似,就是先挑“权重”最大的那个考虑,然后再往下细分.比如你去看医生,症状是流鼻涕,咳嗽等,那么医生就会根据你的流鼻涕这个权重最大的症状先认为你是感冒,接着再根据你咳嗽等症状细分你是否为病毒性感冒等等.决策树的过程其实也是基于极大似然估计.那么我们用一个什么标准来衡量某个特征是权重最大的呢,这里有信息增益和基尼系数两个.ID3算法采用的是信息增益这个量. 根据<统计学习方法>中的描述,G(D,A)表示数据集D在特征A的划分下的信息增益.具体公式: G(D,A)

决策树ID3算法实现,详细分步实现,参考机器学习实战

一.编写计算历史数据的经验熵函数 In [1]: from math import log def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCounts = {} for elem in dataSet: #遍历数据集中每条样本的类别标签,统计每类标签的数量 currentLabel = elem[-1] if currentLabel not in labelCounts.keys(): #如果当前标签不在字典的key值中

初识分类算法(2)------决策树ID3算法

例子:分类:play or not ?(是/否)             目的:根据训练样本集S构建出一个决策树,然后未知分类样本通过决策树就得出分类. 问题:怎么构建决策树,从哪个节点开始(选择划分属性的问题) 方法:ID3(信息增益),C4.5(信息增益率),它们都是用来衡量给定属性区分训练样例的能力. 1. 为了理解信息增益,先来理解熵Entropy.熵是来度量样本的均一性的,即样本的纯度.样本的熵越大,代表信息的不确定性越大,信息就越混杂. 训练样本S:     S相对与目标分类的熵为:

决策树ID3算法预测隐形眼睛类型--python实现

本节讲解如何预测患者需要佩戴的隐形眼镜类型. 1.使用决策树预测隐形眼镜类型的一般流程 (1)收集数据:提供的文本文件(数据来源于UCI数据库) (2)准备数据:解析tab键分隔的数据行 (3)分析数据:快速检查数据,确保正确地解析数据内容,使用createPlot()函数绘制最终的树形图 (4)训练算法:createTree()函数 (5)测试算法:编写测试函数验证决策树可以正确分类给定的数据实例 (6)使用算法:存储数的数据结构,以使下次使用时无需重新构造树 trees.py如下: #!/u