libsvm代码阅读:关于Solver类分析(一)(转)

如果你看完了上篇博文的伪代码,那么我们就可以开始谈谈它的源代码了。

下面先贴出它的类定义,一些成员函数的具体实现先忽略。

[cpp]   view plain copy  

<EMBED id=ZeroClipboardMovie_1 height=18 name=ZeroClipboardMovie_1 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=1&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">

  1. // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918
  2. // Solves:
  3. //  min 0.5(\alpha^T Q \alpha) + p^T \alpha
  4. //
  5. //      y^T \alpha = \delta
  6. //      y_i = +1 or -1
  7. //      0 <= alpha_i <= Cp for y_i = 1
  8. //      0 <= alpha_i <= Cn for y_i = -1
  9. //
  10. // Given:
  11. //  Q, p, y, Cp, Cn, and an initial feasible point \alpha
  12. //  l is the size of vectors and matrices
  13. //  eps is the stopping tolerance
  14. // solution will be put in \alpha, objective value will be put in obj
  15. //
  16. class Solver {
  17. public:
  18. Solver() {};
  19. virtual ~Solver() {};//用虚析构函数的原因是:保证根据实际运行适当的析构函数
  20. struct SolutionInfo {
  21. double obj;
  22. double rho;
  23. double upper_bound_p;
  24. double upper_bound_n;
  25. double r;   // for Solver_NU
  26. };
  27. void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
  28. double *alpha_, double Cp, double Cn, double eps,
  29. SolutionInfo* si, int shrinking);
  30. protected:
  31. int active_size;//计算时实际参加运算的样本数目,经过shrink处理后,该数目小于全部样本数
  32. schar *y;       //样本所属类别,该值只能取-1或+1。
  33. double *G;      // gradient of objective function = (Q alpha + p)
  34. enum { LOWER_BOUND, UPPER_BOUND, FREE };
  35. char *alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE
  36. double *alpha;      //
  37. const QMatrix *Q;
  38. const double *QD;
  39. double eps;     //误差限
  40. double Cp,Cn;
  41. double *p;
  42. int *active_set;
  43. double *G_bar;      // gradient, if we treat free variables as 0
  44. int l;
  45. bool unshrink;  // XXX
  46. //返回对应于样本的C。设置不同的Cp和Cn是为了处理数据的不平衡
  47. double get_C(int i)
  48. {
  49. return (y[i] > 0)? Cp : Cn;
  50. }
  51. void update_alpha_status(int i)
  52. {
  53. if(alpha[i] >= get_C(i))
  54. alpha_status[i] = UPPER_BOUND;
  55. else if(alpha[i] <= 0)
  56. alpha_status[i] = LOWER_BOUND;
  57. else alpha_status[i] = FREE;
  58. }
  59. bool is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }
  60. bool is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }
  61. bool is_free(int i) { return alpha_status[i] == FREE; }
  62. void swap_index(int i, int j);//交换样本i和j的内容,包括申请的内存的地址
  63. void reconstruct_gradient();  //重新计算梯度。
  64. virtual int select_working_set(int &i, int &j);//选择工作集
  65. virtual double calculate_rho();
  66. virtual void do_shrinking();//对样本集做缩减。
  67. private:
  68. bool be_shrunk(int i, double Gmax1, double Gmax2);
  69. };

下面我们来看看SMO如何选择工作集(working set B),选择的约束如下:

[cpp]   view plain copy  

<EMBED id=ZeroClipboardMovie_2 height=18 name=ZeroClipboardMovie_2 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=2&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">

  1. // return i,j such that
  2. // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
  3. // j: minimizes the decrease of obj value
  4. //    (if quadratic coefficeint <= 0, replace it with tau)
  5. //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)

论文中的公式如下:

[cpp]   view plain copy  

