【Java】K-means算法Java实现以及图像分割

1.K-means算法简述以及代码原型

数据挖掘中一个重要算法是K-means,我这里就不做详细介绍。如果感兴趣的话可以移步陈皓的博客:

http://www.csdn.net/article/2012-07-03/2807073-k-means 讲得很好

总的来讲,k-means聚类需要以下几个步骤:

①.初始化数据

②.计算初始的中心点,可以随机选择

③.计算每个点到每个聚类中心的距离,并且划分到距离最短的聚类中心簇中

④.计算每个聚类簇的平均值,这个均值作为新的聚类中心,重复步骤3

⑤.如果达到最大循环或者是聚类中心不再变化或者聚类中心变化幅度小于一定范围时,停止循环。

恩,原理就是这样,超级简单。但是Java算法实现起来代码量并不小。这个代码也不算是完全自己写的啦,也有些借鉴。我把k-means实现封装在了一个类里面,这样就可以随时调用了呢。

import java.util.ArrayList;
import java.util.Random;

public class kmeans {
	private int k;//簇数
	private int m;//迭代次数
	private int dataSetLength;//数据集长度
	private ArrayList<double[]> dataSet;//数据集合
	private ArrayList<double[]> center;//中心链表
	private ArrayList<ArrayList<double[]>> cluster;//簇
	private ArrayList<Float> jc;//误差平方和,这个是用来计算中心聚点的移动哦
	private Random random;

	//设置原始数据集合
	public void setDataSet(ArrayList<double[]> dataSet){
		this.dataSet=dataSet;
	}
	//获得簇分组
	public  ArrayList<ArrayList<double[]>> getCluster(){
		return this.cluster;
	}
	//构造函数,传入要分的簇的数量
	public kmeans(int k){
		if(k<=0)
			k=1;
		this.k=k;
	}
	//初始化
	private void init(){
		m=0;
		random=new Random();
		if(dataSet==null||dataSet.size()==0)
			initDataSet();
		dataSetLength=dataSet.size();
		if(k>dataSetLength)
			k=dataSetLength;
		center=initCenters();
		cluster=initCluster();
		jc=new ArrayList<Float>();
	}
	//初始化数据集合
	private void initDataSet(){
		dataSet=new ArrayList<double[]>();
		double[][] dataSetArray=new double[][]{{8,2},{3,4},{2,5},{4,2},
				{7,3},{6,2},{4,7},{6,3},{5,3},{6,3},{6,9},
				{1,6},{3,9},{4,1},{8,6}};
		for(int i=0;i<dataSetArray.length;i++)
			dataSet.add(dataSetArray[i]);
	}
	//初始化中心链表,分成几簇就有几个中心
	private ArrayList<double[]> initCenters(){
		ArrayList<double[]> center= new ArrayList<double[]>();
		//生成一个随机数列,
		int[] randoms=new int[k];
		boolean flag;
		int temp=random.nextInt(dataSetLength);
		randoms[0]=temp;
		for(int i=1;i<k;i++){
			flag=true;
			while(flag){
				temp=random.nextInt(dataSetLength);
				int j=0;
				while(j<i){
					if(temp==randoms[j])
						break;
					j++;
				}
				if(j==i)
					flag=false;
			}
			randoms[i]=temp;
		}
		for(int i=0;i<k;i++)
			center.add(dataSet.get(randoms[i]));
		return center;
	}
	//初始化簇集合
	private ArrayList<ArrayList<double[]>> initCluster(){
		ArrayList<ArrayList<double[]>> cluster=
				new ArrayList<ArrayList<double[]>>();
		for(int i=0;i<k;i++)
			cluster.add(new ArrayList<double[]>());
		return cluster;
	}
	//计算距离
	private double distance(double[] element,double[] center){
		double distance=0.0f;
		double x=element[0]-center[0];
		double y=element[1]-center[1];
		double z=element[2]-center[2];
		double sum=x*x+y*y+z*z;
		distance=(double)Math.sqrt(sum);
		return distance;
	}
	//计算最短的距离
	private int minDistance(double[] distance){
		double minDistance=distance[0];
		int minLocation=0;
		for(int i=0;i<distance.length;i++){
			if(distance[i]<minDistance){
				minDistance=distance[i];
				minLocation=i;
			}else if(distance[i]==minDistance){
				if(random.nextInt(10)<5){
					minLocation=i;
				}
			}
		}
		return minLocation;
	}
	//每个点分类
	private void clusterSet(){
		double[] distance=new double[k];
		for(int i=0;i<dataSetLength;i++){
			//计算到每个中心店的距离
			for(int j=0;j<k;j++)
				distance[j]=distance(dataSet.get(i),center.get(j));
			//计算最短的距离
			int minLocation=minDistance(distance);
			//把他加到聚类里
			cluster.get(minLocation).add(dataSet.get(i));
		}
	}
	//计算新的中心
	private void setNewCenter(){
		for(int i=0;i<k;i++){
			int n=cluster.get(i).size();
			if(n!=0){
				double[] newcenter={0,0};
				for(int j=0;j<n;j++){
					newcenter[0]+=cluster.get(i).get(j)[0];
					newcenter[1]+=cluster.get(i).get(j)[1];
				}
				newcenter[0]=newcenter[0]/n;
				newcenter[1]=newcenter[1]/n;
				center.set(i, newcenter);
			}
		}
	}
	//求2点的误差平方
	private double errosquare(double[] element,double[] center){
		double x=element[0]-center[0];
		double y=element[1]-center[1];
		double errosquare=x*x+y*y;
		return errosquare;
	}
	//计算误差平方和准则函数
	private void countRule(){
		float jcf=0;
		for(int i=0;i<cluster.size();i++){
			for(int j=0;j<cluster.get(i).size();j++)
				jcf+=errosquare(cluster.get(i).get(j),center.get(i));
		jc.add(jcf);
		}
	}
	//核心算法
	private void Kmeans(){
		//初始化各种变量,随机选定中心,初始化聚类
		init();
		//开始循环
		while(true){
			//把每个点分到聚类中去
			clusterSet();
			//计算目标函数
			countRule();
			//检查误差变化,因为我规定的计算循环次数为50次,所以就不用计算这个啦,你要愿意用也可以,就是慢一点
			/*
			if(m!=0){
				if(jc.get(m)-jc.get(m-1)==0)
					break;
			}*/
			if(m>=50)
				break;
			//否则继续生成新的中心
			setNewCenter();
			m++;
			cluster.clear();
			cluster=initCluster();

		}
	}
    //只暴露一个接口给外部类
	public void execute(){
		System.out.print("start kmeans\n");
		Kmeans();
		System.out.print("kmeans end\n");
	}
        //用来在外面打印出来已经分好的聚类
	public void printDataArray(ArrayList<double[]> data,String dataArrayName){
		for(int i=0;i<data.size();i++){
			System.out.print("print:"+dataArrayName+"["+i+"]={"+data.get(i)[0]+","+data.get(i)[1]+"}\n");
		}
		System.out.print("==========================");
	}
}

