Google interview question: k-nearest neighbor (k-d tree)

Question:

You are given information about hotels in a country/city. X and Y coordinates of each hotel are known. You need to suggest the list of nearest hotels to a user who is querying from a particular point (X and Y coordinates of the user are given). Distance is calculated as the straight line distance between the user and the hotel coordinates.

假设数据大小为N,需要寻找k个最近的酒店,最直接的做法就是计算每一家酒店离查询坐标的距离,用一个小堆来记录最近的k个酒店,时间复杂度为O(Nlog(k)),空间复杂度为O(k)。

我们可以通过对数据进行预处理来达到优化查询效率的方法。先对所有酒店的坐标按x坐标排序。对于查询坐标(x,y),给定a(一个猜测的值),通过二分查找区间[x-a,x+a],可以获得所有x坐标在区间内的酒店,再通过上一个方法的小堆记录最近的k个酒店。我们并不十分关心数据预处理的效率,时间复杂度为O(Nlog(N)),对于查询,二分查找的时间复杂度为O(log(N)),假设通过二分查找筛选出的结果有m个,第二步的时间复杂度为O(mlog(k))。因此,对于查询的时间复杂度为O(log(N))+O(mlog(k))。这个做法的问题在于如何确定a的值,不适当的选取会导致(1)m过大,使得查询效率降低;(2)结果不准确,因为可能的k最近酒店在x区间之外。如果不要求结果完全准确地近似做法,其效率远高于上一个方法。

第二个方法将数据对x坐标进行了排序,但y坐标仍然是无序的。我们是否可以进一步优化数据的结构,来提高查询的效率?k-d tree是一种可以考虑的数据结构来解决这类问题。有关k-d tree的概念和原理,以下这篇文章介绍的非常详细。

http://web.stanford.edu/class/cs106l/handouts/assignment-3-kdtree.pdf

首先,我们的问题是二维平面上的点,因此在这题中只需要实现k-d tree的二维情况。先定义点和k-d tree的节点。

class Point2D{
    int x;
    int y;
    public Point2D(int x, int y){
        this.x = x;
        this.y = y;
    }
}
class KdTreeNode{
    Point2D val;
    KdTreeNode left;
    KdTreeNode right;
    public KdTreeNode(Point2D p) {
        this.val = p;
    }
}

对于k-d tree,其定义方式有很多。有些实现中所有数据存放在叶子节点,内部节点只存放划分空间的信息。在这里,只需当做和普通的binary search tree一样处理。

接下来是k-d tree的构建。在对k-d tree有一定了解之后,我们知道对于树的每一层,交替进行这对x和y轴的划分。在这个题目中,k-d tree的构建属于数据的预处理,静态的数据,并不需要考虑k-d tree的插入删除等操作。我们选择以x坐标划分,选择的坐标作为一个节点,将数据划分为两个部分,左边部分所有数据的x坐标都不大于该节点的x坐标,右边部分所有数据的x坐标都不小于该节点的x坐标。然后递归进行,定义根节点为第0层,当层数为偶数时以x坐标划分,当层数为奇数时以y坐标划分。那如何选择划分的节点坐标?为了使k-d tree的查询是高效的,构建的k-d tree需要平衡,因此选择的节点是x/y坐标的中位数。用selection algorithm,可以在O(N)时间内找到该节点,将数据均匀地划分。

    public static KdTreeNode constructKdTree(Point2D[] array, int depth, int low, int high){
        if(low > high) return null;
        if(low == high) return new KdTreeNode(array[low]);
        int mid = low+(high-low)/2;
        Point2D p = quickSelect(array, mid, low, high, depth%2);
        KdTreeNode node = new KdTreeNode(p);
        node.left = constructKdTree(array, depth+1, low, mid-1);
        node.right = constructKdTree(array, depth+1, mid+1, high);
        return node;
    }
    public static Point2D quickSelect(Point2D[] array, int k,  int low, int high, int dimension){
        while(low<=high){
            int pivotIndex = partition(array, low, high, new Random().nextInt(high-low+1)+low,dimension);
            if(pivotIndex == k) return array[k];
            else if(pivotIndex < k) low = pivotIndex+1;
            else high = pivotIndex-1;
        }
        return null;
    }
    public static int partition(Point2D[] array, int low, int high, int pivot, int dimension){
        int pivotVal = dimension==0?array[pivot].x:array[pivot].y;
        swap(array, pivot, high);
        int index = low;
        for(int i=low;i<high;i++){
            int curVal = dimension==0?array[i].x:array[i].y;
            if(curVal<pivotVal){
                swap(array, index, i);
                index++;
            }
        }
        swap(array, high, index);
        return index;
    }
    public static void swap(Point2D[] array, int i, int j){
        if(i!=j){
            Point2D tmp = array[i];
            array[i] = array[j];
            array[j] = tmp;
        }
    }

