k近邻法的C++实现:kd树

1.k近邻算法的思想

给定一个训练集,对于新的输入实例,在训练集中找到与该实例最近的k个实例,这k个实例中的多数属于某个类,就把该输入实例分为这个类。

因为要找到最近的k个实例,所以计算输入实例与训练集中实例之间的距离是关键!

k近邻算法最简单的方法是线性扫描,这时要计算输入实例与每一个训练实例的距离,当训练集很大时,非常耗时,这种方法不可行,为了提高k近邻的搜索效率,常常考虑使用特殊的存储结构存储训练数据,以减少计算距离的次数,具体方法很多,这里介绍实现经典的kd树方法。

2.构造kd树

kd树是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构,kd树是二叉树。

下面举例说明:

给定一个二维空间的数据集: T = {(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)},构造一个平衡kd树。

  • 根结点对应包含数据集T的矩形选择x(1) 轴,6个数据点的x(1) 坐标的中位数是7,以超平面x(1) = 7将空间分为左右两个子矩形(子结点)
  • 左矩形以x(2) = 4为中位数分为两个子矩形
  • 右矩形以x(2) = 6 分为两个子矩形
  • 如此递归,直到两个子区域没有实例存在时停止

构造的kd树如下:

3.利用kd树搜索最近邻

输入:已构造的kd树;目标点x;

输出:x的最近邻

  • 在kd树中找出包含目标点x的叶结点:从根结点出发,递归的向下访问kd树,若目标点x的当前维的坐标小于切分点的坐标,则移动到左子结点,否则移动到右子结点,直到子结点为叶结点为止。
  • 以此叶结点为“当前最近点”
  • 递归地向上回退,在每个结点进行以下操作:(a)如果该结点保存的实例点比当前最近点距离目标点更近,则以该实例点为“当前最近点”;

    (b)当前最近点一定存在于某结点一个子结点对应的区域,检查该子结点的父结点的另
    一子结点对应区域是否有更近的点(即检查另一子结点对应的区域是否与以目标点为球
    心、以目标点与“当前最近点”间的距离为半径的球体相交);如果相交,可能在另一
    个子结点对应的区域内存在距目标点更近的点,移动到另一个子结点,接着递归进行最
    近邻搜索;如果不相交,向上回退
  • 当回退到根结点时,搜索结束,最后的“当前最近点”即为x的最近邻点。

