[Machine Learning]kNN代码实现(Kd tree)

具体描述见《统计学习方法》第三章。

  1 //
  2 //  main.cpp
  3 //  kNN
  4 //
  5 //  Created by feng on 15/10/24.
  6 //  Copyright © 2015年 ttcn. All rights reserved.
  7 //
  8
  9 #include <iostream>
 10 #include <vector>
 11 #include <algorithm>
 12 #include <cmath>
 13 using namespace std;
 14
 15 template<typename T>
 16 struct KdTree {
 17     // ctor
 18     KdTree():parent(nullptr), leftChild(nullptr), rightChild(nullptr) {}
 19
 20     // KdTree是否为空
 21     bool isEmpty() { return root.empty(); }
 22
 23     // KdTree是否为叶子节点
 24     bool isLeaf() { return !root.empty() && !leftChild && !rightChild;}
 25
 26     // KdTree是否为根节点
 27     bool isRoot() { return !isEmpty() && !parent;}
 28
 29     // 判断KdTree是否为根节点的左儿子
 30     bool isLeft() { return parent->leftChild->root == root; }
 31
 32     // 判断KdTree是否为根节点的右儿子
 33     bool isRight() { return parent->rightChild->root == root; }
 34
 35     // 存放根节点的数据
 36     vector<T> root;
 37
 38     // 父节点
 39     KdTree<T> *parent;
 40
 41     // 左儿子
 42     KdTree<T> *leftChild;
 43
 44     // 右儿子
 45     KdTree<T> *rightChild;
 46 };
 47
 48
 49 /**
 50  *  矩阵转置
 51  *
 52  *  @param matrix 原矩阵
 53  *
 54  *  @return 原矩阵的转置矩阵
 55  */
 56 template<typename T>
 57 vector<vector<T>> transpose(const vector<vector<T>> &matrix) {
 58     size_t rows = matrix.size();
 59     size_t cols = matrix[0].size();
 60     vector<vector<T>> trans(cols, vector<T>(rows, 0));
 61     for (size_t i = 0; i < cols; ++i) {
 62         for (size_t j = 0; j < rows; ++j) {
 63             trans[i][j] = matrix[j][i];
 64         }
 65     }
 66
 67     return trans;
 68 }
 69
 70 /**
 71  *  找中位数
 72  *
 73  *  @param vec 数组
 74  *
 75  *  @return 数组中的中位数
 76  */
 77 template<typename T>
 78 T findMiddleValue(vector<T> vec) {
 79     sort(vec.begin(), vec.end());
 80     size_t pos = vec.size() / 2;
 81     return vec[pos];
 82 }
 83
 84 /**
 85  *  递归构造KdTree
 86  *
 87  *  @param tree  KdTree根节点
 88  *  @param data  数据矩阵
 89  *  @param depth 当前节点深度
 90  *
 91  *  @return void
 92  */
 93 template<typename T>
 94 void buildKdTree(KdTree<T> *tree, vector<vector<T>> &data, size_t depth) {
 95     // 输入数据个数
 96     size_t samplesNum = data.size();
 97
 98     if (samplesNum == 0) {
 99         return;
100     }
101
102     if (samplesNum == 1) {
103         tree->root = data[0];
104         return;
105     }
106
107     // 每一个输入数据的维度,属性个数
108     size_t k = data[0].size();
109     vector<vector<T>> transData = transpose(data);
110
111     // 找到当前切分点
112     size_t splitAttributeIndex = depth % k;
113     vector<T> splitAttributes = transData[splitAttributeIndex];
114     T splitValue = findMiddleValue(splitAttributes);
115
116     vector<vector<T>> leftSubSet;
117     vector<vector<T>> rightSubset;
118
119     for (size_t i = 0; i < samplesNum; ++i) {
120         if (splitAttributes[i] == splitValue && tree->isEmpty()) {
121             tree->root = data[i];
122         } else if (splitAttributes[i] < splitValue) {
123             leftSubSet.push_back(data[i]);
124         } else {
125             rightSubset.push_back(data[i]);
126         }
127     }
128
129     tree->leftChild = new KdTree<T>;
130     tree->leftChild->parent = tree;
131     tree->rightChild = new KdTree<T>;
132     tree->rightChild->parent = tree;
133     buildKdTree(tree->leftChild, leftSubSet, depth + 1);
134     buildKdTree(tree->rightChild, rightSubset, depth + 1);
135 }
136
137 /**
138  *  递归打印KdTree
139  *
140  *  @param tree  KdTree
141  *  @param depth 当前深度
142  *
143  *  @return void
144  */
145 template<typename T>
146 void printKdTree(const KdTree<T> *tree, size_t depth) {
147     for (size_t i = 0; i < depth; ++i) {
148         cout << "\t";
149     }
150
151     for (size_t i = 0; i < tree->root.size(); ++i) {
152         cout << tree->root[i] << " ";
153     }
154     cout << endl;
155
156     if (tree->leftChild == nullptr && tree->rightChild == nullptr) {
157         return;
158     } else {
159         if (tree->leftChild) {
160             for (int i = 0; i < depth + 1; ++i) {
161                 cout << "\t";
162             }
163             cout << "left : ";
164             printKdTree(tree->leftChild, depth + 1);
165         }
166
167         cout << endl;
168
169         if (tree->rightChild) {
170             for (size_t i = 0; i < depth + 1; ++i) {
171                 cout << "\t";
172             }
173             cout << "right : ";
174             printKdTree(tree->rightChild, depth + 1);
175         }
176         cout << endl;
177     }
178 }
179
180 /**
181  *  节点之间的欧氏距离
182  *
183  *  @param p1 节点1
184  *  @param p2 节点2
185  *
186  *  @return 节点之间的欧式距离
187  */
188 template<typename T>
189 T calDistance(const vector<T> &p1, const vector<T> &p2) {
190     T res = 0;
191     for (size_t i = 0; i < p1.size(); ++i) {
192         res += pow(p1[i] - p2[i], 2);
193     }
194
195     return res;
196 }
197
198 /**
199  *  搜索目标节点的最近邻
200  *
201  *  @param tree KdTree
202  *  @param goal 待分类的节点
203  *
204  *  @return 最近邻节点
205  */
206 template <typename T>
207 vector<T> searchNearestNeighbor(KdTree<T> *tree, const vector<T> &goal ) {
208     // 节点数属性个数
209     size_t k = tree->root.size();
210     // 划分的索引
211     size_t d = 0;
212     KdTree<T> *currentTree = tree;
213     vector<T> currentNearest = currentTree->root;
214     // 找到目标节点的最叶节点
215     while (!currentTree->isLeaf()) {
216         size_t index = d % k;
217         if (currentTree->rightChild->isEmpty() || goal[index] < currentNearest[index]) {
218             currentTree = currentTree->leftChild;
219         } else {
220             currentTree = currentTree->rightChild;
221         }
222
223         ++d;
224     }
225     currentNearest = currentTree->root;
226     T currentDistance = calDistance(goal, currentTree->root);
227
228     KdTree<T> *searchDistrict;
229     if (currentTree->isLeft()) {
230         if (!(currentTree->parent->rightChild)) {
231             searchDistrict = currentTree;
232         } else {
233             searchDistrict = currentTree->parent->rightChild;
234         }
235     } else {
236         searchDistrict = currentTree->parent->leftChild;
237     }
238
239     while (!(searchDistrict->parent)) {
240         T districtDistance = abs(goal[(d + 1) % k] - searchDistrict->parent->root[(d + 1) % k]);
241
242         if (districtDistance < currentDistance) {
243             T parentDistance = calDistance(goal, searchDistrict->parent->root);
244
245             if (parentDistance < currentDistance) {
246                 currentDistance = parentDistance;
247                 currentTree = searchDistrict->parent;
248                 currentNearest = currentTree->root;
249             }
250
251             if (!searchDistrict->isEmpty()) {
252                 T rootDistance = calDistance(goal, searchDistrict->root);
253                 if (rootDistance < currentDistance) {
254                     currentDistance = rootDistance;
255                     currentTree = searchDistrict;
256                     currentNearest = currentTree->root;
257                 }
258             }
259
260             if (!(searchDistrict->leftChild)) {
261                 T leftDistance = calDistance(goal, searchDistrict->leftChild->root);
262                 if (leftDistance < currentDistance) {
263                     currentDistance = leftDistance;
264                     currentTree = searchDistrict;
265                     currentNearest = currentTree->root;
266                 }
267             }
268
269             if (!(searchDistrict->rightChild)) {
270                 T rightDistance = calDistance(goal, searchDistrict->rightChild->root);
271                 if (rightDistance < currentDistance) {
272                     currentDistance = rightDistance;
273                     currentTree = searchDistrict;
274                     currentNearest = currentTree->root;
275                 }
276             }
277
278         }
279
280         if (!(searchDistrict->parent->parent)) {
281             searchDistrict = searchDistrict->parent->isLeft()? searchDistrict->parent->parent->rightChild : searchDistrict->parent->parent->leftChild;
282         } else {
283             searchDistrict = searchDistrict->parent;
284         }
285         ++d;
286     }
287
288     return currentNearest;
289 }
290
291 int main(int argc, const char * argv[]) {
292     vector<vector<double>> trainDataSet{{2,3},{5,4},{9,6},{4,7},{8,1},{7,2}};
293     KdTree<double> *kdTree = new KdTree<double>;
294     buildKdTree(kdTree, trainDataSet, 0);
295     printKdTree(kdTree, 0);
296
297     vector<double> goal{3, 4.5};
298     vector<double> nearestNeighbor = searchNearestNeighbor(kdTree, goal);
299
300     for (auto i : nearestNeighbor) {
301         cout << i << " ";
302     }
303     cout << endl;
304
305     return 0;
306 }
时间: 2024-11-10 03:06:55

