机器学习之KNN算法思想及其实现

从一个例子来直观感受KNN思想

如下图 , 绿色圆要被决定赋予哪个类,是红色三角形还是蓝色四方形?如果K=3,由于红色三角形所占比例为2/3,绿色圆将被赋予红色三角形那个类,如果K=5,由于蓝色四方形比例为3/5,因此绿色圆被赋予蓝色四方形类。

                    

从这个例子中,我们再来看KNN思想:

1, 计算已知类别数据集合中的点与当前点之间的距离(使用欧式距离公司: d =sqrt(pow(x-x1),2)+pow(y-y1),2)

2, 按照距离递增次序排序(由近到远)

3, 选取与当前点距离最小的的K个点(如上题中的 k=3,k=5)

4, 确定前K个点所在类别的出现频率

5, 将频率最高的那组,作为该点的预测分类

实现代码:

 1 package com.data.knn;
 2
 3 /**
 4  * *********************************************************
 5  * <p/>
 6  * Author:     XiJun.Gong
 7  * Date:       2016-09-06 12:02
 8  * Version:    default 1.0.0
 9  * Class description:
10  * <p/>
11  * *********************************************************
12  */
13 public class Point {
14
15     private double x;  //x坐标
16     private double y;  //y坐标
17     private double dist; //距离另一个点的距离
18
19
20
21     private String label; //所属类别
22
23     public Point() {
24         this(0d, 0d, "");
25     }
26
27     public Point(double x, double y, String label) {
28         this.x = x;
29         this.y = y;
30         this.label = label;
31     }
32
33     /*计算两点之间的距离*/
34     public double distance(final Point a) {
35         return Math.sqrt((a.x - x) * (a.x - x) + (a.y - y) * (a.y - y));
36     }
37
38     public double getX() {
39         return x;
40     }
41
42     public void setX(double x) {
43         this.x = x;
44     }
45
46     public double getY() {
47         return y;
48     }
49
50     public void setY(double y) {
51         this.y = y;
52     }
53
54     public String getLabel() {
55         return label;
56     }
57
58     public void setLabel(String label) {
59         this.label = label;
60     }
61
62
63     public double getDist() {
64         return dist;
65     }
66
67     public void setDist(double dist) {
68         this.dist = dist;
69     }
70 }

KNN实现

 1 package com.data.knn;
 2
 3 import com.google.common.base.Preconditions;
 4 import com.google.common.collect.Maps;
 5
 6 import java.util.Collections;
 7 import java.util.Comparator;
 8 import java.util.List;
 9 import java.util.Map;
10
11 /**
12  * *********************************************************
13  * <p/>
14  * Author:     XiJun.Gong
15  * Date:       2016-09-06 11:59
16  * Version:    default 1.0.0
17  * Class description:
18  * <p/>
19  * *********************************************************
20  */
21 public class knn {
22
23     private List<Point> dataSet;    //统计频率
24     private Point newPoint;         //当前点
25
26
27     //进行KNN分类
28     public String classify(List<Point> dataSet, final Point newPoint, Integer K) {
29
30         Preconditions.checkArgument(K < dataSet.size(), "K的值超过了dataSet的元素");
31         //求解每一个点到新的点的距离
32         for (Point point : dataSet) {
33             point.setDist(newPoint.distance(point));
34         }
35         //进行排序
36         Collections.sort(dataSet, new Comparator<Point>() {
37             @Override
38             public int compare(Point o1, Point o2) {
39                 //return o1.distance(newPoint) < o2.distance(newPoint) ? 1 : -1;
40                 return o1.getDist() < o2.getDist() ? 1 : -1;
41             }
42         });
43
44         //统计前K个标签的频率
45         Map<String, Integer> map = Maps.newHashMap();
46         Integer maxCnt = -9999; //最高频率
47         String label = "";  //最高频率标签
48         Integer currentCnt = 0; //当前标签的频率
49         Integer times = 0;
50         for (Point point : dataSet) {
51             currentCnt = 1;
52             if (map.containsKey(point.getLabel())) {
53                 currentCnt += map.get(point);
54             }
55             if (maxCnt < currentCnt) {
56                 maxCnt = currentCnt;
57                 label = point.getLabel();
58             }
59             map.put(point.getLabel(), currentCnt);
60             times++;
61             if (times > K) break;
62         }
63         return label;
64     }
65
66
67 }
 1 package com.data.knn;
 2
 3 import com.google.common.collect.Lists;
 4
 5 import java.util.List;
 6
 7 /**
 8  * *********************************************************
 9  * <p/>
10  * Author:     XiJun.Gong
11  * Date:       2016-09-06 14:45
12  * Version:    default 1.0.0
13  * Class description:
14  * <p/>
15  * *********************************************************
16  */
17 public class Main {
18
19     public static void main(String args[]) {
20         List<Point> list = Lists.newArrayList();
21         list.add(new Point(1., 1.1, "A"));
22         list.add(new Point(1., 1., "A"));
23         list.add(new Point(0., 0., "B"));
24         list.add(new Point(0., 0.1, "B"));
25         Point point = new Point(0.5, 0.5, null);
26         KNN knn = new KNN();
27         System.out.println(knn.classify(list, point, 3));
28     }
29 }

结果:

A

  

时间: 2024-11-06 15:02:52

机器学习之KNN算法思想及其实现的相关文章

机器学习之KNN算法