4.C++实现

  1 #include <iostream>
  2 #include <vector>
  3 #include <algorithm>
  4 #include <string>
  5 #include <cmath>
  6 using namespace std;
  7
  8
  9
 10
 11 struct KdTree{
 12     vector<double> root;
 13     KdTree* parent;
 14     KdTree* leftChild;
 15     KdTree* rightChild;
 16     //默认构造函数
 17     KdTree(){parent = leftChild = rightChild = NULL;}
 18     //判断kd树是否为空
 19     bool isEmpty()
 20     {
 21         return root.empty();
 22     }
 23     //判断kd树是否只是一个叶子结点
 24     bool isLeaf()
 25     {
 26         return (!root.empty()) &&
 27             rightChild == NULL && leftChild == NULL;
 28     }
 29     //判断是否是树的根结点
 30     bool isRoot()
 31     {
 32         return (!isEmpty()) && parent == NULL;
 33     }
 34     //判断该子kd树的根结点是否是其父kd树的左结点
 35     bool isLeft()
 36     {
 37         return parent->leftChild->root == root;
 38     }
 39     //判断该子kd树的根结点是否是其父kd树的右结点
 40     bool isRight()
 41     {
 42         return parent->rightChild->root == root;
 43     }
 44 };
 45
 46 int data[6][2] = {{2,3},{5,4},{9,6},{4,7},{8,1},{7,2}};
 47
 48 template<typename T>
 49 vector<vector<T> > Transpose(vector<vector<T> > Matrix)
 50 {
 51     unsigned row = Matrix.size();
 52     unsigned col = Matrix[0].size();
 53     vector<vector<T> > Trans(col,vector<T>(row,0));
 54     for (unsigned i = 0; i < col; ++i)
 55     {
 56         for (unsigned j = 0; j < row; ++j)
 57         {
 58             Trans[i][j] = Matrix[j][i];
 59         }
 60     }
 61     return Trans;
 62 }
 63
 64 template <typename T>
 65 T findMiddleValue(vector<T> vec)
 66 {
 67     sort(vec.begin(),vec.end());
 68     auto pos = vec.size() / 2;
 69     return vec[pos];
 70 }
 71
 72
 73 //构建kd树
 74 void buildKdTree(KdTree* tree, vector<vector<double> > data, unsigned depth)
 75 {
 76
 77     //样本的数量
 78     unsigned samplesNum = data.size();
 79     //终止条件
 80     if (samplesNum == 0)
 81     {
 82         return;
 83     }
 84     if (samplesNum == 1)
 85     {
 86         tree->root = data[0];
 87         return;
 88     }
 89     //样本的维度
 90     unsigned k = data[0].size();
 91     vector<vector<double> > transData = Transpose(data);
 92     //选择切分属性
 93     unsigned splitAttribute = depth % k;
 94     vector<double> splitAttributeValues = transData[splitAttribute];
 95     //选择切分值
 96     double splitValue = findMiddleValue(splitAttributeValues);
 97     //cout << "splitValue" << splitValue  << endl;
 98
 99     // 根据选定的切分属性和切分值,将数据集分为两个子集
100     vector<vector<double> > subset1;
101     vector<vector<double> > subset2;
102     for (unsigned i = 0; i < samplesNum; ++i)
103     {
104         if (splitAttributeValues[i] == splitValue && tree->root.empty())
105             tree->root = data[i];
106         else
107         {
108             if (splitAttributeValues[i] < splitValue)
109                 subset1.push_back(data[i]);
110             else
111                 subset2.push_back(data[i]);
112         }
113     }
114
115     //子集递归调用buildKdTree函数
116
117     tree->leftChild = new KdTree;
118     tree->leftChild->parent = tree;
119     tree->rightChild = new KdTree;
120     tree->rightChild->parent = tree;
121     buildKdTree(tree->leftChild, subset1, depth + 1);
122     buildKdTree(tree->rightChild, subset2, depth + 1);
123 }
124
125 //逐层打印kd树
126 void printKdTree(KdTree *tree, unsigned depth)
127 {
128     for (unsigned i = 0; i < depth; ++i)
129         cout << "\t";
130
131     for (vector<double>::size_type j = 0; j < tree->root.size(); ++j)
132         cout << tree->root[j] << ",";
133     cout << endl;
134     if (tree->leftChild == NULL && tree->rightChild == NULL )//叶子节点
135         return;
136     else //非叶子节点
137     {
138         if (tree->leftChild != NULL)
139         {
140             for (unsigned i = 0; i < depth + 1; ++i)
141                 cout << "\t";
142             cout << " left:";
143             printKdTree(tree->leftChild, depth + 1);
144         }
145
146         cout << endl;
147         if (tree->rightChild != NULL)
148         {
149             for (unsigned i = 0; i < depth + 1; ++i)
150                 cout << "\t";
151             cout << "right:";
152             printKdTree(tree->rightChild, depth + 1);
153         }
154         cout << endl;
155     }
156 }
157
158
159 //计算空间中两个点的距离
160 double measureDistance(vector<double> point1, vector<double> point2, unsigned method)
161 {
162     if (point1.size() != point2.size())
163     {
164         cerr << "Dimensions don‘t match!!" ;
165         exit(1);
166     }
167     switch (method)
168     {
169         case 0://欧氏距离
170             {
171                 double res = 0;
172                 for (vector<double>::size_type i = 0; i < point1.size(); ++i)
173                 {
174                     res += pow((point1[i] - point2[i]), 2);
175                 }
176                 return sqrt(res);
177             }
178         case 1://曼哈顿距离
179             {
180                 double res = 0;
181                 for (vector<double>::size_type i = 0; i < point1.size(); ++i)
182                 {
183                     res += abs(point1[i] - point2[i]);
184                 }
185                 return res;
186             }
187         default:
188             {
189                 cerr << "Invalid method!!" << endl;
190                 return -1;
191             }
192     }
193 }
194 //在kd树tree中搜索目标点goal的最近邻
195 //输入:目标点;已构造的kd树
196 //输出:目标点的最近邻
197 vector<double> searchNearestNeighbor(vector<double> goal, KdTree *tree)
198 {
199     /*第一步:在kd树中找出包含目标点的叶子结点:从根结点出发,
200     递归的向下访问kd树,若目标点的当前维的坐标小于切分点的
201     坐标,则移动到左子结点,否则移动到右子结点,直到子结点为
202     叶结点为止,以此叶子结点为“当前最近点”
203     */
204     unsigned k = tree->root.size();//计算出数据的维数
205     unsigned d = 0;//维度初始化为0,即从第1维开始
206     KdTree* currentTree = tree;
207     vector<double> currentNearest = currentTree->root;
208     while(!currentTree->isLeaf())
209     {
210         unsigned index = d % k;//计算当前维
211         if (currentTree->rightChild->isEmpty() || goal[index] < currentNearest[index])
212         {
213             currentTree = currentTree->leftChild;
214         }
215         else
216         {
217             currentTree = currentTree->rightChild;
218         }
219         ++d;
220     }
221     currentNearest = currentTree->root;
222
223     /*第二步:递归地向上回退, 在每个结点进行如下操作:
224     (a)如果该结点保存的实例比当前最近点距离目标点更近,则以该例点为“当前最近点”
225     (b)当前最近点一定存在于某结点一个子结点对应的区域,检查该子结点的父结点的另
226     一子结点对应区域是否有更近的点(即检查另一子结点对应的区域是否与以目标点为球
227     心、以目标点与“当前最近点”间的距离为半径的球体相交);如果相交,可能在另一
228     个子结点对应的区域内存在距目标点更近的点,移动到另一个子结点,接着递归进行最
229     近邻搜索;如果不相交,向上回退*/
230
231     //当前最近邻与目标点的距离
232     double currentDistance = measureDistance(goal, currentNearest, 0);
233
234     //如果当前子kd树的根结点是其父结点的左孩子,则搜索其父结点的右孩子结点所代表
235     //的区域,反之亦反
236     KdTree* searchDistrict;
237     if (currentTree->isLeft())
238     {
239         if (currentTree->parent->rightChild == NULL)
240             searchDistrict = currentTree;
241         else
242             searchDistrict = currentTree->parent->rightChild;
243     }
244     else
245     {
246         searchDistrict = currentTree->parent->leftChild;
247     }
248
249     //如果搜索区域对应的子kd树的根结点不是整个kd树的根结点,继续回退搜索
250     while (searchDistrict->parent != NULL)
251     {
252         //搜索区域与目标点的最近距离
253         double districtDistance = abs(goal[(d+1)%k] - searchDistrict->parent->root[(d+1)%k]);
254
255         //如果“搜索区域与目标点的最近距离”比“当前最近邻与目标点的距离”短,表明搜索
256         //区域内可能存在距离目标点更近的点
257         if (districtDistance < currentDistance )//&& !searchDistrict->isEmpty()
258         {
259
260             double parentDistance = measureDistance(goal, searchDistrict->parent->root, 0);
261
262             if (parentDistance < currentDistance)
263             {
264                 currentDistance = parentDistance;
265                 currentTree = searchDistrict->parent;
266                 currentNearest = currentTree->root;
267             }
268             if (!searchDistrict->isEmpty())
269             {
270                 double rootDistance = measureDistance(goal, searchDistrict->root, 0);
271                 if (rootDistance < currentDistance)
272                 {
273                     currentDistance = rootDistance;
274                     currentTree = searchDistrict;
275                     currentNearest = currentTree->root;
276                 }
277             }
278             if (searchDistrict->leftChild != NULL)
279             {
280                 double leftDistance = measureDistance(goal, searchDistrict->leftChild->root, 0);
281                 if (leftDistance < currentDistance)
282                 {
283                     currentDistance = leftDistance;
284                     currentTree = searchDistrict;
285                     currentNearest = currentTree->root;
286                 }
287             }
288             if (searchDistrict->rightChild != NULL)
289             {
290                 double rightDistance = measureDistance(goal, searchDistrict->rightChild->root, 0);
291                 if (rightDistance < currentDistance)
292                 {
293                     currentDistance = rightDistance;
294                     currentTree = searchDistrict;
295                     currentNearest = currentTree->root;
296                 }
297             }
298         }//end if
299
300         if (searchDistrict->parent->parent != NULL)
301         {
302             searchDistrict = searchDistrict->parent->isLeft()?
303                             searchDistrict->parent->parent->rightChild:
304                             searchDistrict->parent->parent->leftChild;
305         }
306         else
307         {
308             searchDistrict = searchDistrict->parent;
309         }
310         ++d;
311     }//end while
312     return currentNearest;
313 }
314
315 int main()
316 {
317     vector<vector<double> > train(6, vector<double>(2, 0));
318     for (unsigned i = 0; i < 6; ++i)
319         for (unsigned j = 0; j < 2; ++j)
320             train[i][j] = data[i][j];
321
322     KdTree* kdTree = new KdTree;
323     buildKdTree(kdTree, train, 0);
324
325     printKdTree(kdTree, 0);
326
327     vector<double> goal;
328     goal.push_back(3);
329     goal.push_back(4.5);
330     vector<double> nearestNeighbor = searchNearestNeighbor(goal, kdTree);
331     vector<double>::iterator beg = nearestNeighbor.begin();
332     cout << "The nearest neighbor is: ";
333     while(beg != nearestNeighbor.end()) cout << *beg++ << ",";
334     cout << endl;
335     return 0;
336 }