[Machine Learning]kNN代码实现(Kd tree)的相关文章

[Machine :Learning] kNN近邻算法

from numpy import * import operator def createDataSet() : group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 1.1]]) labels = ['A', 'A', 'B', 'B'] return group, labels ''' tile(array, (intR, intC): 对矩阵进行组合,纵向复制intR次, 横向复制intC次 比如 : tile([1,2,3], (3, 2

机器学习算法之旅A Tour of Machine Learning Algorithms

In this post we take a tour of the most popular machine learning algorithms. It is useful to tour the main algorithms in the field to get a feeling of what methods are available. There are so many algorithms available and it can feel overwhelming whe

Machine Learning In Action 第二章学习笔记: kNN算法

本文主要记录<Machine Learning In Action>中第二章的内容.书中以两个具体实例来介绍kNN(k nearest neighbors),分别是: 约会对象预测 手写数字识别 通过“约会对象”功能,基本能够了解到kNN算法的工作原理.“手写数字识别”与“约会对象预测”使用完全一样的算法代码,仅仅是数据集有变化. 约会对象预测 1 约会对象预测功能需求 主人公“张三”喜欢结交新朋友.“系统A”上面注册了很多类似于“张三”的用户,大家都想结交心朋友.“张三”最开始通过自己筛选的

【machine learning】KNN算法

适逢学习机器学习基础知识,就将书中内容读读记记,本博文代码参考书本Machine Learning in Action(<机器学习实战>). 一.概述 kNN算法又称为k近邻分类(k-nearest neighbor classification)算法. kNN算法则是从训练集中找到和新数据最接近的k条记录,然后根据他们的主要分类来决定新数据的类别.该算法涉及3个主要因素:训练集.距离或相似的衡量.k的大小. 二.算法要点 1.指导思想 kNN算法的指导思想是"近朱者赤,近墨者黑&q

【用Python玩Machine Learning】KNN * 序

这段时间工作太忙,很久没学习了.这两天,工作之余,偶尔在家翻翻书,权且当做休息了. 我一直是c/c++的忠实用户,尤其是c的粉丝--概念简洁.运行高效--计算机专业的人,不用c语言,不了解程序底层的运行机制和过程,那和那些外专业的只会调用接口.函数的同学有什么区别呢?不过,最近一年还是慢慢去了解.尝试python了.原因很简单,开发成本太低了.c/c++就像复杂的吸尘器.洗碗机,优点是高效,缺点是笨重,且对不同的场景要不同的适配:python就像是一块脏抹布,哪儿都能用,用完就扔,再用的时候再捡

【机器学习实战】Machine Learning in Action 代码 视频 项目案例

MachineLearning 欢迎任何人参与和完善:一个人可以走的很快,但是一群人却可以走的更远 Machine Learning in Action (机器学习实战) | ApacheCN(apache中文网) 视频每周更新:如果你觉得有价值,请帮忙点 Star[后续组织学习活动:sklearn + tensorflow] ApacheCN - 学习机器学习群[629470233] 第一部分 分类 1.) 机器学习基础 2.) k-近邻算法 3.) 决策树 4.) 基于概率论的分类方法:朴素

【用Python玩Machine Learning】KNN * 测试

样本我就用的<machine learning in action>中提供的数据样例,据说是婚恋网站上各个候选人的特征,以及当前人对这些人的喜欢程度.一共1k条数据,前900条作为训练样本,后100条作为测试样本. 数据格式如下: 46893 3.562976 0.445386 didntLike 8178 3.230482 1.331698 smallDoses 55783 3.612548 1.551911 didntLike 1148 0.000000 0.332365 smallDos

[Machine Learning] 国外程序员整理的机器学习资源大全

本文汇编了一些机器学习领域的框架.库以及软件(按编程语言排序). 1. C++ 1.1 计算机视觉 CCV —基于C语言/提供缓存/核心的机器视觉库,新颖的机器视觉库 OpenCV—它提供C++, C, Python, Java 以及 MATLAB接口,并支持Windows, Linux, Android and Mac OS操作系统. 1.2 机器学习 MLPack DLib ecogg shark 2. Closure Closure Toolbox—Clojure语言库与工具的分类目录 3

OpenCV的machine learning模块使用

opencv中提供的了较为完善的machine learning 模块,包含多种ml算法,极大了简化了实验过程.然而目前网上大部分的资料(包括官方文档)中关于ml模块的使用均是针对1.0风格的旧代码的,这对我们的学习造成了极大的困扰.本文将简单介绍一下如何使用opencv的ml模块进行实验. 首先,准备实验数据,我这里使用的是<模式分类>一书中第二章上机习题的部分数据,旨在进行一个简单的调用过程进行实验.实验数据如下表所示,在实际实验过程中,使用txt文档保存数据,并且没有文件头信息(实际上o