EM最大期望算法

参考资料:http://blog.csdn.net/zouxy09/article/details/8537620

http://www.cnblogs.com/jerrylead/archive/2011/04/06/2006936.html

我的数据挖掘算法代码实现:https://github.com/linyiqun/DataMiningAlgorithm

介绍

em算法是一种迭代算法,用于含有隐变量的参数模型的最大似然估计或极大后验概率估计。EM算法,作为一个框架思想,它可以应用在很多领域,比如说数据聚类领域----模糊聚类的处理,待会儿也会给出一个这样的实现例子。

EM算法原理

EM算法从名称上就能看出他可以被分成2个部分,E-Step和M-Step。E-Step叫做期望化步骤,M-Step为最大化步骤。

整体算法的步骤如下所示:

1、初始化分布参数。

2、(E-Step)计算期望E,利用对隐藏变量的现有估计值,计算其最大似然估计值,以此实现期望化的过程。

3、(M-Step)最大化在E-步骤上的最大似然估计值来计算参数的值

4、重复2,3步骤直到收敛。

以上就是EM算法的核心原理,也许您会想,真的这么简单,其实事实是我省略了其中复杂的数据推导的过程,因为如果不理解EM的算法原理,去看其中的数据公式的推导,会让人更加晕的。好,下面给出数据的推导过程,本人数学也不好,于是用了别人的推导过程,人家已经写得非常详细了。

EM算法的推导过程

jensen不等式

在介绍推导过程的时候,需要明白jensen不等式,他是一个关于凸函数的一个定理,直接上公式定义;

如果f是凸函数,X是随机变量,那么

特别地,如果f是严格凸函数,那么当且仅当,也就是说X是常量。

这里我们将简写为

如果用图表示会很清晰:

这里需要解释的是E(X)的值为什么是(a+b)/2,因为有0.5 的概率是a,0.5的概率是b,于是他的期望就是a,b的和的中间值了。同理在y轴上的值也是如此。

EM算法的公式表达形式

EM算法转化为公式的表达形式为:

给定的训练样本是,样例间独立,我们想找到每个样例隐含的类别z,能使得p(x,z)最大。p(x,z)的最大似然估计如下:

然后对这个公式做一点变化,就可以用上jensen不等式了,神奇的一笔来了:

可以由前面阐述的内容得到下面的公式:

(1)到(2)比较直接,就是分子分母同乘以一个相等的函数。(2)到(3)利用了Jensen不等式。对于每一个样例i,让表示该样例隐含变量z的某种分布,满足的条件是。于是就来到了问题的关键,通过上面的不等式,我们就可以确定式子的下界,然后我们就可以不断的提高此下界达到逼近最后真实值的目的值,那么什么时候达到想到的时候呢,没错,就是这个不等式变成等式的时候,然后再依据之前描述的jensen不等式的说明,当不等式变为等式的时候,当且仅当,也就是说X是常量,推出就是下面的公式:

再推导下,由于(因为Q是随机变量z(i)的概率密度函数),则可以得到:分子的和等于c(分子分母都对所有z(i)求和:多个等式分子分母相加不变,这个认为每个样例的两个概率比值都是c),再次继续推导;

最后就得出了EM算法的一般过程了:

循环重复直到收敛

(E步)对于每一个i,计算

(M步)计算

也许你看完这个数学推导的过程已经开始头昏了,没有关系,下面给出一个实例,让大家真切的感受一下EM算法的神奇。

EM算法的模糊聚类实现

在这里我会给出一个自己实现的基于EM算法的计算模糊聚类。

输入测试的数据文件,里面包含了a-f 7个点坐标:

3 3
4 10
9 6
14 8
18 11
21 7

开始时默认簇中心点C1, C2为a和b。这就算是参数的初始赋值,然后是主要的操作;

1、E-Step:期望步根据当前的的模糊聚类或概率簇的参数,把对象指派到簇中。

2、M-Step:最大化步发现新的聚类或参数,最小化模糊聚类的SSE(对象的误差平方和,这个在程序中会有所体现)。在M步中会用到这个公式,根据划分矩阵重新调整计算簇的中心。