嗯,代码就是这样。注释写的很详细,也都能看得懂。下面我给一个测试例子。

import java.util.ArrayList;

public class Test {
	public static void main(String[] args){
		kmeans k=new kmeans(2);
		ArrayList<double[]> dataSet=new ArrayList<double[]>();
		dataSet.add(new double[]{2,2,2});
		dataSet.add(new double[]{1,2,2});
		dataSet.add(new double[]{2,1,2});
		dataSet.add(new double[]{1,3,2});
		dataSet.add(new double[]{3,1,2});
		dataSet.add(new double[]{-2,-2,-2});
		dataSet.add(new double[]{-1,-2,-2});
		dataSet.add(new double[]{-2,-1,-2});
		dataSet.add(new double[]{-3,-1,-2});
		dataSet.add(new double[]{-1,-3,-2});

		k.setDataSet(dataSet);
		k.execute();
		ArrayList<ArrayList<double[]>> cluster=k.getCluster();
		for(int i=0;i<cluster.size();i++){
			k.printDataArray(cluster.get(i), "cluster["+i+"]");
		}
	}
}

没啥难度,也就是输入写初始数据,然后执行k-means在进行分类,最后打印一下。这个原型代码很粗糙,没有添加聚类个数以及循环次数的变量,这些需要自己动手啦。

2.k-means应用图像分割

我们可以把k-means聚类放在图像分割上,也就是说把一个颜色的像素分为一类,然后再涂一个颜色。像这样。

左边就是聚类之前的,右边是聚类之后的,看起来还是满炫酷的。其实聚类算法也是很容易扩展到这里的。

有下面四个提示(因为是作业,我决定先不放马,不然到时候作业雷同我的学分就咖喱gaygay了):

①.上面的原型代码是对二维的数据进行分类,那我们也知道,一个颜色有RGB三种原色构成,也就是说我们只需要 在二维的基础上,加上一维数据就吼啦。很简单有木有,改变下数组结构,在距离计算编程三维欧式距离就吼。

