Expectation Maximization-EM(期望最大化)-算法以及源码

在统计计算中,最大期望(EM)算法是在概率(probabilistic)模型中寻找参数最大似然估计的算法,其中概率模型依赖于无法观测的隐藏变量(Latent Variable)。最大期望经常用在机器学习和计算机视觉的数据聚类(Data Clustering) 领域。最大期望算法经过两个步骤交替进行计算,第一步是计算期望(E),利用对隐藏变量的现有估计值,计算其最大似然估计值;第二步是最大化(M),最大 化在 E 步上求得的最大似然值来计算参数的值。M 步上找到的参数估计值被用于下一个 E 步计算中,这个过程不断交替进行。

最大期望值算法由 Arthur Dempster,Nan LairdDonald Rubin在他们1977年发表的经典论文中提出。他们指出此方法之前其实已经被很多作者"在他们特定的研究领域中多次提出过"。

我们用  表示能够观察到的不完整的变量值,用  表示无法观察到的变量值,这样  和  一起组成了完整的数据。 可能是实际测量丢失的数据,也可能是能够简化问题的隐藏变量,如果它的值能够知道的话。例如,在混合模型(Mixture Model)中,如果“产生”样本的混合元素成分已知的话最大似然公式将变得更加便利(参见下面的例子)。

估计无法观测的数据

让  代表矢量 θ:  定义的参数的全部数据的概率分布(连续情况下)或者概率聚类函数(离散情况下),那么从这个函数就可以得到全部数据的最大似然值,另外,在给定的观察到的数据条件下未知数据的条件分布可以表示为:

EM算法有这么两个步骤E和M:

Expectation step: Choose q to maximize F:

Maximization step: Choose θ to maximize F:

举个例子吧:高斯混合

假设 x = (x1,x2,…,xn) 是一个独立的观测样本,来自两个多元d维正态分布的混合, 让z=(z1,z2,…,zn)是潜在变量,确定其中的组成部分,是观测的来源.

即:

 and 

where

 and 

目标呢就是估计下面这些参数了,包括混合的参数以及高斯的均值很方差:

似然函数:

where  是一个指示函数 ,f 是 一个多元正态分布的概率密度函数. 可以写成指数形式:

下面就进入两个大步骤了:
E-step

给定目前的参数估计 θ(t),  Zi 的条件概率分布是由贝叶斯理论得出,高斯之间用参数 τ加权:

.

因此,E步骤的结果:

M步骤

Q(θ|θ(t))的二次型表示可以使得 最大化θ相对简单.  τ, (μ1,Σ1) and (μ2,Σ2) 可以单独的进行最大化.

首先考虑 τ, 有条件τ1 + τ2=1:

和MLE的形式是类似的,二项分布 , 因此:

下一步估计 (μ1,Σ1):

和加权的 MLE就正态分布来说类似

 and 

对称的:

 and .

这个例子来自Answers.com的Expectation-maximization algorithm,由于还没有深入体验,心里还说不出一些更通俗易懂的东西来,等研究了并且应用了可能就有所理解和消化。另外,liuxqsmile也做了一些理解和翻译。

============

在网上的源码不多,有一个很好的EM_GM.m,是滑铁卢大学的Patrick P. C. Tsui写的,拿来分享一下:

运行的时候可以如下进行初始化:

1 % matlab code
2 X = zeros(600,2);
3 X(1:200,:) = normrnd(0,1,200,2);
4 X(201:400,:) = normrnd(0,2,200,2);
5 X(401:600,:) = normrnd(0,3,200,2);
6 [W,M,V,L] = EM_GM(X,3,[],[],1,[])

