OpenCV码源笔记——Decision Tree决策树

来自OpenCV2.3.1 sample/c/mushroom.cpp

1.首先读入agaricus-lepiota.data的训练样本。

样本中第一项是e或p代表有毒或无毒的标志位;其他是特征,可以把每个样本看做一个特征向量;

cvSeqPush( seq, el_ptr );读入序列seq中,每一项都存储一个样本即特征向量;

之后,把特征向量与标志位分别读入CvMat* data与CvMat* reponses中

还有一个CvMat* missing保留丢失位当前小于0位置;

2.训练样本

[cpp] view plain copy

print?

  1. dtree = new CvDTree;
  2. dtree->train( data, CV_ROW_SAMPLE, responses, 0, 0, var_type, missing,
  3. CvDTreeParams( 8, // max depth
  4. 10, // min sample count 样本数小于10时,停止分裂
  5. 0, // regression accuracy: N/A here;回归树的限制精度
  6. true, // compute surrogate split, as we have missing data;;为真时,计算missing data和变量的重要性
  7. 15, // max number of categories (use sub-optimal algorithm for larger numbers)类型上限以保证计算速度。树会以次优分裂(suboptimal split)的形式生长。只对2种取值以上的树有意义
  8. 10, // the number of cross-validation folds;If cv_folds > 1 then prune a tree with K-fold cross-validation where K is equal to cv_folds
  9. true, // use 1SE rule => smaller tree;If true 修剪树. 这将使树更紧凑,更能抵抗训练数据噪声,但有点不太准确
  10. true, // throw away the pruned tree branches
  11. priors //错分类的代价我们判断的:有毒VS无毒 错误的代价比 the array of priors, the bigger p_weight, the more attention
  12. // to the poisonous mushrooms
  13. // (a mushroom will be judjed to be poisonous with bigger chance)
  14. ));

3.

[cpp] view plain copy

print?

  1. double r = dtree->predict( &sample, &mask )->value;//使用predict来预测样本,结果为 CvDTreeNode结构,dtree->predict(sample,mask)->value是分类情况下的类别或回归情况下的函数估计值;

4.interactive_classification通过人工输入特征来判断。

[cpp] view plain copy