<EMBED id=ZeroClipboardMovie_3 height=18 name=ZeroClipboardMovie_3 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=3&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">

  1. int Solver::select_working_set(int &out_i, int &out_j)
  2. {
  3. // return i,j such that
  4. // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
  5. // j: minimizes the decrease of obj value
  6. //    (if quadratic coefficeint <= 0, replace it with tau)
  7. //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
  8. //select i
  9. double Gmax = -INF;
  10. double Gmax2 = -INF;
  11. int Gmax_idx = -1;
  12. int Gmin_idx = -1;
  13. double obj_diff_min = INF;
  14. for(int t=0;t<active_size;t++)
  15. if(y[t]==+1)    //若类别为1
  16. {
  17. if(!is_upper_bound(t))//若alpha<C
  18. if(-G[t] >= Gmax)
  19. {
  20. Gmax = -G[t];// -y[t]*G[t]=-1*G[t]
  21. Gmax_idx = t;
  22. }
  23. }
  24. else
  25. {
  26. if(!is_lower_bound(t))
  27. if(G[t] >= Gmax)
  28. {
  29. Gmax = G[t];
  30. Gmax_idx = t;
  31. }
  32. }
  33. int i = Gmax_idx;
  34. const Qfloat *Q_i = NULL;
  35. if(i != -1) // NULL Q_i not accessed: Gmax=-INF if i=-1
  36. Q_i = Q->get_Q(i,active_size);
  37. //select j
  38. for(int j=0;j<active_size;j++)
  39. {
  40. if(y[j]==+1)
  41. {
  42. if (!is_lower_bound(j))
  43. {
  44. double grad_diff=Gmax+G[j];
  45. if (G[j] >= Gmax2)
  46. Gmax2 = G[j];
  47. if (grad_diff > 0)
  48. {
  49. double obj_diff;
  50. double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j];
  51. if (quad_coef > 0)
  52. obj_diff = -(grad_diff*grad_diff)/quad_coef;
  53. else
  54. obj_diff = -(grad_diff*grad_diff)/TAU;
  55. if (obj_diff <= obj_diff_min)
  56. {
  57. Gmin_idx=j;
  58. obj_diff_min = obj_diff;
  59. }
  60. }
  61. }
  62. }
  63. else
  64. {
  65. if (!is_upper_bound(j))
  66. {
  67. double grad_diff= Gmax-G[j];
  68. if (-G[j] >= Gmax2)
  69. Gmax2 = -G[j];
  70. if (grad_diff > 0)
  71. {
  72. double obj_diff;
  73. double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j];
  74. if (quad_coef > 0)
  75. obj_diff = -(grad_diff*grad_diff)/quad_coef;
  76. else
  77. obj_diff = -(grad_diff*grad_diff)/TAU;
  78. if (obj_diff <= obj_diff_min)
  79. {
  80. Gmin_idx=j;
  81. obj_diff_min = obj_diff;
  82. }
  83. }
  84. }
  85. }
  86. }
  87. if(Gmax+Gmax2 < eps)
  88. return 1;
  89. out_i = Gmax_idx;
  90. out_j = Gmin_idx;
  91. return 0;
  92. }

配合上面几个公式看,这段代码还是很清晰了。

下面来看看它的构造函数,这个构造函数是solver类的核心。这个算法也结合上一篇博文的algorithm2来看。其中要注意的是get_Q是获取核函数。

[cpp]   view plain copy  

