博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Python 实现 KD-Tree 最近邻算法
阅读量:6801 次
发布时间:2019-06-26

本文共 2421 字,大约阅读时间需要 8 分钟。

这里将写了一个KDTree类,仅实现了最近邻,K近邻之后若有时间再更新:

from collections import namedtuplefrom operator import itemgetterfrom pprint import pformatimport numpy as npclass Node(namedtuple('Node', 'location left_child right_child')):    def __repr__(self):        return pformat(tuple(self))class KDTree():    def __init__(self, points):        self.tree = self._make_kdtree(points)        if len(points) > 0:            self.k = len(points[0])        else:            self.k = None    def _make_kdtree(self, points, depth=0):        if not points:            return None        k = len(points[0])        axis = depth % k        points.sort(key=itemgetter(axis))        median = len(points) // 2        return Node(            location=points[median],            left_child=self._make_kdtree(points[:median], depth + 1),            right_child=self._make_kdtree(points[median + 1:], depth + 1))    def find_nearest(self,                     point,                     root=None,                     axis=0,                     dist_func=lambda x, y: np.linalg.norm(x - y)):        if root is None:            root = self.tree            self._best = None        # 若不是叶节点,则继续向下走        if root.left_child or root.right_child:            new_axis = (axis + 1) % self.k            if point[axis] < root.location[axis] and root.left_child:                self.find_nearest(point, root.left_child, new_axis)            elif root.right_child:                self.find_nearest(point, root.right_child, new_axis)        # 回溯:尝试更新 best        dist = dist_func(root.location, point)        if self._best is None or dist < self._best[0]:            self._best = (dist, root.location)        # 若超球与另一边超矩形相交        if abs(point[axis] - root.location[axis]) < self._best[0]:            new_axis = (axis + 1) % self.k            if root.left_child and point[axis] >= root.location[axis]:                self.find_nearest(point, root.left_child, new_axis)            elif root.right_child and point[axis] < root.location[axis]:                self.find_nearest(point, root.right_child, new_axis)        return self._best

测试:

point_list = [(2, 3, 3), (5, 4, 4), (9, 6, 7), (4, 7, 7), (8, 1, 1), (7, 2, 2)]kdtree = KDTree(point_list)point = np.array([5, 5, 5])print(kdtree.find_nearest(point))

输出:

(1.4142135623730951, (5, 4, 4))

与 Scikit-Learn 性能对比(上是我的实现,下是 Scikit-Learn 的实现):

1331521-20190320170654201-1994937339.png

1331521-20190320170703983-471788386.png

可以看到仅相差 1 毫秒,所以性能说得过去。

(本文完)

转载于:https://www.cnblogs.com/gscnblog/p/10566157.html

你可能感兴趣的文章
h5实体
查看>>
模板字符串
查看>>
使用WebDriver遇到的一些问题汇总
查看>>
AI:你们是不是在等一顶红帽子?
查看>>
三周第二次课 3.4 usermod命令 3.5 用户密码管理 3.6 mkpasswd命令
查看>>
六周第一次课 9.1 正则介绍_grep上 9.2 grep中 9.3 grep下
查看>>
Window 2012 R2系统从无命令行配置开启GUI的功能,实现操作系统图形化界面。
查看>>
ToastUtil,一个简单的Toast封装
查看>>
如何在Centos7进行网络配置
查看>>
orabbix结合python发送图形报表
查看>>
Android权限处理分类
查看>>
找不到TouchCopy16激活码位置?
查看>>
IT兄弟连 JavaWeb教程 URI、URL
查看>>
为什么说甲骨文裁员也属无奈之举?
查看>>
Wdatepicker日期控件的使用指南
查看>>
数据包结构分析
查看>>
[转]一位前辈对IT工程师的职业前途的一点个人看法
查看>>
Linux文件系统只读Read-only file system
查看>>
Migrate to new vCenter Server while using dvSwitches
查看>>
ConcurrentHashMap原理分析
查看>>