5. 运行

下面是用上面举例构造的kd树求点(3,4.5)的最近邻:

参考文献:李航《统计学习方法》,维基百科

时间: 2024-10-25 15:33:21

k近邻法的C++实现:kd树的相关文章

K近邻法(KNN)原理小结

K近邻法(k-nearst neighbors,KNN)是一种很基本的机器学习方法了,在我们平常的生活中也会不自主的应用.比如,我们判断一个人的人品,只需要观察他来往最密切的几个人的人品好坏就可以得出了.这里就运用了KNN的思想.KNN方法既可以做分类,也可以做回归,这点和决策树算法相同. KNN做回归和分类的主要区别在于最后做预测时候的决策方式不同.KNN做分类预测时,一般是选择多数表决法,即训练集里和预测的样本特征最近的K个样本,预测为里面有最多类别数的类别.而KNN做回归时,一般是选择平均

统计学习方法 (第3章)K近邻法 学习笔记

第3章 K近邻法 k近邻算法简单.直观:给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的k个实例,这k个实例的多数属于某个类,就把该输入实例分为这个类.当K=1时,又称为最近邻算法,这时候就是将训练数据集中与x最邻近点作为x的类. 3.1 k近邻模型 模型由三个基本要素--距离度量.k值得选择.和分类决策规则决定. 3.1.1 距离度量 p=2时,称为欧式距离,p=1时,称为曼哈顿距离. 3.1.2 k值的选择 k 值的选择会对k 近邻法的结果产生重大影响.如果选择较小的k

