网上有许多Kmeans写的java算法,当然依据个人编码风格的不同,导致编写出来的代码,各有不同。所以在理解原理的基础上,最好就是按照自己设计思路将代码自己写出来。
度娘搜Kmeans的基本原理吧,直接上代码,代码中都有注释:
package net.codeal.suanfa.kmeans; import java.util.Set; /** * * @ClassName: Distancable * @Description: TODO(可计算两点之间距离的可中心化的父类) * @author fuhuaguo * @date 2015年9月1日 上午11:41:23 * */ public class Kmeansable<E> { /** * 获取两点之间的距离 * @param other * @return */ public double getDistance(E other){ return 0; } /** * 获取新的中心点 * @param eSet * @return */ public E getNewCenter(Set<E> eSet){ return null; } }
package net.codeal.suanfa.kmeans; import java.util.Set; /** * * @ClassName: Point * @Description: TODO(聚类的维度信息bean,可以分为K个维度,相似度计算是自身行为,放在bean内部才合适,取消注解使用) * @author fuhuaguo * @email [email protected] * @date 2015年9月1日 上午10:43:25 * */ public class Point extends Kmeansable<Point>{ private String id; //维度1 private double k1; //维度2 private double k2; //维度3 private double k3; public Point() { } public Point(String id,double k1,double k2,double k3) { this.id = id; this.k1 = k1; this.k2 = k2; this.k3 = k3; } /** * 计算和另一个点的距离,采用欧几里得算法 ,计算维度算数平方和的sqrt值,即:相异度 * @param other * @return */ @Override public double getDistance(Point other){ return Math.sqrt((this.k1-other.getK1())*(this.k1-other.getK1()) + (this.k2-other.getK2())*(this.k2-other.getK2()) + (this.k3-other.getK3())*(this.k3-other.getK3())); } @Override public Point getNewCenter(Set<Point> eSet) { if(eSet == null || eSet.size() == 0){ return this; } Point temp = new Point(); int count = 0; for (Point p : eSet) { temp.setK1(temp.getK1() + p.getK1()); temp.setK2(temp.getK2() + p.getK2()); temp.setK3(temp.getK3() + p.getK3()); count++; } temp.setK1(temp.getK1()/count); temp.setK2(temp.getK2()/count); temp.setK3(temp.getK3()/count); return temp; } @Override public boolean equals(Object obj) { if(obj == null || !(obj instanceof Point)) return false; Point other = (Point) obj; return (this.k1 == other.getK1()) && (this.k2 == other.getK2()) && (this.k3 == other.getK3()); } @Override public int hashCode() { return new Double(k1+k2+k3).hashCode(); } @Override public String toString() { return "("+k1+","+k2+","+k3+")"; } public String getId() { return id; } public void setId(String id) { this.id = id; } public double getK1() { return k1; } public void setK1(double k1) { this.k1 = k1; } public double getK2() { return k2; } public void setK2(double k2) { this.k2 = k2; } public double getK3() { return k3; } public void setK3(double k3) { this.k3 = k3; } }
package net.codeal.suanfa.kmeans; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; public class KmeansAlgorithm<E extends Kmeansable<E>> { /** * 对Set进行K个值聚类,计算深度最大为depth */ public void kmeans(Set<E> dataSet, int k, int depth){ //分类数设置不合适 if(k <= 1 || dataSet.size() <= k){ return; } Set<E> kSet = new HashSet<E>(); int count = 0; //随机确定K个中心点 for (E e : dataSet) { if(count++ >= k) break; kSet.add(e); } //计算每个值距离各个中心点的距离,分配到距离最小的那个中心上 boolean flag = true; while(flag && depth > 0){ Map<E, Set<E>> kMap = new HashMap<E, Set<E>>(); for (E e : kSet) { kMap.put(e, new HashSet<E>()); } //完成聚类 for (E data : dataSet) { double d = Double.MAX_VALUE; E e = null; for (E center : kSet) { double d1 = data.getDistance(center); if (d > d1){ e = center; d = d1; } } kMap.get(e).add(data); } //第一组计算完毕,同时获取新的中心点 System.out.println("这是第"+depth+"次聚类"); for (Map.Entry<E, Set<E>> m : kMap.entrySet()) { System.out.println(m.getKey()+":"+m.getValue()); } //获取新的聚类中心 Set<E> oldSet = kSet; kSet = getNewCenters(kMap); flag = !isSameCenters(kSet,oldSet); depth--; } } /** * 获取新的中心点 列表 */ public Set<E> getNewCenters(Map<E, Set<E>> kMap){ Set<E> eSet = new HashSet<E>(); for (Map.Entry<E, Set<E>> m : kMap.entrySet()) { eSet.add(m.getKey().getNewCenter(m.getValue())); } return eSet; } /** * 判断是否为同一个中心列表 */ public boolean isSameCenters(Set<E> oldSet,Set<E> newSet){ //两个集合只要交集为0就是相同的 return oldSet.containsAll(newSet); } public static void main(String[] args) { Set<Point> dataSet = new HashSet<Point>(); dataSet.add(new Point("1",1,1,1)); dataSet.add(new Point("1",2,2,2)); dataSet.add(new Point("1",5,6,1)); dataSet.add(new Point("1",10,10,10)); dataSet.add(new Point("1",11,11,11)); new KmeansAlgorithm<Point>().kmeans(dataSet, 2,10); } }
结果:
这是第10次聚类
(1.0,1.0,1.0):[(1.0,1.0,1.0), (2.0,2.0,2.0), (5.0,6.0,1.0)]
(10.0,10.0,10.0):[(10.0,10.0,10.0), (11.0,11.0,11.0)]
这是第9次聚类
(10.5,10.5,10.5):[(10.0,10.0,10.0), (11.0,11.0,11.0)]
(2.6666666666666665,3.0,1.3333333333333333):[(1.0,1.0,1.0), (2.0,2.0,2.0), (5.0,6.0,1.0)]
版权声明:本文为博主原创文章,未经博主允许不得转载。
时间: 2024-10-28 11:05:51