题目要求
- 根据数据集构造
kd-tree
- 基于
kd-tree
,对于给定的x
,输出其最近邻元素及其欧式距离 - 基于
kd-tree
,对于给定的x
,和正整数n
,输出其n
个最近邻元素列表及其距离值
基本原理
$K-NearestNeighbor$,每个样本点都可以用它最近的K个近邻值来代表。
又名,基于实例的学习
代码实现
# heapq 优先队列算法,是一个原生的 python list, 0 号元素总为最小的元素
from collections import namedtuple
from math import sqrt
#定义一个命名元组
result = namedtuple("Result", "nearest_point nearest_dist nodes_visited")
data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
class Node:
def __init__(self, point, d, left, right, cnt):
self.point = point
self.d = d #维度
self.left = left
self.right = right
self.cnt = cnt #节点以下的数目(包括该节点)
class KdTree:
def __init__(self, data):
if data:
k = len(data[0])
def create(d, data_set):
if not data_set: return None
data_set.sort(key = lambda x: x[d])
pos = len(data_set)//2 # 需要分开的位置,最后会向右取
newd = (d+1)%2
return Node(data_set[pos], d, create(newd, data_set[:pos]), create(newd, data_set[pos+1:]), len(data_set))
self.root = create(0, data)
def preorder(self):
def fun(node):
print(node.point)
print(node.cnt)
if node.left: fun(node.left)
if node.right: fun(node.right)
fun(self.root)
def test(self, point):
k = len(point) # 数据维度
def travel(node, target, max_dist):
if node is None: return result([0] * k, float("inf"), 0) #出口0
nodes_visited = 1
d = node.d #比较的维度
point = node.point
nearer_node = node.left if target[d] <= point[d] else node.right #下一步走的两个点
further_node = node.right if target[d] <= point[d] else node.left
temp1 = travel(nearer_node, target, max_dist) # 进行遍历找到包含目标点的区域
nearest = temp1.nearest_point # 以此叶结点作为“当前最近点”
dist = temp1.nearest_dist # 更新最近距离
nodes_visited += temp1.nodes_visited
if dist < max_dist:
max_dist = dist # 最近点将在以目标点为球心,max_dist为半径的超球体内
if max_dist < abs(point[d] - target[d]): #出口1,另一超矩形无用
return result(nearest, dist, nodes_visited)
temp_dist = sqrt(sum((x - y)**2 for x,y in zip(point, target)))
if temp_dist < dist: # 如果“更近”
nearest = point # 更新最近点
dist = temp_dist # 更新最近距离
max_dist = dist # 更新超球体半径
# 检查另一个子结点对应的区域是否有更近的点
temp2 = travel(further_node, target, max_dist)
nodes_visited += temp2.nodes_visited
if temp2.nearest_dist < dist: # 如果另一个子结点内存在更近距离
nearest = temp2.nearest_point # 更新最近点
dist = temp2.nearest_dist # 更新最近距离
return result(nearest, dist, nodes_visited)
return travel(self.root, point, float('inf'))
if __name__ == "__main__":
kd = KdTree(data)
#kd.preorder()
print(kd.test([2,2]))
print(kd.test([11,12]))
#第三问写不出来,想用heapq的