k-d tree的构建完成,时间复杂度为O(Nlog(N)),空间复杂度为O(N)。接下来是k-d tree的查询操作。我们的问题是要获得最近的酒店列表。在链接的文章中介绍了如何查询最近的坐标和最近的k个坐标。对于这题,我只实现返回最近的坐标。

首先通过查询坐标寻找该坐标所在的划分空间,记录下遍历路径中的离该坐标最近的点。然后根据这个距离r得到中心点为查询坐标,半径为r的搜索空间,再次对kd-tree查询是否存在更近的点。

    static Point2D nearestPoint = new Point2D(Integer.MAX_VALUE, Integer.MAX_VALUE);
    static int min = Integer.MAX_VALUE;
    public static void queryHelper(KdTreeNode root, Point2D query, int depth){
        if(root == null) return;
        int distance = (query.x-root.val.x)*(query.x-root.val.x)+(query.y-root.val.y)*(query.y-root.val.y);
        if(distance < min){
            min = distance;
            nearestPoint = root.val;
        }
        int curVal = depth%2==0?query.x:query.y;
        int nodeVal = depth%2==0?query.x:query.y;
        if(curVal > nodeVal) queryHelper(root.right, query, depth+1);
        else if(curVal < nodeVal) queryHelper(root.left, query, depth+1);
        else{
            queryHelper(root.right, query, depth+1);
            queryHelper(root.left, query, depth+1);
        }
    }
    public static void queryNearestHelper(KdTreeNode root, Point2D query, int depth, double xMin, double xMax, double yMin, double yMax){
        if(root == null) return;
        int distance = (query.x-root.val.x)*(query.x-root.val.x)+(query.y-root.val.y)*(query.y-root.val.y);
        if(distance < min){
            min = distance;
            nearestPoint = root.val;
        }
        int curVal = depth%2==0?query.x:query.y;
        int nodeVal = depth%2==0?query.x:query.y;
        double rangeMin = depth%2==0?xMin:yMin;
        double rangeMax = depth%2==0?xMax:yMax;
        if(curVal > nodeVal){
            queryNearestHelper(root.right, query, depth+1, xMin, xMax, yMin, yMax);
            if(nodeVal > rangeMin) queryNearestHelper(root.left, query, depth+1, xMin, xMax, yMin, yMax);
        }
        else if(curVal < nodeVal){
            queryNearestHelper(root.left, query, depth+1, xMin, xMax, yMin, yMax);
            if(nodeVal < rangeMax)  queryNearestHelper(root.right, query, depth+1, xMin, xMax, yMin, yMax);
        }
        else{
            queryHelper(root.right, query, depth+1);
            queryHelper(root.left, query, depth+1);
        }
    }
    public static void queryNearest(KdTreeNode root, Point2D query){
        queryHelper(root, query, 0);
        double xMin = query.x-Math.sqrt(min), xMax = query.x+Math.sqrt(min), yMin = query.y-Math.sqrt(min), yMax = query.y+Math.sqrt(min);
        queryNearestHelper(root, query, 0, xMin, xMax, yMin, yMax);
    }

至此,我们完成了最近酒店的查询,时间复杂度为O(log(N))。

时间: 2024-10-13 03:04:57

Google interview question: k-nearest neighbor (k-d tree)的相关文章