最后的收敛条件为,计算出的簇中心点的坐标的横纵坐标轴的误差和不超过1.0,意味着基本不再变化了。

主程序类:

package DataMining_EM;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;

/**
 * EM最大期望算法工具类
 *
 * @author lyq
 *
 */
public class EMTool {
	// 测试数据文件地址
	private String dataFilePath;
	// 测试坐标点数据
	private String[][] data;
	// 测试坐标点数据列表
	private ArrayList<Point> pointArray;
	// 目标C1点
	private Point p1;
	// 目标C2点
	private Point p2;

	public EMTool(String dataFilePath) {
		this.dataFilePath = dataFilePath;
		pointArray = new ArrayList<>();
	}

	/**
	 * 从文件中读取数据
	 */
	public void readDataFile() {
		File file = new File(dataFilePath);
		ArrayList<String[]> dataArray = new ArrayList<String[]>();

		try {
			BufferedReader in = new BufferedReader(new FileReader(file));
			String str;
			String[] tempArray;
			while ((str = in.readLine()) != null) {
				tempArray = str.split(" ");
				dataArray.add(tempArray);
			}
			in.close();
		} catch (IOException e) {
			e.getStackTrace();
		}

		data = new String[dataArray.size()][];
		dataArray.toArray(data);

		// 开始时默认取头2个点作为2个簇中心
		p1 = new Point(Integer.parseInt(data[0][0]),
				Integer.parseInt(data[0][1]));
		p2 = new Point(Integer.parseInt(data[1][0]),
				Integer.parseInt(data[1][1]));

		Point p;
		for (String[] array : data) {
			// 将数据转换为对象加入列表方便计算
			p = new Point(Integer.parseInt(array[0]),
					Integer.parseInt(array[1]));
			pointArray.add(p);
		}
	}

	/**
	 * 计算坐标点对于2个簇中心点的隶属度
	 *
	 * @param p
	 *            待测试坐标点
	 */
	private void computeMemberShip(Point p) {
		// p点距离第一个簇中心点的距离
		double distance1 = 0;
		// p距离第二个中心点的距离
		double distance2 = 0;

		// 用欧式距离计算
		distance1 = Math.pow(p.getX() - p1.getX(), 2)
				+ Math.pow(p.getY() - p1.getY(), 2);
		distance2 = Math.pow(p.getX() - p2.getX(), 2)
				+ Math.pow(p.getY() - p2.getY(), 2);

		// 计算对于p1点的隶属度,与距离成反比关系,距离靠近越小,隶属度越大,所以要用大的distance2另外的距离来表示
		p.setMemberShip1(distance2 / (distance1 + distance2));
		// 计算对于p2点的隶属度
		p.setMemberShip2(distance1 / (distance1 + distance2));
	}

	/**
	 * 执行期望最大化步骤
	 */
	public void exceptMaxStep() {
		// 新的优化过的簇中心点
		double p1X = 0;
		double p1Y = 0;
		double p2X = 0;
		double p2Y = 0;
		double temp1 = 0;
		double temp2 = 0;
		// 误差值
		double errorValue1 = 0;
		double errorValue2 = 0;
		// 上次更新的簇点坐标
		Point lastP1 = null;
		Point lastP2 = null;

		// 当开始计算的时候,或是中心点的误差值超过1的时候都需要再次迭代计算
		while (lastP1 == null || errorValue1 > 1.0 || errorValue2 > 1.0) {
			for (Point p : pointArray) {
				computeMemberShip(p);
				p1X += p.getMemberShip1() * p.getMemberShip1() * p.getX();
				p1Y += p.getMemberShip1() * p.getMemberShip1() * p.getY();
				temp1 += p.getMemberShip1() * p.getMemberShip1();

				p2X += p.getMemberShip2() * p.getMemberShip2() * p.getX();
				p2Y += p.getMemberShip2() * p.getMemberShip2() * p.getY();
				temp2 += p.getMemberShip2() * p.getMemberShip2();
			}

			lastP1 = new Point(p1.getX(), p1.getY());
			lastP2 = new Point(p2.getX(), p2.getY());

			// 套公式计算新的簇中心点坐标,最最大化处理
			p1.setX(p1X / temp1);
			p1.setY(p1Y / temp1);
			p2.setX(p2X / temp2);
			p2.setY(p2Y / temp2);

			errorValue1 = Math.abs(lastP1.getX() - p1.getX())
					+ Math.abs(lastP1.getY() - p1.getY());
			errorValue2 = Math.abs(lastP2.getX() - p2.getX())
					+ Math.abs(lastP2.getY() - p2.getY());
		}

		System.out.println(MessageFormat.format(
				"簇中心节点p1({0}, {1}), p2({2}, {3})", p1.getX(), p1.getY(),
				p2.getX(), p2.getY()));
	}

}