下面是程序源码:

  1 %matlab code
  2
  3 function [W,M,V,L] = EM_GM(X,k,ltol,maxiter,pflag,Init)
  4 % [W,M,V,L] = EM_GM(X,k,ltol,maxiter,pflag,Init)
  5 %
  6 % EM algorithm for k multidimensional Gaussian mixture estimation
  7 %
  8 % Inputs:
  9 %   X(n,d) - input data, n=number of observations, d=dimension of variable
 10 %   k - maximum number of Gaussian components allowed
 11 %   ltol - percentage of the log likelihood difference between 2 iterations ([] for none)
 12 %   maxiter - maximum number of iteration allowed ([] for none)
 13 %   pflag - 1 for plotting GM for 1D or 2D cases only, 0 otherwise ([] for none)
 14 %   Init - structure of initial W, M, V: Init.W, Init.M, Init.V ([] for none)
 15 %
 16 % Ouputs:
 17 %   W(1,k) - estimated weights of GM
 18 %   M(d,k) - estimated mean vectors of GM
 19 %   V(d,d,k) - estimated covariance matrices of GM
 20 %   L - log likelihood of estimates
 21 %
 22 % Written by
 23 %   Patrick P. C. Tsui,
 24 %   PAMI research group
 25 %   Department of Electrical and Computer Engineering
 26 %   University of Waterloo,
 27 %   March, 2006
 28 %
 29
 30 %%%% Validate inputs %%%%
 31 if nargin <= 1,
 32  disp(‘EM_GM must have at least 2 inputs: X,k!/n‘)
 33  return
 34 elseif nargin == 2,
 35  ltol = 0.1; maxiter = 1000; pflag = 0; Init = [];
 36  err_X = Verify_X(X);
 37  err_k = Verify_k(k);
 38  if err_X | err_k, return; end
 39 elseif nargin == 3,
 40  maxiter = 1000; pflag = 0; Init = [];
 41  err_X = Verify_X(X);
 42  err_k = Verify_k(k);
 43  [ltol,err_ltol] = Verify_ltol(ltol);
 44  if err_X | err_k | err_ltol, return; end
 45 elseif nargin == 4,
 46  pflag = 0;  Init = [];
 47  err_X = Verify_X(X);
 48  err_k = Verify_k(k);
 49  [ltol,err_ltol] = Verify_ltol(ltol);
 50  [maxiter,err_maxiter] = Verify_maxiter(maxiter);
 51  if err_X | err_k | err_ltol | err_maxiter, return; end
 52 elseif nargin == 5,
 53  Init = [];
 54  err_X = Verify_X(X);
 55  err_k = Verify_k(k);
 56  [ltol,err_ltol] = Verify_ltol(ltol);
 57  [maxiter,err_maxiter] = Verify_maxiter(maxiter);
 58  [pflag,err_pflag] = Verify_pflag(pflag);
 59  if err_X | err_k | err_ltol | err_maxiter | err_pflag, return; end
 60 elseif nargin == 6,
 61  err_X = Verify_X(X);
 62  err_k = Verify_k(k);
 63  [ltol,err_ltol] = Verify_ltol(ltol);
 64  [maxiter,err_maxiter] = Verify_maxiter(maxiter);
 65  [pflag,err_pflag] = Verify_pflag(pflag);
 66  [Init,err_Init]=Verify_Init(Init);
 67  if err_X | err_k | err_ltol | err_maxiter | err_pflag | err_Init, return; end
 68 else
 69  disp(‘EM_GM must have 2 to 6 inputs!‘);
 70  return
 71 end
 72
 73 %%%% Initialize W, M, V,L %%%%
 74 t = cputime;
 75 if isempty(Init),
 76  [W,M,V] = Init_EM(X,k); L = 0;
 77 else
 78  W = Init.W;
 79  M = Init.M;
 80  V = Init.V;
 81 end
 82 Ln = Likelihood(X,k,W,M,V); % Initialize log likelihood
 83 Lo = 2*Ln;
 84
 85 %%%% EM algorithm %%%%
 86 niter = 0;
 87 while (abs(100*(Ln-Lo)/Lo)>ltol) & (niter<=maxiter),
 88  E = Expectation(X,k,W,M,V); % E-step
 89  [W,M,V] = Maximization(X,k,E);  % M-step
 90  Lo = Ln;
 91  Ln = Likelihood(X,k,W,M,V);
 92  niter = niter + 1;
 93 end
 94 L = Ln;
 95
 96 %%%% Plot 1D or 2D %%%%
 97 if pflag==1,
 98  [n,d] = size(X);
 99  if d>2,