K Nearest Neighbor 算法

K Nearest Neighbor算法又叫KNN算法,这个算法是机器学习里面一个比较经典的算法, 总体来说KNN算法是相对比较容易理解的算法.其中的K表示最接近自己的K个数据样本.KNN算法和K-Means算法不同的是,K-Means算法用来聚类,用来判断哪些东西是一个比较相近的类型,而KNN算法是用来做归类的,也就是说,有一个样本空间里的样本分成很几个类型,然后,给定一个待分类的数据,通过计算接近自己最近的K个样本来判断这个待分类数据属于哪个分类.你可以简单的理解为由那离自己最近的K个点来投

Google interview question: count bounded slices(min/max queue)

Question: A Slice of an array said to be a Bounded slice if Max(SliceArray)-Min(SliceArray)<=K. If Array [3,5,6,7,3] and K=2 provided .. the number of bounded slice is 9, first slice (0,0) in the array Min(0,0)=3 Max(0,0)=3 Max-Min<=K result 0<=2

Google interview question: disjoint-set questions

Question: Given a n,m which means the row and column of the 2D matrix and an array of pair A( size k). Originally, the 2D matrix is all 0 which means there is only sea in the matrix. The list pair has k operator and each operator has two integer A[i]

Google interview question: quickSort-like questions

上一篇总结了mergeSort-like questions,这篇总结一下有关quickSort的问题. Question: Given an array of object A, and an array of object B. All A's have different sizes, and all B's have different sizes. Any object A is of the same size as exactly one object B. We have a f

[C++与机器学习] k-近邻算法(K–nearest neighbors)

C++ with Machine Learning -K–nearest neighbors 我本想写C++与人工智能,但是转念一想,人工智能范围太大了,我根本介绍不完也没能力介绍完,所以还是取了他的子集.我想这应该是一个有关机器学习的系列文章,我会不定期更新文章,希望喜欢机器学习的朋友不宁赐教. 本系列特别之处是与一些实例相结合来系统的讲解有关机器学习的各种算法,由于能力和时间有限,不会向诸如Simon Haykin<<NEURAL NETWORKS>>等大块头详细的讲解某一个领

ML_聚类之Nearest neighbor search

有这么一个问题,说我在看一篇文章,觉得不错,想要从书架的众多书籍中找相类似的文章来继续阅读,这该怎么办? 于是我们想到暴力解决法,我一篇一篇的比对,找出相似的 最近邻的概念很好理解,我们通过计算知道了每一篇文章和目标文章的距离,选择距离最小的那篇作为最相近的候选文章或者距离最小的一些文章作为候选文章集. 让我们转化成更数学的表述方式:      这其实就是一个衡量相似性的问题(•?How do we measure similarity?)要完成上述想法,我们需要解决两大难题: 文档的向量化表示

imshow(K)和imshow(K,[]) 的区别

参考文献 imshow(K)直接显示K:imshow(K,[])显示K,并将K的最大值和最小值分别作为纯白(255)和纯黑(0),中间的K值映射为0到255之间的标准灰度值.

poj 2985 The k-th Largest Group 求第K大数 Treap, Binary Index Tree, Segment Tree

题目链接:点击打开链接 题意:有两种操作,合并集合,查询第K大集合的元素个数.(总操作次数为2*10^5) 解法: 1.Treap 2.树状数组 |-二分找第K大数 |-二进制思想,逼近第K大数 3.线段树 4.... Treap模板(静态数组) #include <math.h> #include <time.h> #include <stdio.h> #include <limits.h> #include <stdlib.h> const

求 区间[a,b]内满足p^k*q*^m(k&gt;m)的数的个数

题目描述: 1<=a,b<=10^18,p,q都是素数  2<=p,q<=10^9; 求在[a,b]内可以表示为  x*p^k*q^m  k > m   的数的个数 分析: 由于要小于b,因此m一定小于 log10(b)/log10(p*q); 因此我们可以枚举m,中间计数的时候需要用到容斥. 具体看代码: #include <iostream> #include <cstdio> #include <cmath> #include <