java实现K-means聚类算法

import java.util.ArrayList;
import java.util.Random;

/**
 * K均值聚类算法
 */
public class Kmeans {
    private int numOfCluster;// 分成多少簇
    private int timeOfIteration;// 迭代次数
    private int dataSetLength;// 数据集元素个数,即数据集的长度
    private ArrayList<float[]> dataSet;// 数据集链表
    private ArrayList<float[]> center;// 中心链表
    private ArrayList<ArrayList<float[]>> cluster; //簇
    private ArrayList<Float> sumOfErrorSquare;// 误差平方和
    private Random random;

    /**
     * 设置需分组的原始数据集
     *
     * @param dataSet
     */

    public void setDataSet(ArrayList<float[]> dataSet) {
        this.dataSet = dataSet;
    }

    /**
     * 获取结果分组
     *
     * @return 结果集
     */

    public ArrayList<ArrayList<float[]>> getCluster() {
        return cluster;
    }

    /**
     * 构造函数,传入需要分成的簇数量
     *
     * @param numOfCluster
     *    簇数量,若numOfCluster<=0时,设置为1,若numOfCluster大于数据源的长度时,置为数据源的长度
     */
    public Kmeans(int numOfCluster) {
        if (numOfCluster <= 0) {
            numOfCluster = 1;
        }
        this.numOfCluster = numOfCluster;
    }

    /**
     * 初始化
     */
    private void init() {
        timeOfIteration = 0;
        random = new Random();
        //如果调用者未初始化数据集,则采用内部测试数据集
        if (dataSet == null || dataSet.size() == 0) {
            initDataSet();
        }
        dataSetLength = dataSet.size();
        //若numOfCluster大于数据源的长度时,置为数据源的长度
        if (numOfCluster > dataSetLength) {
            numOfCluster = dataSetLength;
        }
        center = initCenters();
        cluster = initCluster();
        sumOfErrorSquare = new ArrayList<Float>();
    }

    /**
     * 如果调用者未初始化数据集,则采用内部测试数据集
     */
    private void initDataSet() {
        dataSet = new ArrayList<float[]>();
        // 其中{6,3}是一样的,所以长度为15的数据集分成14簇和15簇的误差都为0
        float[][] dataSetArray = new float[][] { { 8, 2 }, { 3, 4 }, { 2, 5 },
                { 4, 2 }, { 7, 3 }, { 6, 2 }, { 4, 7 }, { 6, 3 }, { 5, 3 },
                { 6, 3 }, { 6, 9 }, { 1, 6 }, { 3, 9 }, { 4, 1 }, { 8, 6 } };

        for (int i = 0; i < dataSetArray.length; i++) {
            dataSet.add(dataSetArray[i]);
        }
    }

    /**
     * 初始化中心数据链表,分成多少簇就有多少个中心点
     *
     * @return 中心点集
     */
    private ArrayList<float[]> initCenters() {
        ArrayList<float[]> center = new ArrayList<float[]>();
        int[] randoms = new int[numOfCluster];
        boolean flag;
        int temp = random.nextInt(dataSetLength);
        randoms[0] = temp;
        //randoms数组中存放数据集的不同的下标
        for (int i = 1; i < numOfCluster; i++) {
            flag = true;
            while (flag) {
                temp = random.nextInt(dataSetLength);

                int j=0;
                for(j=0; j<i; j++){
                    if(randoms[j] == temp){
                        break;
                    }
                }

                if (j == i) {
                    flag = false;
                }
            }
            randoms[i] = temp;
        }

        // 测试随机数生成情况
        // for(int i=0;i<numOfCluster;i++)
        // {
        // System.out.println("test1:randoms["+i+"]="+randoms[i]);
        // }

        // System.out.println();

        for (int i = 0; i < numOfCluster; i++) {
            center.add(dataSet.get(randoms[i]));// 生成初始化中心链表
        }
        return center;
    }