<EMBED id=ZeroClipboardMovie_4 height=18 name=ZeroClipboardMovie_4 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=4&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">

  1. void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
  2. double *alpha_, double Cp, double Cn, double eps,
  3. SolutionInfo* si, int shrinking)
  4. {
  5. this->l = l;
  6. this->Q = &Q;
  7. QD=Q.get_QD();//这个是获取核函数(如果分类的话在SVC_Q中定义)
  8. clone(p, p_,l);
  9. clone(y, y_,l);
  10. clone(alpha,alpha_,l);
  11. this->Cp = Cp;
  12. this->Cn = Cn;
  13. this->eps = eps;
  14. unshrink = false;
  15. // initialize alpha_status
  16. {
  17. alpha_status = new char[l];
  18. for(int i=0;i<l;i++)
  19. update_alpha_status(i);
  20. }
  21. // initialize active set (for shrinking)
  22. {
  23. active_set = new int[l];
  24. for(int i=0;i<l;i++)
  25. active_set[i] = i;
  26. active_size = l;
  27. }
  28. // initialize gradient
  29. {
  30. G = new double[l];
  31. G_bar = new double[l];
  32. int i;
  33. for(i=0;i<l;i++)
  34. {
  35. G[i] = p[i];
  36. G_bar[i] = 0;
  37. }
  38. for(i=0;i<l;i++)
  39. if(!is_lower_bound(i))
  40. {
  41. const Qfloat *Q_i = Q.get_Q(i,l);
  42. double alpha_i = alpha[i];
  43. int j;
  44. for(j=0;j<l;j++)
  45. G[j] += alpha_i*Q_i[j];
  46. if(is_upper_bound(i))
  47. for(j=0;j<l;j++)
  48. G_bar[j] += get_C(i) * Q_i[j]; //这里见文献LIBSVM: A Library for SVM公式(33)
  49. }
  50. }
  51. // optimization step
  52. int iter = 0;
  53. int max_iter = max(10000000, l>INT_MAX/100 ? INT_MAX : 100*l);
  54. int counter = min(l,1000)+1;
  55. while(iter < max_iter)
  56. {
  57. // show progress and do shrinking
  58. if(--counter == 0)
  59. {
  60. counter = min(l,1000);
  61. if(shrinking) do_shrinking();
  62. info(".");
  63. }
  64. int i,j;
  65. if(select_working_set(i,j)!=0)
  66. {
  67. // reconstruct the whole gradient
  68. reconstruct_gradient();
  69. // reset active set size and check
  70. active_size = l;
  71. info("*");
  72. if(select_working_set(i,j)!=0)
  73. break;
  74. else
  75. counter = 1;    // do shrinking next iteration
  76. }
  77. ++iter;
  78. // update alpha[i] and alpha[j], handle bounds carefully
  79. const Qfloat *Q_i = Q.get_Q(i,active_size);
  80. const Qfloat *Q_j = Q.get_Q(j,active_size);
  81. double C_i = get_C(i);
  82. double C_j = get_C(j);
  83. double old_alpha_i = alpha[i];
  84. double old_alpha_j = alpha[j];
  85. if(y[i]!=y[j])
  86. {
  87. double quad_coef = QD[i]+QD[j]+2*Q_i[j];
  88. if (quad_coef <= 0)
  89. quad_coef = TAU;
  90. double delta = (-G[i]-G[j])/quad_coef;
  91. double diff = alpha[i] - alpha[j];
  92. alpha[i] += delta;
  93. alpha[j] += delta;
  94. if(diff > 0)
  95. {
  96. if(alpha[j] < 0)
  97. {
  98. alpha[j] = 0;
  99. alpha[i] = diff;
  100. }
  101. }
  102. else
  103. {
  104. if(alpha[i] < 0)
  105. {
  106. alpha[i] = 0;
  107. alpha[j] = -diff;
  108. }
  109. }
  110. if(diff > C_i - C_j)
  111. {
  112. if(alpha[i] > C_i)
  113. {
  114. alpha[i] = C_i;
  115. alpha[j] = C_i - diff;
  116. }
  117. }
  118. else
  119. {
  120. if(alpha[j] > C_j)
  121. {
  122. alpha[j] = C_j;
  123. alpha[i] = C_j + diff;
  124. }
  125. }
  126. }
  127. else
  128. {
  129. double quad_coef = QD[i]+QD[j]-2*Q_i[j];
  130. if (quad_coef <= 0)
  131. quad_coef = TAU;
  132. double delta = (G[i]-G[j])/quad_coef;
  133. double sum = alpha[i] + alpha[j];
  134. alpha[i] -= delta;
  135. alpha[j] += delta;
  136. if(sum > C_i)
  137. {
  138. if(alpha[i] > C_i)
  139. {
  140. alpha[i] = C_i;
  141. alpha[j] = sum - C_i;
  142. }
  143. }
  144. else
  145. {
  146. if(alpha[j] < 0)
  147. {
  148. alpha[j] = 0;
  149. alpha[i] = sum;
  150. }
  151. }
  152. if(sum > C_j)
  153. {
  154. if(alpha[j] > C_j)
  155. {
  156. alpha[j] = C_j;
  157. alpha[i] = sum - C_j;
  158. }
  159. }
  160. else
  161. {
  162. if(alpha[i] < 0)
  163. {
  164. alpha[i] = 0;
  165. alpha[j] = sum;
  166. }
  167. }
  168. }
  169. // update G
  170. double delta_alpha_i = alpha[i] - old_alpha_i;
  171. double delta_alpha_j = alpha[j] - old_alpha_j;
  172. for(int k=0;k<active_size;k++)
  173. {
  174. G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
  175. }
  176. // update alpha_status and G_bar
  177. {
  178. bool ui = is_upper_bound(i);
  179. bool uj = is_upper_bound(j);
  180. update_alpha_status(i);
  181. update_alpha_status(j);
  182. int k;
  183. if(ui != is_upper_bound(i))
  184. {
  185. Q_i = Q.get_Q(i,l);
  186. if(ui)
  187. for(k=0;k<l;k++)
  188. G_bar[k] -= C_i * Q_i[k];
  189. else
  190. for(k=0;k<l;k++)
  191. G_bar[k] += C_i * Q_i[k];
  192. }
  193. if(uj != is_upper_bound(j))
  194. {
  195. Q_j = Q.get_Q(j,l);
  196. if(uj)
  197. for(k=0;k<l;k++)
  198. G_bar[k] -= C_j * Q_j[k];
  199. else
  200. for(k=0;k<l;k++)
  201. G_bar[k] += C_j * Q_j[k];
  202. }
  203. }
  204. }
  205. if(iter >= max_iter)
  206. {
  207. if(active_size < l)
  208. {
  209. // reconstruct the whole gradient to calculate objective value
  210. reconstruct_gradient();
  211. active_size = l;
  212. info("*");
  213. }
  214. fprintf(stderr,"\nWARNING: reaching max number of iterations\n");
  215. }
  216. // calculate rho
  217. si->rho = calculate_rho();
  218. // calculate objective value
  219. {
  220. double v = 0;
  221. int i;
  222. for(i=0;i<l;i++)
  223. v += alpha[i] * (G[i] + p[i]);
  224. si->obj = v/2;
  225. }
  226. // put back the solution
  227. {
  228. for(int i=0;i<l;i++)
  229. alpha_[active_set[i]] = alpha[i];
  230. }
  231. // juggle everything back
  232. /*{
  233. for(int i=0;i<l;i++)
  234. while(active_set[i] != i)
  235. swap_index(i,active_set[i]);
  236. // or Q.swap_index(i,active_set[i]);
  237. }*/
  238. si->upper_bound_p = Cp;
  239. si->upper_bound_n = Cn;
  240. info("\noptimization finished, #iter = %d\n",iter);
  241. delete[] p;
  242. delete[] y;
  243. delete[] alpha;
  244. delete[] alpha_status;
  245. delete[] active_set;
  246. delete[] G;
  247. delete[] G_bar;
  248. }