print?

    1. #include "opencv2/core/core_c.h"
    2. #include "opencv2/ml/ml.hpp"
    3. #include <stdio.h>
    4. void help()
    5. {
    6. printf("\nThis program demonstrated the use of OpenCV‘s decision tree function for learning and predicting data\n"
    7. "Usage :\n"
    8. "./mushroom <path to agaricus-lepiota.data>\n"
    9. "\n"
    10. "The sample demonstrates how to build a decision tree for classifying mushrooms.\n"
    11. "It uses the sample base agaricus-lepiota.data from UCI Repository, here is the link:\n"
    12. "\n"
    13. "Newman, D.J. & Hettich, S. & Blake, C.L. & Merz, C.J. (1998).\n"
    14. "UCI Repository of machine learning databases\n"
    15. "[http://www.ics.uci.edu/~mlearn/MLRepository.html].\n"
    16. "Irvine, CA: University of California, Department of Information and Computer Science.\n"
    17. "\n"
    18. "// loads the mushroom database, which is a text file, containing\n"
    19. "// one training sample per row, all the input variables and the output variable are categorical,\n"
    20. "// the values are encoded by characters.\n\n");
    21. }
    22. int mushroom_read_database( const char* filename, CvMat** data, CvMat** missing, CvMat** responses )
    23. {
    24. const int M = 1024;
    25. FILE* f = fopen( filename, "rt" );
    26. CvMemStorage* storage;
    27. CvSeq* seq;
    28. char buf[M+2], *ptr;
    29. float* el_ptr;
    30. CvSeqReader reader;
    31. int i, j, var_count = 0;
    32. if( !f )
    33. return 0;
    34. // read the first line and determine the number of variables
    35. if( !fgets( buf, M, f ))
    36. {
    37. fclose(f);
    38. return 0;
    39. }
    40. for( ptr = buf; *ptr != ‘\0‘; ptr++ )
    41. var_count += *ptr == ‘,‘;//计算每个样本的数量,每个样本一个“,”,样本数量=var_count+1;
    42. assert( ptr - buf == (var_count+1)*2 );
    43. // create temporary memory storage to store the whole database
    44. //把样本存入seq中,存储空间是storage;
    45. el_ptr = new float[var_count+1];
    46. storage = cvCreateMemStorage();
    47. seq = cvCreateSeq( 0, sizeof(*seq), (var_count+1)*sizeof(float), storage );//
    48. for(;;)
    49. {
    50. for( i = 0; i <= var_count; i++ )
    51. {
    52. int c = buf[i*2];
    53. el_ptr[i] = c == ‘?‘ ? -1.f : (float)c;
    54. }
    55. if( i != var_count+1 )
    56. break;
    57. cvSeqPush( seq, el_ptr );
    58. if( !fgets( buf, M, f ) || !strchr( buf, ‘,‘ ) )
    59. break;
    60. }
    61. fclose(f);
    62. // allocate the output matrices and copy the base there
    63. *data = cvCreateMat( seq->total, var_count, CV_32F );//行数:样本数量;列数:样本大小;
    64. *missing = cvCreateMat( seq->total, var_count, CV_8U );
    65. *responses = cvCreateMat( seq->total, 1, CV_32F );//样本标志;
    66. cvStartReadSeq( seq, &reader );
    67. for( i = 0; i < seq->total; i++ )
    68. {
    69. const float* sdata = (float*)reader.ptr + 1;
    70. float* ddata = data[0]->data.fl + var_count*i;
    71. float* dr = responses[0]->data.fl + i;
    72. uchar* dm = missing[0]->data.ptr + var_count*i;
    73. for( j = 0; j < var_count; j++ )
    74. {
    75. ddata[j] = sdata[j];
    76. dm[j] = sdata[j] < 0;
    77. }
    78. *dr = sdata[-1];//样本的第一个位置是标志;
    79. CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
    80. }
    81. cvReleaseMemStorage( &storage );
    82. delete el_ptr;
    83. return 1;
    84. }
    85. CvDTree* mushroom_create_dtree( const CvMat* data, const CvMat* missing,
    86. const CvMat* responses, float p_weight )
    87. {
    88. CvDTree* dtree;
    89. CvMat* var_type;
    90. int i, hr1 = 0, hr2 = 0, p_total = 0;
    91. float priors[] = { 1, p_weight };
    92. var_type = cvCreateMat( data->cols + 1, 1, CV_8U );
    93. cvSet( var_type, cvScalarAll(CV_VAR_CATEGORICAL) ); // all the variables are categorical
    94. dtree = new CvDTree;
    95. dtree->train( data, CV_ROW_SAMPLE, responses, 0, 0, var_type, missing,
    96. CvDTreeParams( 8, // max depth
    97. 10, // min sample count样本数小于10时,停止分裂
    98. 0, // regression accuracy: N/A here;回归树的限制精度
    99. true, // compute surrogate split, as we have missing data;为真时,计算missing data和可变的重要性正确度
    100. 15, // max number of categories (use sub-optimal algorithm for larger numbers)类型上限以保证计算速度。树会以次优分裂(suboptimal split)的形式生长。只对2种取值以上的树有意义
    101. 10, // the number of cross-validation folds;If cv_folds > 1 then prune a tree with K-fold cross-validation
    102. true, // use 1SE rule => smaller treeIf true 修剪树. 这将使树更紧凑,更能抵抗训练数据噪声,但有点不太准确
    103. true, // throw away the pruned tree branches
    104. priors // the array of priors, the bigger p_weight, the more attention
    105. // to the poisonous mushrooms
    106. // (a mushroom will be judjed to be poisonous with bigger chance)
    107. ));
    108. // compute hit-rate on the training database, demonstrates predict usage.
    109. for( i = 0; i < data->rows; i++ )
    110. {
    111. CvMat sample, mask;
    112. cvGetRow( data, &sample, i );
    113. cvGetRow( missing, &mask, i );
    114. double r = dtree->predict( &sample, &mask )->value;//使用predict来预测样本,结果为 CvDTreeNode结构,dtree->predict(sample,mask)->value是分类情况下的类别或回归情况下的函数估计值;
    115. int d = fabs(r - responses->data.fl[i]) >= FLT_EPSILON;//大于阈值FLT_EPSILON被判断为误检
    116. if( d )
    117. {
    118. if( r != ‘p‘ )
    119. hr1++;
    120. else
    121. hr2++;
    122. }
    123. p_total += responses->data.fl[i] == ‘p‘;
    124. }
    125. printf( "Results on the training database:\n"
    126. "\tPoisonous mushrooms mis-predicted: %d (%g%%)\n"
    127. "\tFalse-alarms: %d (%g%%)\n", hr1, (double)hr1*100/p_total,
    128. hr2, (double)hr2*100/(data->rows - p_total) );
    129. cvReleaseMat( &var_type );
    130. return dtree;
    131. }
    132. static const char* var_desc[] =
    133. {
    134. "cap shape (bell=b,conical=c,convex=x,flat=f)",
    135. "cap surface (fibrous=f,grooves=g,scaly=y,smooth=s)",
    136. "cap color (brown=n,buff=b,cinnamon=c,gray=g,green=r,\n\tpink=p,purple=u,red=e,white=w,yellow=y)",
    137. "bruises? (bruises=t,no=f)",
    138. "odor (almond=a,anise=l,creosote=c,fishy=y,foul=f,\n\tmusty=m,none=n,pungent=p,spicy=s)",
    139. "gill attachment (attached=a,descending=d,free=f,notched=n)",
    140. "gill spacing (close=c,crowded=w,distant=d)",
    141. "gill size (broad=b,narrow=n)",
    142. "gill color (black=k,brown=n,buff=b,chocolate=h,gray=g,\n\tgreen=r,orange=o,pink=p,purple=u,red=e,white=w,yellow=y)",
    143. "stalk shape (enlarging=e,tapering=t)",
    144. "stalk root (bulbous=b,club=c,cup=u,equal=e,rhizomorphs=z,rooted=r)",
    145. "stalk surface above ring (ibrous=f,scaly=y,silky=k,smooth=s)",
    146. "stalk surface below ring (ibrous=f,scaly=y,silky=k,smooth=s)",
    147. "stalk color above ring (brown=n,buff=b,cinnamon=c,gray=g,orange=o,\n\tpink=p,red=e,white=w,yellow=y)",
    148. "stalk color below ring (brown=n,buff=b,cinnamon=c,gray=g,orange=o,\n\tpink=p,red=e,white=w,yellow=y)",
    149. "veil type (partial=p,universal=u)",
    150. "veil color (brown=n,orange=o,white=w,yellow=y)",
    151. "ring number (none=n,one=o,two=t)",
    152. "ring type (cobwebby=c,evanescent=e,flaring=f,large=l,\n\tnone=n,pendant=p,sheathing=s,zone=z)",
    153. "spore print color (black=k,brown=n,buff=b,chocolate=h,green=r,\n\torange=o,purple=u,white=w,yellow=y)",
    154. "population (abundant=a,clustered=c,numerous=n,\n\tscattered=s,several=v,solitary=y)",
    155. "habitat (grasses=g,leaves=l,meadows=m,paths=p\n\turban=u,waste=w,woods=d)",
    156. 0
    157. };
    158. void print_variable_importance( CvDTree* dtree, const char** var_desc )
    159. {
    160. const CvMat* var_importance = dtree->get_var_importance();
    161. int i;
    162. char input[1000];
    163. if( !var_importance )
    164. {
    165. printf( "Error: Variable importance can not be retrieved\n" );
    166. return;
    167. }
    168. printf( "Print variable importance information? (y/n) " );
    169. scanf( "%1s", input );
    170. if( input[0] != ‘y‘ && input[0] != ‘Y‘ )
    171. return;
    172. for( i = 0; i < var_importance->cols*var_importance->rows; i++ )
    173. {
    174. double val = var_importance->data.db[i];
    175. if( var_desc )
    176. {
    177. char buf[100];
    178. int len = strchr( var_desc[i], ‘(‘ ) - var_desc[i] - 1;
    179. strncpy( buf, var_desc[i], len );
    180. buf[len] = ‘\0‘;
    181. printf( "%s", buf );
    182. }
    183. else
    184. printf( "var #%d", i );
    185. printf( ": %g%%\n", val*100. );
    186. }
    187. }
    188. void interactive_classification( CvDTree* dtree, const char** var_desc )
    189. {
    190. char input[1000];
    191. const CvDTreeNode* root;
    192. CvDTreeTrainData* data;
    193. if( !dtree )
    194. return;
    195. root = dtree->get_root();
    196. data = dtree->get_data();
    197. for(;;)
    198. {
    199. const CvDTreeNode* node;
    200. printf( "Start/Proceed with interactive mushroom classification (y/n): " );
    201. scanf( "%1s", input );
    202. if( input[0] != ‘y‘ && input[0] != ‘Y‘ )
    203. break;
    204. printf( "Enter 1-letter answers, ‘?‘ for missing/unknown value...\n" );
    205. // custom version of predict
    206. //传统的预测方式;
    207. node = root;
    208. for(;;)
    209. {
    210. CvDTreeSplit* split = node->split;
    211. int dir = 0;
    212. if( !node->left || node->Tn <= dtree->get_pruned_tree_idx() || !node->split )
    213. break;
    214. for( ; split != 0; )
    215. {
    216. int vi = split->var_idx, j;
    217. int count = data->cat_count->data.i[vi];
    218. const int* map = data->cat_map->data.i + data->cat_ofs->data.i[vi];
    219. printf( "%s: ", var_desc[vi] );
    220. scanf( "%1s", input );
    221. if( input[0] == ‘?‘ )
    222. {
    223. split = split->next;
    224. continue;
    225. }
    226. // convert the input character to the normalized value of the variable
    227. for( j = 0; j < count; j++ )
    228. if( map[j] == input[0] )
    229. break;
    230. if( j < count )
    231. {
    232. dir = (split->subset[j>>5] & (1 << (j&31))) ? -1 : 1;
    233. if( split->inversed )
    234. dir = -dir;
    235. break;
    236. }
    237. else
    238. printf( "Error: unrecognized value\n" );
    239. }
    240. if( !dir )
    241. {
    242. printf( "Impossible to classify the sample\n");
    243. node = 0;
    244. break;
    245. }
    246. node = dir < 0 ? node->left : node->right;
    247. }
    248. if( node )
    249. printf( "Prediction result: the mushroom is %s\n",
    250. node->class_idx == 0 ? "EDIBLE" : "POISONOUS" );
    251. printf( "\n-----------------------------\n" );
    252. }
    253. }
    254. int main( int argc, char** argv )
    255. {
    256. CvMat *data = 0, *missing = 0, *responses = 0;
    257. CvDTree* dtree;
    258. const char* base_path = argc >= 2 ? argv[1] : "agaricus-lepiota.data";
    259. help();
    260. if( !mushroom_read_database( base_path, &data, &missing, &responses ) )
    261. {
    262. printf( "\nUnable to load the training database\n\n");
    263. help();
    264. return -1;
    265. }
    266. dtree = mushroom_create_dtree( data, missing, responses,
    267. 10 // poisonous mushrooms will have 10x higher weight in the decision tree
    268. );
    269. cvReleaseMat( &data );
    270. cvReleaseMat( &missing );
    271. cvReleaseMat( &responses );
    272. print_variable_importance( dtree, var_desc );
    273. interactive_classification( dtree, var_desc );
    274. delete dtree;
    275. return 0;
    276. }
    277. //from: http://blog.csdn.net/yangtrees/article/details/7490852
时间: 2024-11-08 21:34:22

OpenCV码源笔记——Decision Tree决策树的相关文章

OpenCV码源笔记——RandomTrees (二)(Forest)

源码细节: ● 训练函数 bool CvRTrees::train( const CvMat* _train_data, int _tflag,                        const CvMat* _responses, const CvMat* _var_idx,                        const CvMat* _sample_idx, const CvMat* _var_type,                        const CvMa

OpenCV码源笔记——RandomTrees (一)

OpenCV2.3中Random Trees(R.T.)的继承结构: API: CvRTParams 定义R.T.训练用参数,CvDTreeParams的扩展子类,但并不用到CvDTreeParams(单一决策树)所需的所有参数.比如说,R.T.通常不需要剪枝,因此剪枝参数就不被用到.max_depth  单棵树所可能达到的最大深度min_sample_count  树节点持续分裂的最小样本数量,也就是说,小于这个数节点就不持续分裂,变成叶子了regression_accuracy  回归树的终

决策树Decision Tree 及实现

Decision Tree 及实现 标签: 决策树熵信息增益分类有监督 2014-03-17 12:12 15010人阅读 评论(41) 收藏 举报  分类: Data Mining(25)  Python(24)  Machine Learning(46)  版权声明:本文为博主原创文章,未经博主允许不得转载. 本文基于python逐步实现Decision Tree(决策树),分为以下几个步骤: 加载数据集 熵的计算 根据最佳分割feature进行数据分割 根据最大信息增益选择最佳分割feat

机器学习分类实例——SVM(修改)/Decision Tree/Naive Bayes

机器学习分类实例--SVM(修改)/Decision Tree/Naive Bayes 20180427-28笔记.30总结 已经5月了,毕设告一段落了,该准备论文了.前天开会老师说,希望我以后做关于语义分析那一块内容,会议期间还讨论了学姐的知识图谱的知识推理内容,感觉也挺有趣的,但是感觉应该会比较复杂.有时间的话希望对这块了解一下.其实吧,具体怎么展示我还是不太清楚... 大概就是图表那个样子.我先做一个出来,让老师看看,两个礼拜写论文.24/25答辩,6月就可以去浪哈哈哈哈哈哈. 一.工作

【3】Decision tree(决策树)

前言 Decision tree is one of the most popular classification tools 它用一个训练数据集学到一个映射,该映射以未知类别的新实例作为输入,输出对这个实例类别的预测. 决策树相当于将一系列问题组织成树,具体说,每个问题对应一个属性,根据属性值来生成判断分支,一直到决策树的叶节点就产生了类别. 那么,接下来的问题就是怎么选择最佳的属性作为当前的判断分支,这就引出了用信息论划分数据集的方式. 在信息论中,划分数据之前和之后信息发生的信息变化成为

决策树(decision tree)

决策树 ID3,C4.5,CART,决策树的生成,剪枝. 一.概述 决策树(decision tree)是一种基本的分类与回归方法(这里是分类的决策树).决策树模型呈树形结构,在分类问题中,表示基于特征对实例进行分类的过程.它可以认为是if-then规则的集合,也可以认为是定义在特征空间与类空间上的条件概率分布.其主要优点是模型具有可读性,分类速度快.学习时,利用训练数据,根据损失函数最小化的原则建立决策树模型.预测时,对新的数据利用决策树模型进行分类.决策树学习通常包括三个步骤:特征选择.决策

CI框架源码阅读笔记3 全局函数Common.php

从本篇开始,将深入CI框架的内部,一步步去探索这个框架的实现.结构和设计. Common.php文件定义了一系列的全局函数(一般来说,全局函数具有最高的加载优先权,因此大多数的框架中BootStrap引导文件都会最先引入全局函数,以便于之后的处理工作). 打开Common.php中,第一行代码就非常诡异: if ( ! defined('BASEPATH')) exit('No direct script access allowed'); 上一篇(CI框架源码阅读笔记2 一切的入口 index

机器学习中的算法:决策树模型组合之GBDT(Gradient Boost Decision Tree)

[转载自:http://www.cnblogs.com/LeftNotEasy/archive/2011/03/07/random-forest-and-gbdt.html] 前言 决策树这种算法有着很多良好的特性,比如说训练时间复杂度较低,预测的过程比较快速,模型容易展示(容易将得到的决策树做成图片展示出来)等.但是同时,单决策树又有一些不好的地方,比如说容易over-fitting,虽然有一些方法,如剪枝可以减少这种情况,但是还是不够的. 模型组合(比如说有Boosting,Bagging等

源码阅读笔记 - 1 MSVC2015中的std::sort

大约寒假开始的时候我就已经把std::sort的源码阅读完毕并理解其中的做法了,到了寒假结尾,姑且把它写出来 这是我的第一篇源码阅读笔记,以后会发更多的,包括算法和库实现,源码会按照我自己的代码风格格式化,去掉或者展开用于条件编译或者debug检查的宏,依重要程度重新排序函数,但是不会改变命名方式(虽然MSVC的STL命名实在是我不能接受的那种),对于代码块的解释会在代码块前(上面)用注释标明. template<class _RanIt, class _Diff, class _Pr> in