4. K-Means和K-Means++实现

1. K-Means原理解析

2. K-Means的优化

3. sklearn的K-Means的使用

4. K-Means和K-Means++实现

1. 前言

前面3篇K-Means的博文从原理、优化、使用几个方面详细的介绍了K-Means算法,本文用python算法,详细的为读者实现一下K-Means。代码是本人修改完成,效率虽远不及sklearn,但是它的作用是在帮助同学们能从代码中去理解K-Means算法。后面我会慢慢的把所有的机器学习方面的算法,尽我所能的去实现一遍。

2. KMeans基本框架实现

先实现一个基本的kmeans,代码如下,需要查看完整代码的同学请移步至我的github

class KMeansBase(object):

    def __init__(self, n_clusters = 8, init = "random", max_iter = 300, random_state = None, n_init = 10, tol = 1e-4):
        self.k = n_clusters # 聚类个数
        self.init = init # 输出化方式
        self.max_iter = max_iter # 最大迭代次数
        self.random_state = check_random_state(random_state) #随机数
        self.n_init = n_init # 进行多次聚类,选择最好的一次
        self.tol = tol # 停止聚类的阈值

    # fit对train建立模型
    def fit(self, dataset):
        self.tol = self._tolerance(dataset, self.tol)

        bestError = None
        bestCenters = None
        bestLabels = None
        for i in range(self.n_init):
            labels, centers, error = self._kmeans(dataset)
            if bestError == None or error < bestError:
                bestError = error
                bestCenters = centers
                bestLabels = labels
        self.centers = bestCenters
        return bestLabels, bestCenters, bestError

    # predict根据训练好的模型预测新的数据
    def predict(self, X):
        return self.update_labels_error(X, self.centers)[0]

    # 合并fit和predict
    def fit_predict(self, dataset):
        self.fit(dataset)
        return self.predict(dataset)

    # kmeans的主要方法,完成一次聚类的过程
    def _kmeans(self, dataset):
        self.dataset = np.array(dataset)
        bestError = None
        bestCenters = None
        bestLabels = None
        centerShiftTotal = 0
        centers = self._init_centroids(dataset)

        for i in range(self.max_iter):
            oldCenters = centers.copy()
            labels, error = self.update_labels_error(dataset, centers)
            centers = self.update_centers(dataset, labels)

            if bestError == None or error < bestError:
                bestLabels = labels.copy()
                bestCenters = centers.copy()
                bestError = error

            ## oldCenters和centers的偏移量
            centerShiftTotal = np.linalg.norm(oldCenters - centers) ** 2
            if centerShiftTotal <= self.tol:
                break

        #由于上面的循环,最后一步更新了centers,所以如果和旧的centers不一样的话,再更新一次labels,error
        if centerShiftTotal > 0:
            bestLabels, bestError = self.update_labels_error(dataset, bestCenters)

        return bestLabels, bestCenters, bestError

    # k个数据点,随机生成
    def _init_centroids(self, dataset):
        n_samples = dataset.shape[0]
        centers = []
        if self.init == "random":
            seeds = self.random_state.permutation(n_samples)[:self.k]
            centers = dataset[seeds]
        elif self.init == "k-means++":
            pass
        return np.array(centers)

    # 把tol和dataset相关联
    def _tolerance(self, dataset, tol):
        variances = np.var(dataset, axis=0)
        return np.mean(variances) * tol

    # 更新每个点的标签,和计算误差
    def update_labels_error(self, dataset, centers):
        labels = self.assign_points(dataset, centers)
        new_means = defaultdict(list)
        error = 0
        for assignment, point in zip(labels, dataset):
            new_means[assignment].append(point)

        for points in new_means.values():
            newCenter = np.mean(points, axis=0)
            error += np.sqrt(np.sum(np.square(points - newCenter)))

        return labels, error

    # 更新中心点
    def update_centers(self, dataset, labels):
        new_means = defaultdict(list)
        centers = []
        for assignment, point in zip(labels, dataset):
            new_means[assignment].append(point)

        for points in new_means.values():
            newCenter = np.mean(points, axis=0)
            centers.append(newCenter)

        return np.array(centers)

    # 分配每个点到最近的center
    def assign_points(self, dataset, centers):
        labels = []
        for point in dataset:
            shortest = float("inf")  # 正无穷
            shortest_index = 0
            for i in range(len(centers)):
                val = distance(point[np.newaxis], centers[i])
                if val < shortest:
                    shortest = val
                    shortest_index = i
            labels.append(shortest_index)
        return labels

