从一个例子来直观感受KNN思想
如下图 , 绿色圆要被决定赋予哪个类,是红色三角形还是蓝色四方形?如果K=3,由于红色三角形所占比例为2/3,绿色圆将被赋予红色三角形那个类,如果K=5,由于蓝色四方形比例为3/5,因此绿色圆被赋予蓝色四方形类。
从这个例子中,我们再来看KNN思想:
1, 计算已知类别数据集合中的点与当前点之间的距离(使用欧式距离公司: d =sqrt(pow(x-x1),2)+pow(y-y1),2) 2, 按照距离递增次序排序(由近到远) 3, 选取与当前点距离最小的的K个点(如上题中的 k=3,k=5) 4, 确定前K个点所在类别的出现频率 5, 将频率最高的那组,作为该点的预测分类
实现代码:
1 package com.data.knn; 2 3 /** 4 * ********************************************************* 5 * <p/> 6 * Author: XiJun.Gong 7 * Date: 2016-09-06 12:02 8 * Version: default 1.0.0 9 * Class description: 10 * <p/> 11 * ********************************************************* 12 */ 13 public class Point { 14 15 private double x; //x坐标 16 private double y; //y坐标 17 private double dist; //距离另一个点的距离 18 19 20 21 private String label; //所属类别 22 23 public Point() { 24 this(0d, 0d, ""); 25 } 26 27 public Point(double x, double y, String label) { 28 this.x = x; 29 this.y = y; 30 this.label = label; 31 } 32 33 /*计算两点之间的距离*/ 34 public double distance(final Point a) { 35 return Math.sqrt((a.x - x) * (a.x - x) + (a.y - y) * (a.y - y)); 36 } 37 38 public double getX() { 39 return x; 40 } 41 42 public void setX(double x) { 43 this.x = x; 44 } 45 46 public double getY() { 47 return y; 48 } 49 50 public void setY(double y) { 51 this.y = y; 52 } 53 54 public String getLabel() { 55 return label; 56 } 57 58 public void setLabel(String label) { 59 this.label = label; 60 } 61 62 63 public double getDist() { 64 return dist; 65 } 66 67 public void setDist(double dist) { 68 this.dist = dist; 69 } 70 }
KNN实现
1 package com.data.knn; 2 3 import com.google.common.base.Preconditions; 4 import com.google.common.collect.Maps; 5 6 import java.util.Collections; 7 import java.util.Comparator; 8 import java.util.List; 9 import java.util.Map; 10 11 /** 12 * ********************************************************* 13 * <p/> 14 * Author: XiJun.Gong 15 * Date: 2016-09-06 11:59 16 * Version: default 1.0.0 17 * Class description: 18 * <p/> 19 * ********************************************************* 20 */ 21 public class knn { 22 23 private List<Point> dataSet; //统计频率 24 private Point newPoint; //当前点 25 26 27 //进行KNN分类 28 public String classify(List<Point> dataSet, final Point newPoint, Integer K) { 29 30 Preconditions.checkArgument(K < dataSet.size(), "K的值超过了dataSet的元素"); 31 //求解每一个点到新的点的距离 32 for (Point point : dataSet) { 33 point.setDist(newPoint.distance(point)); 34 } 35 //进行排序 36 Collections.sort(dataSet, new Comparator<Point>() { 37 @Override 38 public int compare(Point o1, Point o2) { 39 //return o1.distance(newPoint) < o2.distance(newPoint) ? 1 : -1; 40 return o1.getDist() < o2.getDist() ? 1 : -1; 41 } 42 }); 43 44 //统计前K个标签的频率 45 Map<String, Integer> map = Maps.newHashMap(); 46 Integer maxCnt = -9999; //最高频率 47 String label = ""; //最高频率标签 48 Integer currentCnt = 0; //当前标签的频率 49 Integer times = 0; 50 for (Point point : dataSet) { 51 currentCnt = 1; 52 if (map.containsKey(point.getLabel())) { 53 currentCnt += map.get(point); 54 } 55 if (maxCnt < currentCnt) { 56 maxCnt = currentCnt; 57 label = point.getLabel(); 58 } 59 map.put(point.getLabel(), currentCnt); 60 times++; 61 if (times > K) break; 62 } 63 return label; 64 } 65 66 67 }
1 package com.data.knn; 2 3 import com.google.common.collect.Lists; 4 5 import java.util.List; 6 7 /** 8 * ********************************************************* 9 * <p/> 10 * Author: XiJun.Gong 11 * Date: 2016-09-06 14:45 12 * Version: default 1.0.0 13 * Class description: 14 * <p/> 15 * ********************************************************* 16 */ 17 public class Main { 18 19 public static void main(String args[]) { 20 List<Point> list = Lists.newArrayList(); 21 list.add(new Point(1., 1.1, "A")); 22 list.add(new Point(1., 1., "A")); 23 list.add(new Point(0., 0., "B")); 24 list.add(new Point(0., 0.1, "B")); 25 Point point = new Point(0.5, 0.5, null); 26 KNN knn = new KNN(); 27 System.out.println(knn.classify(list, point, 3)); 28 } 29 }
结果:
A
时间: 2024-11-06 15:02:52