坐标点Point类:

/**
 * 坐标点类
 *
 * @author lyq
 *
 */
public class Point {
	// 坐标点横坐标
	private double x;
	// 坐标点纵坐标
	private double y;
	// 坐标点对于P1的隶属度
	private double memberShip1;
	// 坐标点对于P2的隶属度
	private double memberShip2;

	public Point(double d, double e) {
		this.x = d;
		this.y = e;
	}

	public double getX() {
		return x;
	}

	public void setX(double x) {
		this.x = x;
	}

	public double getY() {
		return y;
	}

	public void setY(double y) {
		this.y = y;
	}

	public double getMemberShip1() {
		return memberShip1;
	}

	public void setMemberShip1(double memberShip1) {
		this.memberShip1 = memberShip1;
	}

	public double getMemberShip2() {
		return memberShip2;
	}

	public void setMemberShip2(double memberShip2) {
		this.memberShip2 = memberShip2;
	}

}

调用类;

/**
 * EM期望最大化算法场景调用类
 * @author lyq
 *
 */
public class Client {
	public static void main(String[] args){
		String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";

		EMTool tool = new EMTool(filePath);
		tool.readDataFile();
		tool.exceptMaxStep();
	}
}

输出结果:

簇中心节点p1(7.608, 5.907), p2(14.208, 8.745)

在这个程序中,隐藏变量就是簇中心点,通过不断的迭代计算,最终无限的接近真实值,相当有意思的算法。

时间: 2024-10-10 15:07:55

EM最大期望算法的相关文章

EM最大期望算法-走读

打算抽时间走读一些算法,尽量通俗的记录下面,希望帮助需要的同学. overview: 基本思想: 通过初始化参数P1,P2,推断出隐变量Z的概率分布(E步): 通过隐变量Z的概率分布,最大似然推断参数P1,P2 (M步). 梯度下降也可以解决隐变量估计问题,但求和项会随隐变量个数指数增长,EM方法是一种非梯度下降优化方法. 一 例子参考 ------------------------------------------------------- 引入问题:两枚材质不均匀硬币模型:五次实验,每次

数据挖掘经典算法——最大期望算法