    /**
     * 初始化簇集合
     *
     * @return 一个分为k簇的空数据的簇集合
     */
    private ArrayList<ArrayList<float[]>> initCluster() {
        ArrayList<ArrayList<float[]>> cluster = new ArrayList<ArrayList<float[]>>();
        for (int i = 0; i < numOfCluster; i++) {
            cluster.add(new ArrayList<float[]>());
        }

        return cluster;
    }

    /**
     * 计算两个点之间的距离
     *
     * @param element
     *            点1
     * @param center
     *            点2
     * @return 距离
     */
    private float distance(float[] element, float[] center) {
        float distance = 0.0f;
        float x = element[0] - center[0];
        float y = element[1] - center[1];
        float z = x * x + y * y;
        distance = (float) Math.sqrt(z);

        return distance;
    }

    /**
     * 获取距离集合中最小距离的位置
     *
     * @param distance
     *            距离数组
     * @return 最小距离在距离数组中的位置
     */
    private int minDistance(float[] distance) {
        float minDistance = distance[0];
        int minLocation = 0;
        for (int i = 1; i < distance.length; i++) {
            if (distance[i] <= minDistance) {
                minDistance = distance[i];
                minLocation = i;
            }
        }
        return minLocation;
    }

    /**
     * 核心,将当前元素放到最小距离中心相关的簇中
     */
    private void clusterSet() {
        float[] distance = new float[numOfCluster];
        for (int i = 0; i < dataSetLength; i++) {
            for (int j = 0; j < numOfCluster; j++) {
                distance[j] = distance(dataSet.get(i), center.get(j));
                // System.out.println("test2:"+"dataSet["+i+"],center["+j+"],distance="+distance[j]);
            }
            int minLocation = minDistance(distance);
            // System.out.println("test3:"+"dataSet["+i+"],minLocation="+minLocation);
            // System.out.println();

            cluster.get(minLocation).add(dataSet.get(i));// 核心,将当前元素放到最小距离中心相关的簇中

        }
    }

    /**
     * 求两点误差平方的方法
     *
     * @param element
     *            点1
     * @param center
     *            点2
     * @return 误差平方
     */
    private float errorSquare(float[] element, float[] center) {
        float x = element[0] - center[0];
        float y = element[1] - center[1];

        float errSquare = x * x + y * y;

        return errSquare;
    }

    /**
     * 计算误差平方和准则函数方法
     */
    private void countRule() {
        float jcF = 0;
        for (int i = 0; i < cluster.size(); i++) {
            for (int j = 0; j < cluster.get(i).size(); j++) {
                jcF += errorSquare(cluster.get(i).get(j), center.get(i));

            }
        }
        sumOfErrorSquare.add(jcF);
    }

    /**
     * 设置新的簇中心方法
     */
    private void setNewCenter() {
        for (int i = 0; i < numOfCluster; i++) {
            int n = cluster.get(i).size();
            if (n != 0) {
                float[] newCenter = { 0, 0 };
                for (int j = 0; j < n; j++) {
                    newCenter[0] += cluster.get(i).get(j)[0];
                    newCenter[1] += cluster.get(i).get(j)[1];
                }
                // 设置一个平均值
                newCenter[0] = newCenter[0] / n;
                newCenter[1] = newCenter[1] / n;
                center.set(i, newCenter);
            }
        }
    }

    /**
     * 打印数据,测试用
     *
     * @param dataArray
     *            数据集
     * @param dataArrayName
     *            数据集名称
     */
    public void printDataArray(ArrayList<float[]> dataArray,
                               String dataArrayName) {
        for (int i = 0; i < dataArray.size(); i++) {
            System.out.println("print:" + dataArrayName + "[" + i + "]={"
                    + dataArray.get(i)[0] + "," + dataArray.get(i)[1] + "}");
        }
        System.out.println("===================================");
    }