100  disp(‘Can only plot 1 or 2 dimensional applications!/n‘);
101  else
102  Plot_GM(X,k,W,M,V);
103  end
104  elapsed_time = sprintf(‘CPU time used for EM_GM: %5.2fs‘,cputime-t);
105  disp(elapsed_time);
106  disp(sprintf(‘Number of iterations: %d‘,niter-1));
107 end
108 %%%%%%%%%%%%%%%%%%%%%%
109 %%%% End of EM_GM %%%%
110 %%%%%%%%%%%%%%%%%%%%%%
111
112 function E = Expectation(X,k,W,M,V)
113 [n,d] = size(X);
114 a = (2*pi)^(0.5*d);
115 S = zeros(1,k);
116 iV = zeros(d,d,k);
117 for j=1:k,
118  if V(:,:,j)==zeros(d,d), V(:,:,j)=ones(d,d)*eps; end
119  S(j) = sqrt(det(V(:,:,j)));
120  iV(:,:,j) = inv(V(:,:,j));
121 end
122 E = zeros(n,k);
123 for i=1:n,
124  for j=1:k,
125  dXM = X(i,:)‘-M(:,j);
126  pl = exp(-0.5*dXM‘*iV(:,:,j)*dXM)/(a*S(j));
127  E(i,j) = W(j)*pl;
128  end
129  E(i,:) = E(i,:)/sum(E(i,:));
130 end
131 %%%%%%%%%%%%%%%%%%%%%%%%%%%%
132 %%%% End of Expectation %%%%
133 %%%%%%%%%%%%%%%%%%%%%%%%%%%%
134
135 function [W,M,V] = Maximization(X,k,E)
136 [n,d] = size(X);
137 W = zeros(1,k); M = zeros(d,k);
138 V = zeros(d,d,k);
139 for i=1:k,  % Compute weights
140  for j=1:n,
141  W(i) = W(i) + E(j,i);
142  M(:,i) = M(:,i) + E(j,i)*X(j,:)‘;
143  end
144  M(:,i) = M(:,i)/W(i);
145 end
146 for i=1:k,
147  for j=1:n,
148  dXM = X(j,:)‘-M(:,i);
149  V(:,:,i) = V(:,:,i) + E(j,i)*dXM*dXM‘;
150  end
151  V(:,:,i) = V(:,:,i)/W(i);
152 end
153 W = W/n;
154 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%
155 %%%% End of Maximization %%%%
156 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%
157
158 function L = Likelihood(X,k,W,M,V)
159 % Compute L based on K. V. Mardia, "Multivariate Analysis", Academic Press, 1979, PP. 96-97
160 % to enchance computational speed
161 [n,d] = size(X);
162 U = mean(X)‘;
163 S = cov(X);
164 L = 0;
165 for i=1:k,
166  iV = inv(V(:,:,i));
167  L = L + W(i)*(-0.5*n*log(det(2*pi*V(:,:,i))) ...
168  -0.5*(n-1)*(trace(iV*S)+(U-M(:,i))‘*iV*(U-M(:,i))));
169 end
170 %%%%%%%%%%%%%%%%%%%%%%%%%%%
171 %%%% End of Likelihood %%%%
172 %%%%%%%%%%%%%%%%%%%%%%%%%%%
173
174 function err_X = Verify_X(X)
175 err_X = 1;
176 [n,d] = size(X);
177 if n<d,
178  disp(‘Input data must be n x d!/n‘);
179  return
180 end
181 err_X = 0;
182 %%%%%%%%%%%%%%%%%%%%%%%%%
183 %%%% End of Verify_X %%%%
184 %%%%%%%%%%%%%%%%%%%%%%%%%
185
186 function err_k = Verify_k(k)
187 err_k = 1;
188 if ~isnumeric(k) | ~isreal(k) | k<1,
189  disp(‘k must be a real integer >= 1!/n‘);
190  return
191 end
192 err_k = 0;
193 %%%%%%%%%%%%%%%%%%%%%%%%%
194 %%%% End of Verify_k %%%%
195 %%%%%%%%%%%%%%%%%%%%%%%%%
196
197 function [ltol,err_ltol] = Verify_ltol(ltol)
198 err_ltol = 1;
199 if isempty(ltol),
200  ltol = 0.1;
201 elseif ~isreal(ltol) | ltol<=0,
202  disp(‘ltol must be a positive real number!‘);
203  return
204 end
205 err_ltol = 0;
206 %%%%%%%%%%%%%%%%%%%%%%%%%%%%
207 %%%% End of Verify_ltol %%%%
208 %%%%%%%%%%%%%%%%%%%%%%%%%%%%
209
210 function [maxiter,err_maxiter] = Verify_maxiter(maxiter)
211 err_maxiter = 1;
212 if isempty(maxiter),
213  maxiter = 1000;
214 elseif ~isreal(maxiter) | maxiter<=0,
215  disp(‘ltol must be a positive real number!‘);
216  return
217 end
218 err_maxiter = 0;
219 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
220 %%%% End of Verify_maxiter %%%%
221 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
222
223 function [pflag,err_pflag] = Verify_pflag(pflag)
224 err_pflag = 1;
225 if isempty(pflag),
226  pflag = 0;
227 elseif pflag~=0 & pflag~=1,
228  disp(‘Plot flag must be either 0 or 1!/n‘);
229  return
230 end
231 err_pflag = 0;
232 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%
233 %%%% End of Verify_pflag %%%%
234 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%
235
236 function [Init,err_Init] = Verify_Init(Init)
237 err_Init = 1;
238 if isempty(Init),
239  % Do nothing;
240 elseif isstruct(Init),
241  [Wd,Wk] = size(Init.W);
242  [Md,Mk] = size(Init.M);
243  [Vd1,Vd2,Vk] = size(Init.V);
244  if Wk~=Mk | Wk~=Vk | Mk~=Vk,
245  disp(‘k in Init.W(1,k), Init.M(d,k) and Init.V(d,d,k) must equal!/n‘)
246  return
247  end
248  if Md~=Vd1 | Md~=Vd2 | Vd1~=Vd2,
249  disp(‘d in Init.W(1,k), Init.M(d,k) and Init.V(d,d,k) must equal!/n‘)
250  return
251  end
252 else
253  disp(‘Init must be a structure: W(1,k), M(d,k), V(d,d,k) or []!‘);
254  return
255 end
256 err_Init = 0;
257 %%%%%%%%%%%%%%%%%%%%%%%%%%%%
258 %%%% End of Verify_Init %%%%
259 %%%%%%%%%%%%%%%%%%%%%%%%%%%%
260
261 function [W,M,V] = Init_EM(X,k)
262 [n,d] = size(X);
263 [Ci,C] = kmeans(X,k,‘Start‘,‘cluster‘, ...
264  ‘Maxiter‘,100, ...
265  ‘EmptyAction‘,‘drop‘, ...
266  ‘Display‘,‘off‘); % Ci(nx1) - cluster indeices; C(k,d) - cluster centroid (i.e. mean)
267 while sum(isnan(C))>0,
268  [Ci,C] = kmeans(X,k,‘Start‘,‘cluster‘, ...
269  ‘Maxiter‘,100, ...
270  ‘EmptyAction‘,‘drop‘, ...
271  ‘Display‘,‘off‘);
272 end
273 M = C‘;
274 Vp = repmat(struct(‘count‘,0,‘X‘,zeros(n,d)),1,k);
275 for i=1:n, % Separate cluster points
276  Vp(Ci(i)).count = Vp(Ci(i)).count + 1;
277  Vp(Ci(i)).X(Vp(Ci(i)).count,:) = X(i,:);
278 end
279 V = zeros(d,d,k);
280 for i=1:k,
281  W(i) = Vp(i).count/n;
282  V(:,:,i) = cov(Vp(i).X(1:Vp(i).count,:));
283 end
284 %%%%%%%%%%%%%%%%%%%%%%%%
285 %%%% End of Init_EM %%%%
286 %%%%%%%%%%%%%%%%%%%%%%%%
287
288 function Plot_GM(X,k,W,M,V)
289 [n,d] = size(X);
290 if d>2,
291  disp(‘Can only plot 1 or 2 dimensional applications!/n‘);
292  return
293 end
294 S = zeros(d,k);
295 R1 = zeros(d,k);
296 R2 = zeros(d,k);
297 for i=1:k,  % Determine plot range as 4 x standard deviations
298  S(:,i) = sqrt(diag(V(:,:,i)));
299  R1(:,i) = M(:,i)-4*S(:,i);
300  R2(:,i) = M(:,i)+4*S(:,i);
301 end
302 Rmin = min(min(R1));
303 Rmax = max(max(R2));
304 R = [Rmin:0.001*(Rmax-Rmin):Rmax];
305 clf, hold on
306 if d==1,
307  Q = zeros(size(R));
308  for i=1:k,
309  P = W(i)*normpdf(R,M(:,i),sqrt(V(:,:,i)));
310  Q = Q + P;
311  plot(R,P,‘r-‘); grid on,
312  end
313  plot(R,Q,‘k-‘);
314  xlabel(‘X‘);
315  ylabel(‘Probability density‘);
316 else % d==2
317  plot(X(:,1),X(:,2),‘r.‘);
318  for i=1:k,
319  Plot_Std_Ellipse(M(:,i),V(:,:,i));
320  end
321  xlabel(‘1^{st} dimension‘);
322  ylabel(‘2^{nd} dimension‘);
323  axis([Rmin Rmax Rmin Rmax])
324 end
325 title(‘Gaussian Mixture estimated by EM‘);
326 %%%%%%%%%%%%%%%%%%%%%%%%
327 %%%% End of Plot_GM %%%%
328 %%%%%%%%%%%%%%%%%%%%%%%%
329
330 function Plot_Std_Ellipse(M,V)
331 [Ev,D] = eig(V);
332 d = length(M);
333 if V(:,:)==zeros(d,d),
334  V(:,:) = ones(d,d)*eps;
335 end
336 iV = inv(V);
337 % Find the larger projection
338 P = [1,0;0,0];  % X-axis projection operator
339 P1 = P * 2*sqrt(D(1,1)) * Ev(:,1);
340 P2 = P * 2*sqrt(D(2,2)) * Ev(:,2);
341 if abs(P1(1)) >= abs(P2(1)),
342  Plen = P1(1);
343 else
344  Plen = P2(1);
345 end
346 count = 1;
347 step = 0.001*Plen;
348 Contour1 = zeros(2001,2);
349 Contour2 = zeros(2001,2);
350 for x = -Plen:step:Plen,
351  a = iV(2,2);
352  b = x * (iV(1,2)+iV(2,1));
353  c = (x^2) * iV(1,1) - 1;
354  Root1 = (-b + sqrt(b^2 - 4*a*c))/(2*a);
355  Root2 = (-b - sqrt(b^2 - 4*a*c))/(2*a);
356  if isreal(Root1),
357  Contour1(count,:) = [x,Root1] + M‘;
358  Contour2(count,:) = [x,Root2] + M‘;
359  count = count + 1;
360  end
361 end
362 Contour1 = Contour1(1:count-1,:);
363 Contour2 = [Contour1(1,:);Contour2(1:count-1,:);Contour1(count-1,:)];
364 plot(M(1),M(2),‘k+‘);
365 plot(Contour1(:,1),Contour1(:,2),‘k-‘);
366 plot(Contour2(:,1),Contour2(:,2),‘k-‘);
367 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
368 %%%% End of Plot_Std_Ellipse %%%%
369 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