算法定义 最大期望算法(Exception Maximization Algorithm,后文简称EM算法)是一种启发式的迭代算法,用于实现用样本对含有隐变量的模型的参数做极大似然估计.已知的概率模型内部存在隐含的变量,导致了不能直接用极大似然法来估计参数,EM算法就是通过迭代逼近的方式用实际的值带入求解模型内部参数的算法. 算法描述 算法的形式如下: 随机对参数赋予初值: While(求解参数不稳定){ E步骤:求在当前参数值和样本下的期望函数Q: M步骤:利用期望函数重新计算模型中新的估计值

最大期望算法 Expectation Maximization概念

在统计计算中,最大期望(EM,Expectation–Maximization)算法是在概率(probabilistic)模型中寻找参数最大似然估计的算法,其中概率模型依赖于无法观测的隐藏变量(Latent Variabl).最大期望经常用在机器学习和计算机视觉的数据集聚(Data Clustering)领域. 可以有一些比较形象的比喻说法把这个算法讲清楚.比如说食堂的大师傅炒了一份菜,要等分成两份给两个人吃,显然没有必要拿来天平一点一点的精确的去称分量,最简单的办法是先随意的把菜分到两个碗中,

EM 算法

这个暂时还不太明白,先写一点明白的. EM:最大期望算法,属于基于模型的聚类算法.是对似然函数的进一步应用. 我们知道,当我们想要估计某个分布的未知值,可以使用样本结果来进行似然估计,进而求最大似然估计就可以估计出要求的参数. 但是有时候还会有未知参数,这样就不能使用极大似然估计.当然这个参数与我们要估计的参数是有关联的. 比如说调查 男生 女生身高的问题.身高肯定是服从高斯分布.以往我们可以通过对男生抽样进而求出高斯分布的参数,女生也是,但是如果我们只能知道某个人的高度,却不能知道他是男生或者

数据挖掘十大经典算法

一. C4.5  C4.5算法是机器学习算法中的一种分类决策树算法,其核心算法是ID3 算法.   C4.5算法继承了ID3算法的优点,并在以下几方面对ID3算法进行了改进: 1) 用信息增益率来选择属性,克服了用信息增益选择属性时偏向选择取值多的属性的不足: 2) 在树构造过程中进行剪枝: 3) 能够完成对连续属性的离散化处理: 4) 能够对不完整数据进行处理. C4.5算法有如下优点:产生的分类规则易于理解,准确率较高.其缺点是:在构造树的过程中,需要对数据集进行多次的顺序扫描和排序,因而导

机器学习算法集锦

机器学习 机器学习(Machine Learning, ML)是一门多领域交叉学科,涉及概率论.统计学.逼近论.凸分析.算法复杂度理论等多门学科.专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能. 严格的定义:机器学习是一门研究机器获取新知识和新技能,并识别现有知识的学问.这里所说的"机器",指的就是计算机,电子计算机,中子计算机.光子计算机或神经计算机等等. 机器学习概论 由上图所示:机器学习分为四大块: classifi

数据挖掘10大算法详细介绍

想初步了解下怎样数据挖掘,看到一篇不错的文章转载过来啦~ 转自:http://blog.jobbole.com/89037/ 在一份调查问卷中,三个独立专家小组投票选出的十大最有影响力的数据挖掘算法,今天我打算用简单的语言来解释一下. 一旦你知道了这些算法是什么.怎么工作.能做什么.在哪里能找到,我希望你能把这篇博文当做一个跳板,学习更多的数据挖掘知识. 还等什么?这就开始吧! 1.C4.5算法 C4.5是做什么的?C4.5 以决策树的形式构建了一个分类器.为了做到这一点,需要给定 C4.5 表

机器学习——应用场景 算法应用场景

常见的机器学习模型:感知机,线性回归,逻辑回归,支持向量机,决策树,随机森林,GBDT,XGBoost,贝叶斯,KNN,K-means等: 常见的机器学习理论:过拟合问题,交叉验证问题,模型选择问题,模型融合问题等: K近邻:算法采用测量不同特征值之间的距离的方法进行分类. 优点: 1.简单好用,容易理解,精度高,理论成熟,既可以用来做分类也可以用来做回归: 2.可用于数值型数据和离散型数据: 3.训练时间复杂度为O(n):无数据输入假定: 4.对异常值不敏感 缺点: 1.计算复杂性高:空间复杂

几个常用算法的适应场景及其优缺点!

机器学习算法太多了,分类.回归.聚类.推荐.图像识别领域等等,要想找到一个合适算法真的不容易,所以在实际应用中,我们一般都是采用启发式学习方式来实验.通常最开始我们都会选择大家普遍认同的算法,诸如SVM,GBDT,Adaboost,现在深度学习很火热,神经网络也是一个不错的选择. 假如你在乎精度(accuracy)的话,最好的方法就是通过交叉验证(cross-validation)对各个算法一个个地进行测试,进行比较,然后调整参数确保每个算法达到最优解,最后选择最好的一个.但是如果你只是在寻找一