1 KNN算法 1.1 KNN算法简介 KNN(K-Nearest Neighbor)工作原理:存在一个样本数据集合,也称为训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分类对应的关系.输入没有标签的数据后,将新数据中的每个特征与样本集中数据对应的特征进行比较,提取出样本集中特征最相似数据(最近邻)的分类标签.一般来说,我们只选择样本数据集中前k个最相似的数据,这就是k近邻算法中k的出处,通常k是不大于20的整数.最后选择k个最相似数据中出现次数最多的分类作为新数据

菜鸟之路——机器学习之KNN算法个人理解及Python实现

KNN(K Nearest Neighbor) 还是先记几个关键公式 距离:一般用Euclidean distance   E(x,y)√∑(xi-yi)2 .名字这么高大上,就是初中学的两点间的距离嘛. 还有其他距离的衡量公式,余弦值(cos),相关度(correlation) 曼哈顿距离(manhatann distance).我觉得针对于KNN算法还是Euclidean distance最好,最直观. 然后就选择最近的K个点.根据投票原则分类出结果. 首先利用sklearn自带的的iris

Python 手写数字识别-knn算法应用

在上一篇博文中,我们对KNN算法思想及流程有了初步的了解,KNN是采用测量不同特征值之间的距离方法进行分类,也就是说对于每个样本数据,需要和训练集中的所有数据进行欧氏距离计算.这里简述KNN算法的特点: 优点:精度高,对异常值不敏感,无数据输入假定 缺点:计算复杂度高,空间复杂度高 适用数据范围:数值型和标称型(具有有穷多个不同值,值之间无序)    knn算法代码: #-*- coding: utf-8 -*- from numpy import * import operatorimport

【转】常见面试之机器学习算法思想简单梳理

转:http://www.chinakdd.com/article-oyU85v018dQL0Iu.html 前言: 找工作时(IT行业),除了常见的软件开发以外,机器学习岗位也可以当作是一个选择,不少计算机方向的研究生都会接触这个,如果你的研究方向是机器学习/数据挖掘之类,且又对其非常感兴趣的话,可以考虑考虑该岗位,毕竟在机器智能没达到人类水平之前,机器学习可以作为一种重要手段,而随着科技的不断发展,相信这方面的人才需求也会越来越大. 纵观IT行业的招聘岗位,机器学习之类的岗位还是挺少的,国内

机器学习&amp;数据挖掘笔记_16(常见面试之机器学习算法思想简单梳理)

http://www.cnblogs.com/tornadomeet/p/3395593.html 机器学习&数据挖掘笔记_16(常见面试之机器学习算法思想简单梳理) 前言: 找工作时(IT行业),除了常见的软件开发以外,机器学习岗位也可以当作是一个选择,不少计算机方向的研究生都会接触这个,如果你的研究方向是机器学习/数据挖掘之类,且又对其非常感兴趣的话,可以考虑考虑该岗位,毕竟在机器智能没达到人类水平之前,机器学习可以作为一种重要手段,而随着科技的不断发展,相信这方面的人才需求也会越来越大.

常见面试之机器学习算法思想简单梳理

http://www.cnblogs.com/tornadomeet/p/3395593.html (转) 前言: 找工作时(IT行业),除了常见的软件开发以外,机器学习岗位也可以当作是一个选择,不少计算机方向的研究生都会接触这个,如果你的研究方向是机器学习/数据挖掘之类,且又对其非常感兴趣的话,可以考虑考虑该岗位,毕竟在机器智能没达到人类水平之前,机器学习可以作为一种重要手段,而随着科技的不断发展,相信这方面的人才需求也会越来越大. 纵观IT行业的招聘岗位,机器学习之类的岗位还是挺少的,国内大

常见面试之机器学习算法思想简单梳理【转】

前言: 找工作时(IT行业),除了常见的软件开发以外,机器学习岗位也可以当作是一个选择,不少计算机方向的研究生都会接触这个,如果你的研究方向是机器学习/数据挖掘之类,且又对其非常感兴趣的话,可以考虑考虑该岗位,毕竟在机器智能没达到人类水平之前,机器学习可以作为一种重要手段,而随着科技的不断发展,相信这方面的人才需求也会越来越大. 纵观IT行业的招聘岗位,机器学习之类的岗位还是挺少的,国内大点的公司里百度,阿里,腾讯,网易,搜狐,华为(华为的岗位基本都是随机分配,机器学习等岗位基本面向的是博士)等

机器学习十大算法之KNN(K最近邻,k-NearestNeighbor)算法

机器学习十大算法之KNN算法 前段时间一直在搞tkinter,机器学习荒废了一阵子.如今想重新写一个,发现遇到不少问题,不过最终还是解决了.希望与大家共同进步. 闲话少说,进入正题. KNN算法也称最近邻居算法,是一种分类算法. 算法的基本思想:假设已存在一个数据集,数据集有多个数值属性和一个标签属性,输入一个新数据,求新数据的标签. 步骤如下: 先将新数据拷贝n份,形成一个新的数据集: 逐行计算新数据集与原数据集的距离: 按距离长度排序后,统计前K个数据里,那个标签出现的次数最多,新数据就标记

【机器学习算法实现】kNN算法__手写识别——基于Python和NumPy函数库

[机器学习算法实现]系列文章将记录个人阅读机器学习论文.书籍过程中所碰到的算法,每篇文章描述一个具体的算法.算法的编程实现.算法的具体应用实例.争取每个算法都用多种语言编程实现.所有代码共享至github:https://github.com/wepe/MachineLearning-Demo     欢迎交流指正! (1)kNN算法_手写识别实例--基于Python和NumPy函数库 1.kNN算法简介 kNN算法,即K最近邻(k-NearestNeighbor)分类算法,是最简单的机器学习算