上面是我实现的基本的以EM算法为基础的一个KMeans的算法过程,我接口设计和参数形式尽量模范sklearn的方式,方面熟悉sklearn的同学接受起来比较快。

3. KMeans++实现

kmeans++的原理在之前有介绍。这里为了配合代码,再介绍一遍。

  1. 从输入的数据点集合中随机选择一个点作为第一个聚类中心\(\mu_1\).
  2. 对于数据集中的每一个点\(x_i\),计算它与已选择的聚类中心中最近聚类中心的距离.
    \[
    D(x_i) = arg\;min|x_i-\mu_r|^2\;\;r=1,2,...k_{selected}
    \]
  3. 选择一个新的数据点作为新的聚类中心,选择的原则是:\(D(x)\)较大的点,被选取作为聚类中心的概率较大
  4. 重复2和3直到选择出k个聚类质心。
  5. 利用这k个质心来作为初始化质心去运行标准的K-Means算法。

其中比较关键的是第2、3步,请看具体实现过程:

# kmeans++的初始化方式,加速聚类速度
def _k_means_plus_plus(self, dataset):
    n_samples, n_features = dataset.shape
    centers = np.empty((self.k, n_features))
    # n_local_trials是每次选择候选点个数
    n_local_trials = None
    if n_local_trials is None:
        n_local_trials = 2 + int(np.log(self.k))

    # 第一个随机点
    center_id = self.random_state.randint(n_samples)
    centers[0] = dataset[center_id]

    # closest_dist_sq是每个样本,到所有中心点最近距离
    # 假设现在有3个中心点,closest_dist_sq = [min(样本1到3个中心距离),min(样本2到3个中心距离),...min(样本n到3个中心距离)]
    closest_dist_sq = distance(centers[0, np.newaxis], dataset)

    # current_pot所有最短距离的和
    current_pot = closest_dist_sq.sum()

    for c in range(1, self.k):
        # 选出n_local_trials随机址,并映射到current_pot的长度
        rand_vals = self.random_state.random_sample(n_local_trials) * current_pot
        # np.cumsum([1,2,3,4]) = [1, 3, 6, 10],就是累加当前索引前面的值
        # np.searchsorted搜索随机出的rand_vals落在np.cumsum(closest_dist_sq)中的位置。
        # candidate_ids候选节点的索引
        candidate_ids = np.searchsorted(np.cumsum(closest_dist_sq), rand_vals)

        # best_candidate最好的候选节点
        # best_pot最好的候选节点计算出的距离和
        # best_dist_sq最好的候选节点计算出的距离列表
        best_candidate = None
        best_pot = None
        best_dist_sq = None
        for trial in range(n_local_trials):
            # 计算每个样本到候选节点的欧式距离
            distance_to_candidate = distance(dataset[candidate_ids[trial], np.newaxis], dataset)

            # 计算每个候选节点的距离序列new_dist_sq, 距离总和new_pot
            new_dist_sq = np.minimum(closest_dist_sq, distance_to_candidate)
            new_pot = new_dist_sq.sum()

            # 选择最小的new_pot
            if (best_candidate is None) or (new_pot < best_pot):
                best_candidate = candidate_ids[trial]
                best_pot = new_pot
                best_dist_sq = new_dist_sq

        centers[c] = dataset[best_candidate]
        current_pot = best_pot
        closest_dist_sq = best_dist_sq

    return centers

4. 效果比较

用kmeans_base和kmeans++和sklearn的kmeans对sklearn中自带的数据集iris、boston房价、digits进行聚类,比较速度和聚类效果比较。





5. 总结

Kmeans的算法讲解靠一段落,有兴趣的同学们可以去实践下我在优化中提到的另外两个优化方法,elkan减少计算距离的次数,Mini Batch处理大样本的情况下,计算的速度。

原文地址:https://www.cnblogs.com/huangyc/p/10274001.html

时间: 2024-10-30 16:12:15

4. K-Means和K-Means++实现的相关文章

imshow(K)和imshow(K,[]) 的区别

参考文献 imshow(K)直接显示K:imshow(K,[])显示K,并将K的最大值和最小值分别作为纯白(255)和纯黑(0),中间的K值映射为0到255之间的标准灰度值.