    /**
     * Kmeans算法核心过程方法
     */
    private void kmeans() {
        init();
        // printDataArray(dataSet,"initDataSet");
        // printDataArray(center,"initCenter");

        // 循环分组,直到误差不变为止
        while (true) {
            clusterSet();
            // for(int i=0;i<cluster.size();i++)
            // {
            // printDataArray(cluster.get(i),"cluster["+i+"]");
            // }

            countRule();

            // System.out.println("count:"+"sumOfErrorSquare["+timeOfIteration+"]="+sumOfErrorSquare.get(timeOfIteration));

            // System.out.println();
            // 误差不变了,分组完成
            if (timeOfIteration != 0) {
                if (sumOfErrorSquare.get(timeOfIteration) - sumOfErrorSquare.get(timeOfIteration - 1) == 0) {
                    break;
                }
            }

            setNewCenter();
            // printDataArray(center,"newCenter");
            timeOfIteration++;
            cluster.clear();
            cluster = initCluster();
        }

        // System.out.println("note:the times of repeat:timeOfIteration="+timeOfIteration);//输出迭代次数
    }

    /**
     * 执行算法
     */
    public void execute() {
        long startTime = System.currentTimeMillis();
        System.out.println("kmeans begins");
        kmeans();
        long endTime = System.currentTimeMillis();
        System.out.println("kmeans running time=" + (endTime - startTime)
                + "ms");
        System.out.println("kmeans ends");
        System.out.println();
    }
}
package com.bianlifeng.algorithm;

import java.util.ArrayList;

public class KmeansTest {
    public  static void main(String[] args)
    {
        //初始化一个Kmean对象,将k置为10
        Kmeans k=new Kmeans(3);
        ArrayList<float[]> dataSet=new ArrayList<float[]>();

        dataSet.add(new float[]{1,2});
        dataSet.add(new float[]{3,3});
        dataSet.add(new float[]{3,4});
        dataSet.add(new float[]{5,6});
        dataSet.add(new float[]{8,9});
        dataSet.add(new float[]{4,5});
        dataSet.add(new float[]{6,4});
        dataSet.add(new float[]{3,9});
        dataSet.add(new float[]{5,9});
        dataSet.add(new float[]{4,2});
        dataSet.add(new float[]{1,9});
        dataSet.add(new float[]{7,8});
        dataSet.add(new float[]{1,1});
        dataSet.add(new float[]{2,1});
        dataSet.add(new float[]{2,2});
        dataSet.add(new float[]{6,9});

        //设置原始数据集
        k.setDataSet(dataSet);
        //执行算法
        k.execute();
        //得到聚类结果
        ArrayList<ArrayList<float[]>> cluster=k.getCluster();
        //查看结果
        for(int i=0;i<cluster.size();i++)
        {
            k.printDataArray(cluster.get(i), "cluster["+i+"]");
        }

    }
}

From:

http://blog.csdn.net/cyxlzzs/article/details/7416491

原文地址:https://www.cnblogs.com/Allen-win/p/8298181.html

时间: 2024-12-20 20:20:12

java实现K-means聚类算法的相关文章

k-均值聚类算法;二分k均值聚类算法