《统计学习方法》第三章,k 近邻法

? k 近邻法来分类,用到了 kd 树的建立和搜索 ● 代码 1 import numpy as np 2 import matplotlib.pyplot as plt 3 from mpl_toolkits.mplot3d import Axes3D 4 from mpl_toolkits.mplot3d.art3d import Poly3DCollection 5 from matplotlib.patches import Rectangle 6 import operator 7 i

scikit-learn K近邻法类库使用小结

在K近邻法(KNN)原理小结这篇文章,我们讨论了KNN的原理和优缺点,这里我们就从实践出发,对scikit-learn 中KNN相关的类库使用做一个小结.主要关注于类库调参时的一个经验总结. 一.scikit-learn 中KNN相关的类库概述 在scikit-learn 中,与近邻法这一大类相关的类库都在sklearn.neighbors包之中.KNN分类树的类是KNeighborsClassifier,KNN回归树的类是KNeighborsRegressor.除此之外,还有KNN的扩展,即限

统计学习方法与Python实现(二)——k近邻法

统计学习方法与Python实现(二)——k近邻法 iwehdio的博客园:https://www.cnblogs.com/iwehdio/ 1.定义 k近邻法假设给定一个训练数据集,其中的实例类别已定.分类时,对新的实例,根据其k个最近邻的训练实例的类别,通过多数表决的方式进行预测.k近邻法不具有显式的学习过程,而实际上是利用训练数据集对特征空间进行划分,并作为其分类的模型.k近邻法的三个基本要素是 k值的选择.距离度量和分类决策规则. k近邻法的模型是将特征空间划分成一些称为单元的子空间,并且