求 区间[a,b]内满足p^k*q*^m(k&gt;m)的数的个数

题目描述: 1<=a,b<=10^18,p,q都是素数  2<=p,q<=10^9; 求在[a,b]内可以表示为  x*p^k*q^m  k > m   的数的个数 分析: 由于要小于b,因此m一定小于 log10(b)/log10(p*q); 因此我们可以枚举m,中间计数的时候需要用到容斥. 具体看代码: #include <iostream> #include <cstdio> #include <cmath> #include <

C++链表K个节点K个节点的反转((1,2,3,4),如果k是2,反转结果是(2,1,4,3))

#include <iostream> using namespace std; struct Node { int val; struct Node *next; Node(int x = int()):val(x),next(NULL){} }; struct List { List() { head=NULL; } void Insert(int x) { if(head==NULL) { head = new Node(x); } else { Node *p = head; Node

约瑟夫环 数学解法 f(n,k)=(f(n-1,k)+k)%n 公式讲解

问题:有n个人站成环 从1开始报数,报k的人去死,之后下一个人报1,问当你是第几个的时候可以活下来? 这篇文章主要是讲解  f(n,k)=(f(n-1,k)+k)%n 这个公式是什么意思为什么是对的 虽然公式是使用数学解法 但开始时我会手动的模拟过程 其是有意义的 十分有助于理解 首先我们看样一个问题 n=2, k=3 a b 我们首先使用人力来数 a b a 很好 a死 接下来在试一遍 n=2 k=4 a b 人力:a b a b 很好b死 n=2 k=5 人力 a b a b a 很好a死

给定一个非负索引 k,其中 k ≤ 33,返回杨辉三角的第 k 行。

从第0行开始,输出第k行,传的参数为第几行,所以在方法中先将所传参数加1,然后将最后一行加入集合中返回. 代码如下: public static List<Integer> generateII(int row){ ++row; List<Integer> list = new ArrayList<Integer>(); int[][] arr = new int[row][row]; for(int j = 0;j<row;j++) { for(int k =

SortedList&lt;T,K&gt;,SortedDictionary&lt;T,K&gt;,Dictionay&lt;T,K&gt;用法区别

这三货都是键值对,都可以通过Key获取Value.   Dictionay<T,K> SortedDictionary<T,K> SortedList<T,K> 支持通过Index获取元素? 否 否 是 遍历时的排序方式 随机,与hash算法有关 默认用Key的值排序,而非插入顺序.可通过构造器传入自定义的排序方法. 每次插入新值都会与现有项比较,可能导致列表重置. 可查找索引 否 否 是 内存使用   多 少 插入.移除性能   慢 快 抛开性能和内部实现,Sorte

【开源】专业K线绘制[K线主副图、趋势图、成交量、滚动、放大缩小、MACD、KDJ等)

这是最近一个iOS项目需要使用的K线的绘制,在网上大量查阅资料无果,只好自行绘制. 实时数据使用来源API: https://www.btc123.com/kline/klineapi 返回数据说明: 1.时间戳 2.开盘价 3.最高价 4.最低价 5.收盘价 6.成交量 实现功能包括K线主副图.趋势图.成交量.滚动.放大缩小.MACD.KDJ,长按显示辅助线等功能 预览图 最后的最后,这是项目的开源地址:https://github.com/yate1996/Y_KLine,如果帮到了你,麻烦

[LeetCode] Remove K Digits 去掉K位数字

Given a non-negative integer num represented as a string, remove k digits from the number so that the new number is the smallest possible. Note: The length of num is less than 10002 and will be ≥ k. The given num does not contain any leading zero. Ex

Merge k Sorted Lists, k路归并

import java.util.Arrays; import java.util.List; import java.util.PriorityQueue; /* class ListNode { ListNode next; int val; ListNode(int x) { val = x; } } */ //k路归并问题 public class MergKSortedLists { //二路归并,这个算法时间复杂度o(2n) public ListNode mergeTwoLists

android 股票数据通过日K获取周K的数据 算法 源码

目前的数据是从新浪接口获取的, http://biz.finance.sina.com.cn/stock/flash_hq/kline_data.php?symbol=sh600000&end_date=20141120&begin_date=20120101 返回数据为XML格式: 1 <?xml version="1.0" encoding="UTF-8"?> 2 <control> 3 <content d=&qu