②.Java有自带的图像处理类,所以读取数据敲击方便。我给一点代码提示哦

//读取指定目录的图片数据,并且写入数组,这个数据要继续处理
	private int[][] getImageData(String path){
		BufferedImage bi=null;
		try{
			bi=ImageIO.read(new File(path));
		}catch (IOException e){
			e.printStackTrace();
		}
		int width=bi.getWidth();
		int height=bi.getHeight();
		int [][] data=new int[width][height];
		for(int i=0;i<width;i++)
			for(int j=0;j<height;j++)
				data[i][j]=bi.getRGB(i, j);
		/*测试输出
		for(int i=0;i<data.length;i++)
			for(int j=0;j<data[0].length;j++)
				System.out.println(data[i][j]);*/
		return data;
	}
	//用来处理获取的像素数据,提取我们需要的写入dataItem数组
	private dataItem[][] InitData(int [][] data){
		dataItem[][] dataitems=new dataItem[data.length][data[0].length];
		for(int i=0;i<data.length;i++){
			for(int j=0;j<data[0].length;j++){
				dataItem di=new dataItem();
				Color c=new Color(data[i][j]);
				di.r=(double)c.getRed();
				di.g=(double)c.getGreen();
				di.b=(double)c.getBlue();
				di.group=1;
				dataitems[i][j]=di;
			}
		}
		return dataitems;
	}
          //介货是用来输出图像的
<pre name="code" class="java">           private void ImagedataOut(String path){
		Color c0=new Color(255,0,0);
		Color c1=new Color(0,255,0);
		Color c2=new Color(0,0,255);
		Color c3=new Color(128,128,128);
		BufferedImage nbi=new BufferedImage(source.length,source[0].length,BufferedImage.TYPE_INT_RGB);
		for(int i=0;i<source.length;i++){
			for(int j=0;j<source[0].length;j++){
				if(source[i][j].group==0)
					nbi.setRGB(i, j, c0.getRGB());
				else if(source[i][j].group==1)
					nbi.setRGB(i, j, c1.getRGB());
				else if(source[i][j].group==2)
					nbi.setRGB(i, j, c2.getRGB());
				else if (source[i][j].group==3)
					nbi.setRGB(i, j, c3.getRGB());
				//Color c=new Color((int)center[source[i][j].group].r,
				//		(int)center[source[i][j].group].g,(int)center[source[i][j].group].b);
				//nbi.setRGB(i, j, c.getRGB());
			}
		}
		try{
			ImageIO.write(nbi, "jpg", new File(path));
		}catch(IOException e){
			e.printStackTrace();
			}
	}

很舒爽,你问我dataItem是啥?等我交完作业我就告诉你。

③.有一点不同的是,注意数据格式。胖胖开始用的就是int类型,结果在计算新的聚类中心的时候溢出了呢。。。所幸鹏鹏改成了double,但是鹏鹏在计算距离的时候又写错了,最后还是机智的胖胖鹏解决掉了所有的bug。

④.注意读取图片的时候保护好数据的顺序,也就是用一个二维数组来存储,这样在写的时候就不用记录像素点的位置,输出的时候也很方便。

就是这些。。。。等我作业交完就来一次完整的代码讲解!

时间: 2024-11-03 03:47:29

【Java】K-means算法Java实现以及图像分割的相关文章

k近邻算法-java实现

最近在看<机器学习实战>这本书,因为自己本身很想深入的了解机器学习算法,加之想学python,就在朋友的推荐之下选择了这本书进行学习. 一 . K-近邻算法(KNN)概述 最简单最初级的分类器是将全部的训练数据所对应的类别都记录下来,当测试对象的属性和某个训练对象的属性完全匹配时,便可以对其进行分类.但是怎么可能所有测试对象都会找到与之完全匹配的训练对象呢,其次就是存在一个测试对象同时与多个训练对象匹配,导致一个训练对象被分到了多个类的问题,基于这些问题呢,就产生了KNN. KNN是通过测量不

java每日小算法(4)

[程序4] 题目:将一个正整数分解质因数.例如:输入90,打印出90=2*3*3*5. 程序分析:对n进行分解质因数,应先找到一个最小的质数k,然后按下述步骤完成: (1)如果这个质数恰等于n,则说明分解质因数的过程已经结束,打印出即可. (2)如果n<>k,但n能被k整除,则应打印出k的值,并用n除以k的商,作为新的正整数你n,重复执行第一步. (3)如果n不能被k整除,则用k+1作为k的值,重复执行第一步. package test; import java.util.ArrayList;