时间: 2024-10-07 19:11:09

libsvm代码阅读:关于Solver类分析(一)(转)的相关文章

libsvm代码阅读:关于Solver类分析(二)(转)

如果你看完了上篇博文的伪代码,那么我们就可以开始谈谈它的源代码了. 下面先贴出它的类定义,一些成员函数的具体实现先忽略. [cpp]   view plain copy   <EMBED id=ZeroClipboardMovie_1 height=18 name=ZeroClipboardMovie_1 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflas

libsvm代码阅读(3):关于Cache类的分析(转)

下面来分析Cache类的源码,该类位于svm.cpp中.这个类的主要功能是:负责运算所涉及的内存管理,包括申请.释放等. 简单来说:这个Cache类,首先通过Cache构造函数申请一块空间,这块空间的大小是:L个head_t大小的空间.然后get_data函数保证结构head_t中至少有len个float的内存,并且将可以使用的内存块的指针放在data指针中:而swap_index函数则是用于交换head[i]和head[j]. Cache类的定义如下: [cpp]       view pla

libsvm代码阅读:关于Kernel类分析(转)

这一篇博文来分析下Kernel类,代码上很简单,一般都能看懂.Kernel类主要是为SVM的核函数服务的,里面实现了SVM常用的核函数,通过函数指针来使用这些核函数. 其中几个常用核函数如下所示:(一般情况下,使用RBF核函数能取得很好的效果) 关于基类QMatrix在Kernel中的作用并不明显,只是定义了一些纯虚函数,Kernel继承这些函数,Kernel只对swap_index进行了定义.其余的get_Q和get_QD在Kernel并没有用到. [cpp]   view plain cop