3.K近邻法

1. k 近邻算法k近邻法(k-nearest neighbor, k-NN) 是一种基本分类与回归方法.  k近邻法的输入为实例的特征向量, 对应于特征空间的点: 输出为实例的类别, 可以取多类. k近邻法假设给定一个训练数据集, 其中的实例类别已定. 分类时, 对新的实例, 根据其k个最近邻的训练实例的类别, 通过多数表决等方式进行预测.因此, k近邻法不具有显式的学习过程. k近邻法实际上利用训练数据集对特征向量空间进行划分, 并作为其分类的“模型”. k值的选择. 距离度量及分类决策规则

李航统计学习方法——算法2——k近邻法

一.K近邻算法 k近邻法(k-nearest neighbor,k-NN)是一种基本分类与回归方法,输入实例的特征向量,输出实例的类别,其中类别可取多类 二.k近邻模型 2.1 距离度量 距离定义: (1)当p=1,称为曼哈顿距离 (2)当p=2,称为欧式距离 (3)当p取无穷大时,它是各个坐标距离的最大值 max|xi-xj| 注意:p值的选择会影响分类结果,例如二维空间的三个点 x1=(1,1),x2=(5,1), x3=(4,4) 由于x1和x2只有第二维上不同,不管p值如何变化,Lp始终

【黎明传数==&gt;机器学习速成宝典】模型篇04——k近邻法【kNN】(附python代码)

目录 什么是k近邻算法 模型的三个基本要素 构造kd树 搜索kd树 Python代码(sklearn库) 什么K近邻算法(k-Nearest Neighbor,kNN) 引例 假设有数据集,其中前6部是训练集(有属性值和标记),我们根据训练集训练一个KNN模型,预测最后一部影片的电影类型. 首先,将训练集中的所有样例画入坐标系,也将待测样例画入 然后计算待测分类的电影与所有已知分类的电影的欧式距离 接着,将这些电影按照距离升序排序,取前k个电影,假设k=3,那么我们得到的电影依次是<He's N

K近邻法【机器学习】

K近邻模型的3个要素 1.距离度量(如欧式距离) 2.k值的选择 3.分类决策规则(如多数表决) 线性搜索时间复杂度较高,因而引入了KD树这一数据结构,加快搜索. 构造KD树 搜索KD树 如果实例点是随是随机分布的,kd树搜索复杂度是O(logN),这里N是训练实例数,kd树更适合于训练实例数远大于空间维数时的k近邻搜索. 当空间维数接近训练实例数时,它的效率会迅速下降,几乎接近线性扫描 原文地址:https://www.cnblogs.com/shengwang/p/9756309.html