from: http://www.zhizhihu.com/html/y2010/2109.html

时间: 2024-10-06 15:58:33

Expectation Maximization-EM(期望最大化)-算法以及源码的相关文章

CRC8算法DELPHI源码

unit Crc8; interface Uses Classes, Windows; Function Crc_8n(p : array of BYTE; len : BYTE) : Byte; implementation Function Crc_8n(p : array of BYTE; len : BYTE) : Byte; Var j, cbit, aout, crc, crc_a, crc_b : Byte; i : integer; begin crc := 0; i := 0;

SURF算法与源码分析、下

上一篇文章 SURF算法与源码分析.上 中主要分析的是SURF特征点定位的算法原理与相关OpenCV中的源码分析,这篇文章接着上篇文章对已经定位到的SURF特征点进行特征描述.这一步至关重要,这是SURF特征点匹配的基础.总体来说算法思路和SIFT相似,只是每一步都做了不同程度的近似与简化,提高了效率. 1. SURF特征点方向分配 为了保证特征矢量具有旋转不变性,与SIFT特征一样,需要对每个特征点分配一个主方向.为些,我们需要以特征点为中心,以$6s$($s = 1.2 *L /9$为特征点

Spark MLlib机器学习算法、源码及实战讲解pdf电子版下载

Spark MLlib机器学习算法.源码及实战讲解pdf电子版下载 链接:https://pan.baidu.com/s/1ruX9inG5ttOe_5lhpK_LQg 提取码:idcb <Spark MLlib机器学习:算法.源码及实战详解>书中讲解由浅入深慢慢深入,解析讲解了MLlib的底层原理:数据操作及矩阵向量计算操作,该部分是MLlib实现的基础:并对此延伸机器学习的算法,循序渐进的讲解其中的原理,是读者一点一点的理解和掌握书中的知识. 目录 · · · · · · 第一部分 Spa

Weka算法Classifier-meta-AdditiveRegression源码分析

博主最近迷上了打怪物猎人,这片文章拖了很久才开始动笔 一.算法 AdditiveRegression,换个更出名一点的叫法可以称作GBDT(Grandient Boosting Decision Tree)梯度下降分类树,或者GBRT(Grandient Boosting Regression Tree)梯度下降回归树,是一种多分类器组合的算法,更确切的说,是属于Boosting算法. 谈到Boosting算法,就不能不提AdaBoost,参见之前我写的博客,可以看到AdaBoost的核心是级联

Weka算法Clusterers-Xmeans源码分析

</pre><p></p><p><span style="font-size:18px">上几篇博客都是分析的分类器算法(有监督学习),这次就分析一个聚类算法(无监督学习).</span></p><p><span style="font-size:18px"></span></p><p><span style=&quo

Weka算法Classifier-trees-REPTree源码分析(一)

一.算法 关于REPTree我实在是没找到什么相关其算法的资料,或许是Weka自创的一个关于决策树的改进,也许是其它某种决策树方法的别名,根据类的注释:Fast decision tree learner. Builds a decision/regression tree using information gain/variance and prunes it using reduced-error pruning (with backfitting).  Only sorts values

SURF算法与源码分析、上

如果说SIFT算法中使用DOG对LOG进行了简化,提高了搜索特征点的速度,那么SURF算法则是对DoH的简化与近似.虽然SIFT算法已经被认为是最有效的,也是最常用的特征点提取的算法,但如果不借助于硬件的加速和专用图像处理器的配合,SIFT算法以现有的计算机仍然很难达到实时的程度.对于需要实时运算的场合,如基于特征点匹配的实时目标跟踪系统,每秒要处理8-24帧的图像,需要在毫秒级内完成特征点的搜索.特征矢量生成.特征矢量匹配.目标锁定等工作,这样SIFT算法就很难适应这种需求了.SURF借鉴了S

Weka算法Classifier-tree-J48源码分析(四)总结

一.ClassifyInstance 首先先说一下构造好的分类树是如何对一个新的Instance进行区分. 直观上,会对树进行一个检索,从根节点根据属性的不同,最终走到叶子节点,得到具体的分类. 但Weka在实现上,是遍历了这个Instance属于不同的class的可能性,并从中选出了一个最大的,代码如下: public double classifyInstance(Instance instance) throws Exception { double maxProb = -1; doubl

数据结构算法 - ConcurrentHashMap 源码解析

五个线程同时往 HashMap 中 put 数据会发生什么? ConcurrentHashMap 是怎么保证线程安全的? 在分析 HashMap 源码时还遗留这两个问题,这次我们站在 Java 多线程内存模型和 synchronized 的实现原理,这两个角度来彻底分析一下.至于 JDK 1.8 的红黑树不是本文探讨的内容. 640?wx_fmt=gif1. Java 多线程内存模型 五个线程同时往 HashMap 中 put 数据会出现两种现象,大概率会出现数据丢失,小概率会出现死循环,我们不