1 # coding:utf-8 2 import numpy as np 3 import matplotlib.pyplot as plt 4 5 T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]] 6 S=[7, 3] 7 8 class node: 9 def __init__(self, point): 10 self.left = None 11 self.right = None 12 self.point = point 13 self.parent = None 14 pass 15 16 def set_left(self, left): 17 if left == None: pass 18 left.parent = self 19 self.left = left 20 21 def set_right(self, right): 22 if right == None: pass 23 right.parent = self 24 self.right = right 25 26 def median(lst): 27 m = len(lst) / 2 28 return lst[m], m 29 30 def build_kdtree(data, d): 31 data = sorted(data, key=lambda x: x[d]) 32 p, m = median(data) 33 tree = node(p) 34 del data[m] 35 if m > 0: tree.set_left(build_kdtree(data[:m], not d)) 36 if len(data) > 1: tree.set_right(build_kdtree(data[m:], not d)) 37 return tree 38 39 def distance(a, b): 40 return ((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2) ** 0.5 41 42 def search_kdtree(tree, target,best=[]): 43 if len(best)==0: best = [tree.point,distance(tree.point, target)] 44 if target[0] < tree.point[0]: 45 if tree.left != None: 46 return search_kdtree(tree.left, target, best) 47 else: 48 if tree.right != None: 49 return search_kdtree(tree.right, target, best) 50 def update_best(t, best): 51 if t == None: return 52 t = t.point 53 d = distance(t, target) 54 if d < best[1]: 55 best[1] = d 56 best[0] = t 57 while (tree.parent != None): 58 update_best(tree.parent.left, best) 59 update_best(tree.parent.right, best) 60 tree = tree.parent 61 return best[0] 62 63 def showT(tree,d): 64 plt.plot(tree.point[0],tree.point[1],‘ob‘) 65 if tree.parent==None: 66 plt.plot([tree.point[0],tree.point[0]],[0,10]) 67 elif d: 68 if tree.point[0]<tree.parent.point[0]: 69 plt.plot([0,tree.parent.point[0]],[tree.point[1],tree.point[1]]) 70 else: 71 plt.plot([tree.parent.point[0],10],[tree.point[1],tree.point[1]]) 72 else: 73 if tree.point[1]<tree.parent.point[1]: 74 plt.plot([tree.point[0],tree.point[0]],[0,tree.parent.point[1]]) 75 else: 76 plt.plot([tree.point[0],tree.point[0]],[tree.parent.point[1],10]) 77 if tree.left != None: 78 showT(tree.left,not d) 79 if tree.right != None: 80 showT(tree.right,not d) 81 82 kd_tree = build_kdtree(T, 0) 83 showT(kd_tree,0) 84 plt.annotate(‘S‘,xy = (S[0],S[1]+0.2)) 85 plt.plot(S[0],S[1],‘^r‘) 86 result=search_kdtree(kd_tree,S) 87 print result #[7, 2] 88 plt.show()
时间: 2024-10-01 13:31:37