libsvm代码阅读:关于svm_group_classes函数分析(转)

目前libsvm最新的version是3.17,主要的改变是在svm_group_classes函数中加了几行代码.官方的说明如下: Version 3.17 released on April Fools' day, 2013. We slightly adjust the way class labels are handled internally. By default labels are ordered by their first occurrence in the trainin

libsvm代码阅读:关于svm_train函数分析(转)

在svm中,训练是一个十分重要的步骤,下面我们来看看svm的train部分. 在libsvm中的svm_train中分别有回归和分类两部分,我只对其中分类做介绍. 分类的步骤如下: 统计类别总数,同时记录类别的标号,统计每个类的样本数目 将属于相同类的样本分组,连续存放 计算权重C 训练n(n-1)/2 个模型 初始化nozero数组,便于统计SV //初始化概率数组 训练过程中,需要重建子数据集,样本的特征不变,但样本的类别要改为+1/-1 //如有必要,先调用svm_binary_svc_p

libsvm代码阅读(1):基础准备与svm.h头文件(转)

libsvm是国立台湾大学Chih-Jen Lin开发的一个SVM的函数库,是当前应用最广泛的svm函数库,从2000年到2010年,该函数库的下载量达到250000之多.它的最新版本是version 3.17,主要是对是svm_group_classes做了修改. 主页:LIBSVM -- A Library for Support Vector Machines 下载地址:zip.file ortar.gz 我下载后的解压文件如下所示: libsvm函数包的组织结构如下: 1.主文件路径:包

libsvm代码阅读(2):svm.cpp浅谈和函数指针(转)

svm.cpp浅谈 svm.cpp总共有3159行代码,实现了svm算法的核心功能,里面总共有Cache.Kernel.ONE_CLASS_Q.QMatrix.Solver.Solver_NU.SVC_Q.SVR_Q 8个类(如下图1所示),而它们之间的继承和组合关系如图2.图3所示.在这些类中Cache.Kernel.Solver是核心类,对整个算法起支撑作用.在以后的博文中我们将对这3个核心类做重点注解分析,另外还将对svm.cpp中的svm_train函数做一个注解分析. 图1 图2 图3

代码阅读分析工具Understand 2.0试用

Understand 2.0是一款源码阅读分析软件,功能强大.试用过一段时间后,感觉相当不错,确实能够大大提高代码阅读效率.因为Understand功能十分强大,本文不可能详尽地介绍它的全部功能,所以仅仅列举本人觉得比較重要或有特色的功能,以做抛砖引玉之举. Understand 2.0能够从http://www.scitools.com/下载到,安装后能够试用15天. 使用Understand阅读代码前,要先创建一个Project,然后把全部的源码文件增加到这个Project里.这里我创建了一

Caffe源码-Solver类

Solver类简介 Net类中实现了网络的前向/反向计算和参数更新,而Solver类中则是对此进行进一步封装,包含可用于逐次训练网络的Step()函数,和用于求解网络的优化解的Solve()函数,同时还实现了一些存储.读取网络模型快照的接口函数. solver.cpp源码 template<typename Dtype> void Solver<Dtype>::SetActionFunction(ActionCallback func) { action_request_funct