根据<机器学习实战>一书第十章学习k均值聚类算法和二分k均值聚类算法,自己把代码边敲边理解了一下,修正了一些原书中代码的细微差错.目前代码有时会出现如下4种报错信息,这有待继续探究和完善. 报错信息: Warning (from warnings module): File "F:\Python2.7.6\lib\site-packages\numpy\core\_methods.py", line 55 warnings.warn("Mean of empty

机器学习实战笔记-利用K均值聚类算法对未标注数据分组

聚类是一种无监督的学习,它将相似的对象归到同一个簇中.它有点像全自动分类.聚类方法几乎可以应用于所有对象,簇内的对象越相似,聚类的效果越好 簇识别给出聚类结果的含义.假定有一些数据,现在将相似数据归到一起,簇识别会告诉我们这些簇到底都是些什么.聚类与分类的最大不同在于,分类的目标事先巳知,而聚类则不一样.因为其产生的结果与分类相同,而只是类别没有预先定义,聚类有时也被称为无监督分类(unsupervised classification ). 聚类分析试图将相似对象归人同一簇,将不相似对象归到不

K均值聚类算法

k均值聚类算法(k-means clustering algorithm)是一种迭代求解的聚类分析算法,其步骤是随机选取K个对象作为初始的聚类中心,然后计算每个对象与各个种子聚类中心之间的距离,把每个对象分配给距离它最近的聚类中心.聚类中心以及分配给它们的对象就代表一个聚类.每分配一个样本,聚类的聚类中心会根据聚类中现有的对象被重新计算.这个过程将不断重复直到满足某个终止条件.终止条件可以是没有(或最小数目)对象被重新分配给不同的聚类,没有(或最小数目)聚类中心再发生变化,误差平方和局部最小.

基于改进人工蜂群算法的K均值聚类算法(附MATLAB版源代码)

其实一直以来也没有准备在园子里发这样的文章,相对来说,算法改进放在园子里还是会稍稍显得格格不入.但是最近邮箱收到的几封邮件让我觉得有必要通过我的博客把过去做过的东西分享出去更给更多需要的人.从论文刊登后,陆陆续续收到本科生.研究生还有博士生的来信和短信微信等,表示了对论文的兴趣以及寻求算法的效果和实现细节,所以,我也就通过邮件或者短信微信来回信,但是有时候也会忘记回复. 另外一个原因也是时间久了,我对于论文以及改进的算法的记忆也越来越模糊,或者那天无意间把代码遗失在哪个角落,真的很难想象我还会全

K均值聚类算法的MATLAB实现

1.K-均值聚类法的概述 之前在参加数学建模的过程中用到过这种聚类方法,但是当时只是简单知道了在matlab中如何调用工具箱进行聚类,并不是特别清楚它的原理.最近因为在学模式识别,又重新接触了这种聚类算法,所以便仔细地研究了一下它的原理.弄懂了之后就自己手工用matlab编程实现了,最后的结果还不错,嘿嘿~~~ 简单来说,K-均值聚类就是在给定了一组样本(x1, x2, ...xn) (xi, i = 1, 2, ... n均是向量) 之后,假设要将其聚为 m(<n) 类,可以按照如下的步骤实现

k means聚类过程

k-means是一种非监督 (从下图0 当中我们可以看到训练数据并没有标签标注类别)的聚类算法 0.initial 1.select centroids randomly 2.assign points 3.update centroids 4.reassign points 5.update centroids 6.reassign points 7.iteration reference: https://www.naftaliharris.com/blog/visualizing-k-me

聚类之K均值聚类和EM算法

这篇博客整理K均值聚类的内容,包括: 1.K均值聚类的原理: 2.初始类中心的选择和类别数K的确定: 3.K均值聚类和EM算法.高斯混合模型的关系. 一.K均值聚类的原理 K均值聚类(K-means)是一种基于中心的聚类算法,通过迭代,将样本分到K个类中,使得每个样本与其所属类的中心或均值的距离之和最小. 1.定义损失函数 假设我们有一个数据集{x1, x2,..., xN},每个样本的特征维度是m维,我们的目标是将数据集划分为K个类别.假定K的值已经给定,那么第k个类别的中心定义为μk,k=1

[聚类算法]K-means优缺点及其改进

[聚类算法]K-means优缺点及其改进 [转]:http://blog.csdn.net/u010536377/article/details/50884416 K-means聚类小述 大家接触的第一个聚类方法,十有八九都是K-means聚类啦.该算法十分容易理解,也很容易实现.其实几乎所有的机器学习和数据挖掘算法都有其优点和缺点.那么K-means的缺点是什么呢? 总结为下: (1)对于离群点和孤立点敏感: (2)k值选择; (3)初始聚类中心的选择: (4)只能发现球状簇. 对于这4点呢的

scikit-learn学习之K-means聚类算法与 Mini Batch K-Means算法

====================================================================== 本系列博客主要参考 Scikit-Learn 官方网站上的每一个算法进行,并进行部分翻译,如有错误,请大家指正 转载请注明出处 ====================================================================== K-means算法分析与Python代码实现请参考之前的两篇博客: <机器学习实战>k