k近邻算法-java实现

最近在看《机器学习实战》这本书,因为自己本身很想深入的了解机器学习算法,加之想学python,就在朋友的推荐之下选择了这本书进行学习。

一 . K-近邻算法(KNN)概述

最简单最初级的分类器是将全部的训练数据所对应的类别都记录下来,当测试对象的属性和某个训练对象的属性完全匹配时,便可以对其进行分类。但是怎么可能所有测试对象都会找到与之完全匹配的训练对象呢,其次就是存在一个测试对象同时与多个训练对象匹配,导致一个训练对象被分到了多个类的问题,基于这些问题呢,就产生了KNN。

KNN是通过测量不同特征值之间的距离进行分类。它的的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。K通常是不大于20的整数。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。

下面通过一个简单的例子说明一下:如下图,绿色圆要被决定赋予哪个类,是红色三角形还是蓝色四方形?如果K=3,由于红色三角形所占比例为2/3,绿色圆将被赋予红色三角形那个类,如果K=5,由于蓝色四方形比例为3/5,因此绿色圆被赋予蓝色四方形类。

由此也说明了KNN算法的结果很大程度取决于K的选择。

在KNN中,通过计算对象间距离来作为各个对象之间的非相似性指标,避免了对象之间的匹配问题,在这里距离一般使用欧氏距离或曼哈顿距离:

同时,KNN通过依据k个对象中占优的类别进行决策,而不是单一的对象类别决策。这两点就是KNN算法的优势。

接下来对KNN算法的思想总结一下:就是在训练集中数据和标签已知的情况下,输入测试数据,将测试数据的特征与训练集中对应的特征进行相互比较,找到训练集中与之最为相似的前K个数据,则该测试数据对应的类别就是K个数据中出现次数最多的那个分类,其算法的描述为:

1)计算测试数据与各个训练数据之间的距离;

2)按照距离的递增关系进行排序;

3)选取距离最小的K个点;

4)确定前K个点所在类别的出现频率;

5)返回前K个点中出现频率最高的类别作为测试数据的预测分类。

代码实现:

import java.util.*;

/**
 * code by me
 * <p>
 * Data:2017/8/17 Time:16:40
 * User:lbh
 */
public class KNN {

    /**
     * KNN数据模型
     */
    public static class KNNModel implements Comparable<KNNModel> {
        public double a;
        public double b;
        public double c;
        public double distince;
        String type;

        public KNNModel(double a, double b, double c, String type) {
            this.a = a;
            this.b = b;
            this.c = c;
            this.type = type;
        }
        /**
         * 按距离排序
         *
         * @param arg
         * @return
         */
        @Override
        public int compareTo(KNNModel arg) {
            return Double.valueOf(this.distince).compareTo(Double.valueOf(arg.distince));
        }
    }

    /**
     * 计算距离
     *
     * @param knnModelList
     * @param i
     */
    private static void calDistince(List<KNNModel> knnModelList, KNNModel i) {
        double distince;
        for (KNNModel m : knnModelList) {
            distince = Math.sqrt((i.a - m.a) * (i.a - m.a) + (i.b - m.b) * (i.b - m.b) + (i.c - m.c) * (i.c - m.c));
            m.distince = distince;
        }
    }

    /**
     * 找出前k个数据中分类最多的数据
     *
     * @param knnModelList
     * @return
     */
    private static String findMostData(List<KNNModel> knnModelList) {
        Map<String, Integer> typeCountMap = new HashMap<String, Integer>();
        String type = "";
        Integer tempVal = 0;
        // 统计分类个数
        for (KNNModel model : knnModelList) {
            if (typeCountMap.containsKey(model.type)) {
                typeCountMap.put(model.type, typeCountMap.get(model.type) + 1);
            } else {
                typeCountMap.put(model.type, 1);
            }
        }
        // 找出最多分类
        for (Map.Entry<String, Integer> entry : typeCountMap.entrySet()) {
            if (entry.getValue() > tempVal) {
                tempVal = entry.getValue();
                type = entry.getKey();
            }
        }
        return type;
    }

    /**
     * KNN 算法的实现
     *
     * @param k
     * @param knnModelList
     * @param inputModel
     * @return
     */
    public static String calKNN(int k, List<KNNModel> knnModelList, KNNModel inputModel) {
        System.out.println("1.计算距离");
        calDistince(knnModelList, inputModel);
        System.out.println("2.按距离(近-->远)排序");
        Collections.sort(knnModelList);
        System.out.println("3.取前k个数据");
        while (knnModelList.size() > k) {
            knnModelList.remove(k);
        }
        System.out.println("4.找出前k个数据中分类出现频率最大的数据");
        String type = findMostData(knnModelList);
        return type;
    }

    /**
     * 测试KNN算法
     *
     * @param args
     */
    public static void main(String[] args) {
        // 准备数据
        List<KNNModel> knnModelList = new ArrayList<KNNModel>();
        knnModelList.add(new KNNModel(1.1, 1.1, 1.1, "A"));
        knnModelList.add(new KNNModel(1.2, 1.1, 1.0, "A"));
        knnModelList.add(new KNNModel(1.1, 1.0, 1.0, "A"));
        knnModelList.add(new KNNModel(3.0, 3.1, 1.0, "B"));
        knnModelList.add(new KNNModel(3.1, 3.0, 1.0, "B"));
        knnModelList.add(new KNNModel(5.4, 6.0, 4.0, "C"));
        knnModelList.add(new KNNModel(5.5, 6.3, 4.1, "C"));
        knnModelList.add(new KNNModel(6.0, 6.0, 4.0, "C"));
        knnModelList.add(new KNNModel(10.0, 12.0, 10.0, "M"));
        // 预测数据
        KNNModel predictionData = new KNNModel(5.1, 6.2, 2.0, "NB");
        // 计算
        String result = calKNN(3, knnModelList, predictionData);
        System.out.println("预测结果:"+result);
    }
}