JVM学习(4)——全面总结Java的GC算法和回收机制---转载自http://www.cnblogs.com/kubixuesheng/p/5208647.html

俗话说,自己写的代码,6个月后也是别人的代码--复习!复习!复习!涉及到的知识点总结如下: 一些JVM的跟踪参数的设置 Java堆的分配参数 -Xmx 和 –Xms 应该保持一个什么关系,可以让系统的性能尽可能的好呢?是不是虚拟机内存越大越好? Java 7之前和Java 8的堆内存结构 Java栈的分配参数 GC算法思想介绍 –GC ROOT可达性算法 –标记清除 –标记压缩 –复制算法 可触及性含义和在Java中的体现 finalize方法理解 Java的强引用,软引用,弱引用,虚引用 GC

Java密码学原型算法实现——第一部分:标准Hash算法

题注 从博客中看出来我是个比较钟爱Java的应用密码学研究者.虽然C在密码学中有不可替代的优势:速度快,但是,Java的可移植性使得开发人员可以很快地将代码移植到各个平台,这比C实现要方便的多.尤其是Android平台的出现,Java的应用也就越来越广.因此,我本人在密码学研究过程中实际上也在逐渐使用和封装一些知名的Java密码学库,主要是方便自己使用. Java JDK实际上自带了密码学库,支持几乎所有通用密码学原型的实现.然而,自带密码库有几个缺点:第一,由于版权问题,其并不支持全部的密码学

银行家算法java实现

关于银行家算法的理论知识,课本或者百度上有好多资料,我就不再多说了,这里把我最近写的银行家算法的实现带码贴出来. 由于这是我们的一个实验,对系统资源数和进程数都指定了,所以这里也将其指定了,其中系统资源数为3,进程数为5. import java.util.Scanner; import javax.swing.plaf.basic.BasicInternalFrameTitlePane.MaximizeAction; import javax.swing.text.StyledEditorKi

使用Java实现一则算法

[Java] 使用Java实现一则算法 前情提要 在学习Java的过程中,我的一个基友扔给了我一道算法题,为了检验自己对Java的学习情况我决定使用Java解决这道算法题. 具体问题 现有一株K叉树,我们知道其前序遍历与后序遍历,也知道K的值,求该K叉树有多少种可能形态.如一13叉树,前序遍历为abejkcfghid,后序遍历为jkebfghicda,则其可能形态有207352860种. 问题分析 根据遍历的定义我们可以知道: 前序遍历的第一个字母和后序遍历的最后一个字母是他的根. 前序遍历的根

Java常用排序算法+程序员必须掌握的8大排序算法+二分法查找法

Java 常用排序算法/程序员必须掌握的 8大排序算法 本文由网络资料整理转载而来,如有问题,欢迎指正! 分类: 1)插入排序(直接插入排序.希尔排序) 2)交换排序(冒泡排序.快速排序) 3)选择排序(直接选择排序.堆排序) 4)归并排序 5)分配排序(基数排序) 所需辅助空间最多:归并排序 所需辅助空间最少:堆排序 平均速度最快:快速排序 不稳定:快速排序,希尔排序,堆排序. 先来看看 8种排序之间的关系: 1.直接插入排序 (1)基本思想:在要排序的一组数中,假设前面(n-1)[n>=2]

排序算法Java实现

排序算法Java实现 排序算法的分类: 内部排序,在排序过程中,全部记录放在内存中,称为内部排序: 外部排序,在排序过程中需要使用外部存储(磁盘),则称为外部排序. 主要介绍内部排序: 插入排序:直接插入排序.二分法插入排序.希尔排序 选择排序:简单选择排序.堆排序 交换排序:冒泡排序.快速排序 归并排序 基数排序 插入排序 直接插入排序 基本思想:对于给定的一组记录,初始时假设第一个记录自成一个有序序列,其余记录为无序序列.接着从第二个记录开始,按照记录的大小依次将当前处理的记录插入到其之前的

排序算法 Java实现版

8种排序之间的关系: 1. 直接插入排序 (1)基本思想: 在要排序的一组数中,假设前面(n-1)[n>=2] 个数已经是排好顺序的,现在要把第n个数插到前面的有序数中,使得这n个数也是排好顺序的.如此反复循环,直到全部排好顺序. (2)实例 (3)用java实现 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 package com.njue; public class insertSort { public insertSort(){     inta[]