结果:

时间: 2024-10-12 18:35:20

k近邻算法-java实现的相关文章

『cs231n』作业1问题1选讲_通过代码理解K近邻算法&amp;交叉验证选择超参数参数

通过K近邻算法探究numpy向量运算提速 茴香豆的"茴"字有... ... 使用三种计算图片距离的方式实现K近邻算法: 1.最为基础的双循环 2.利用numpy的broadca机制实现单循环 3.利用broadcast和矩阵的数学性质实现无循环 图片被拉伸为一维数组 X_train:(train_num, 一维数组) X:(test_num, 一维数组) 方法验证 import numpy as np a = np.array([[1,1,1],[2,2,2],[3,3,3]]) b

K 近邻算法

声明: 1,本篇为个人对<2012.李航.统计学习方法.pdf>的学习总结,不得用作商用,欢迎转载,但请注明出处(即:本帖地址). 2,因为本人在学习初始时有非常多数学知识都已忘记,所以为了弄懂当中的内容查阅了非常多资料.所以里面应该会有引用其它帖子的小部分内容,假设原作者看到能够私信我,我会将您的帖子的地址付到以下. 3.假设有内容错误或不准确欢迎大家指正. 4.假设能帮到你.那真是太好了. 描写叙述 给定一个训练数据集,对新的输入实例.在训练数据集中找到与该实例最邻近的K个实例,若这K个实

从K近邻算法、距离度量谈到KD树、SIFT+BBF算法

从K近邻算法.距离度量谈到KD树.SIFT+BBF算法 从K近邻算法.距离度量谈到KD树.SIFT+BBF算法 前言 前两日,在微博上说:“到今天为止,我至少亏欠了3篇文章待写:1.KD树:2.神经网络:3.编程艺术第28章.你看到,blog内的文章与你于别处所见的任何都不同.于是,等啊等,等一台电脑,只好等待..”.得益于田,借了我一台电脑(借他电脑的时候,我连表示感谢,他说“能找到工作全靠你的博客,这点儿小忙还说,不地道”,有的时候,稍许感受到受人信任也是一种压力,愿我不辜负大家对我的信任)

K近邻算法

1.1.什么是K近邻算法 何谓K近邻算法,即K-Nearest Neighbor algorithm,简称KNN算法,单从名字来猜想,可以简单粗暴的认为是:K个最近的邻居,当K=1时,算法便成了最近邻算法,即寻找最近的那个邻居.为何要找邻居?打个比方来说,假设你来到一个陌生的村庄,现在你要找到与你有着相似特征的人群融入他们,所谓入伙. 用官方的话来说,所谓K近邻算法,即是给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的K个实例(也就是上面所说的K个邻居),这K个实例的多数属

K近邻算法-KNN

何谓K近邻算法,即K-Nearest Neighbor algorithm,简称KNN算法,单从名字来猜想,可以简单粗暴的认为是:K个最近的邻居,当K=1时,算法便成了最近邻算法,即寻找最近的那个邻居.为何要找邻居?打个比方来说,假设你来到一个陌生的村庄,现在你要找到与你有着相似特征的人群融入他们,所谓入伙. 用官方的话来说,所谓K近邻算法,即是给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的K个实例(也就是上面所说的K个邻居),这K个实例的多数属于某个类,就把该输入实例分

k近邻算法理论(一)

时间 :2014.07.05 地点:基地 ----------------------------------------------------------------------------------- 一.简述 K近邻法(k-nearest neighbor,kNN)是一种基本分类与回归方法.k近邻的输入为实例的特征向量,对应特征空间中的点,输出为实例的类别.k近邻算法的基本思想是:给定训练数据集,实例类别已定,在对目标实例进行分类时,我们根据与目标实例k个最近邻居的训练实例的类别,通过

机器学习实战笔记--k近邻算法

1 #encoding:utf-8 2 from numpy import * 3 import operator 4 import matplotlib 5 import matplotlib.pyplot as plt 6 7 from os import listdir 8 9 def makePhoto(returnMat,classLabelVector): #创建散点图 10 fig = plt.figure() 11 ax = fig.add_subplot(111) #例如参数为

基本分类方法——KNN(K近邻)算法

在这篇文章 http://www.cnblogs.com/charlesblc/p/6193867.html 讲SVM的过程中,提到了KNN算法.有点熟悉,上网一查,居然就是K近邻算法,机器学习的入门算法. 参考内容如下:http://www.cnblogs.com/charlesblc/p/6193867.html 1.kNN算法又称为k近邻分类(k-nearest neighbor classification)算法. 最简单平凡的分类器也许是那种死记硬背式的分类器,记住所有的训练数据,对于

使用K近邻算法实现手写体识别系统

目录 1. 应用介绍 1.1实验环境介绍 1.2应用背景介绍 2. 数据来源及预处理 2.1数据来源及格式 2.2数据预处理 3. 算法设计与实现 3.1手写体识别系统算法实现过程 3.2 K近邻算法实现 3.3手写体识别系统实现 3.4算法改进与优化 4. 系统运行过程与结果展示 1.应用介绍 1.1实验环境介绍 本次实验主要使用Python语言开发完成,Python的版本为2.7,并且使用numpy函数库做一些数值计算和处理. 1.2应用背景介绍 本次实验实现的是简易的手写体识别系统,即根据