参考 GitHub

"""
Kd树搜索的k近邻模型,和《统计学习方法》上介绍的最近邻算法差距有点大..
(1) 设定一个当前最优点集合,用来保存当前离搜索点最近的样本点集合
(2) 从根节点开始,并设其为当前节点;在此code中由query执行,下面的(3)(4)(5)(6)(7)由一个函数_search执行
(3) 如果当前节点为空,则更新集合并结束;
(4) 如果被搜索点的划分维度小于当前节点的划分维度,则设当前节点的左孩子为新的下一次应该访问的节点;反之设当前节点的右孩子为下一次应该访问的节点
(5) 如果当前节点到被搜索点的距离小于当前全局最短距离,则更新最优k点集;
(6) 如果被搜索点到当前节点划分维度的距离小于全局最短距离,则说明全局最佳点可能存在于当前节点的另外一个子树中,
所以设当前节点的另外一个孩子为下一次应该访问的节点并执行步骤(3)(4)(5)(6)(7);
(7) 搜索下一次应该访问的节点,它是由(4)设置的,直到叶子节点,即触发(3)算法减少搜索量就是通过只检查有可能足够近的点,而对于不可能是最近的k个点的那些点则永远不理会
哪些是不可能的点?答:假设当前节点离目标点够近(有可能是那k个点),比一下当前点和目标点某维度的大小,左子树都是比当前点小的点,
如果目标点也比当前点小,显然左子树的那些点应该考虑一下是不是的那k个点,因为这时左子节点必然比右子节点离目标点近,再判断一下右子节点离目标点的距离
如果这个距离比最差的最优点都小,那右子树的其他节点就不用考虑了"""
import json
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
import numpy as np
import pandas as pd
import timeclass Node:"""节点类"""def __init__(self, value, index, left_child, right_child):""":param value:节点存储的k维特征:param index:此节点存储的样本点在训练集中的index:param left_child:左子树:param right_child:右子树"""self.value = value.tolist()self.index = indexself.left_child = left_childself.right_child = right_childdef __repr__(self):"""print对象时相当于执行print(对象.__repr__)"""return json.dumps(self, indent=3, default=lambda obj: obj.__dict__, ensure_ascii=False)class KDTree:"""kd tree类"""def __init__(self, data):# 原始训练数据集self.data = np.asarray(data)# 构造的kd树的根节点self.kd_tree = None# 每个节点所存储的样本特征数量,即特征空间的维度self.K = self.data.shape[1]# 创建平衡kd树self._create_kd_tree(data)def __repr__(self):"""调用时相当于print的是kd_tree的根节点,它是一个Node对象,相当于print的是Node类的__repr__"""return str(self.kd_tree)def SplitData_GetNode(self, data, depth):"""根据数据集构造节点:param data:剩下待分割的数据集:param depth:现在构造的节点所处的深度:return:子树的节点"""if len(data) == 0:return Nonesplit_feature_index = depth % self.K# 对数据在split_feature_index维度进行排序data = data[data[:, split_feature_index].argsort()]# 切分点median_index = data.shape[0] // 2# 获取结点在原始训练数据集中的位置node_index = [i for i, v in enumerate(self.data) if list(v) == list(data[median_index])]return Node(# 本结点value=data[median_index],# 本结点在数据集中的位置index=node_index[0],# 左子结点left_child=self.SplitData_GetNode(data[:median_index], depth + 1),# 右子结点right_child=self.SplitData_GetNode(data[median_index + 1:], depth + 1))def _create_kd_tree(self, X):""":param X:数据集,只在初始化时调用,递归的开启:从最原始数据集开始,构造深度为0的节点:return:根节点"""self.kd_tree = self.SplitData_GetNode(X, 0)def query(self, TargetPoint, k=1):"""调用时搜索目标点,开始递归:param TargetPoint: 目标点:param k::return:"""TargetPoint = np.asarray(TargetPoint)results = self._search(TargetPoint, 0, self.kd_tree, k=k, k_neighbor_sets=list())DistanceList_KBestPoints_And_TargetPoint = np.array([results[0] for results in results])IndexList_KBestPoints_in_OriginalData = np.array([results[1] for results in results])return DistanceList_KBestPoints_And_TargetPoint, IndexList_KBestPoints_in_OriginalData@staticmethoddef _cal_node_distance(node1, node2):"""计算两个结点之间的距离"""return np.sqrt(np.sum(np.square(node1 - node2)))def _search(self, point, depth, node=None, k=1, k_neighbor_sets=None):"""搜索附近可能的那k个点:param point: 目标点,即新输入模型的样本点,现在要预测这个点的label:param depth: 当前考察节点的深度:param node: 当前考察点,即判断这个点是否确实是那k个点:param k: k近邻中的k:param k_neighbor_sets: 最优点集,存储的是(节点离目标点的距离, 节点存储值在原始训练集中的位置, 节点值):return:无"""if k_neighbor_sets is None:k_neighbor_sets = []if node is None:return k_neighbor_sets#到了叶节点,停止搜索,附近可能的点已经搜索结束if node.left_child is None and node.right_child is None:# 更新当前k近邻点集return self._update_k_neighbor_sets(k_neighbor_sets, k, node, point)# 递归地向下访问kd树if point[depth % self.K] < node.value[depth % self.K]:direct = 'left'next_node = node.left_childelse:direct = 'right'next_node = node.right_childif next_node is not None:#判断当前结点是否符合进入最优点集的条件k_neighbor_sets = self._update_k_neighbor_sets(k_neighbor_sets, k, next_node, point)#检查兄弟节点对应的区域是否相交,if direct == 'left' and node.right_child is not None:node_distance = self._cal_node_distance(point, node.right_child.value)if k_neighbor_sets[0][0] > node_distance:# 如果相交,递归地进行近邻搜索return self._search(point, depth=depth + 1, node=node.right_child, k=k,k_neighbor_sets=k_neighbor_sets)elif node.left_child is not None:node_distance = self._cal_node_distance(point, node.left_child.value)if k_neighbor_sets[0][0] > node_distance:return self._search(point, depth=depth + 1, node=node.left_child, k=k,k_neighbor_sets=k_neighbor_sets)return self._search(point, depth=depth + 1, node=next_node, k=k, k_neighbor_sets=k_neighbor_sets)def _update_k_neighbor_sets(self, k_neighbor_sets, k, node, point):"""更新最优点集:param k_neighbor_sets:最优点集:param k: k近邻中的k:param node: 符合进入最优点集的条件:param point: 目标点:return: 更新后的最优点集"""# 计算目标点与当前结点的距离node_distance = self._cal_node_distance(point, node.value)if len(k_neighbor_sets) == 0:k_neighbor_sets.append((node_distance, node.index, node.value))elif len(k_neighbor_sets) < k:# 如果“当前k近邻点集”元素数量小于kself._insert_k_neighbor_sets(k_neighbor_sets, node, node_distance)else:# 叶节点距离小于“当前k近邻点集”中最远点距离if k_neighbor_sets[0][0] > node_distance:k_neighbor_sets = k_neighbor_sets[1:] #将差的点移除self._insert_k_neighbor_sets(k_neighbor_sets, node, node_distance)return k_neighbor_sets@staticmethoddef _insert_k_neighbor_sets(k_neighbor_sets, node, node_distance):"""发现新的更符合的点:param k_neighbor_sets:最优点集:param node: 节点:param node_distance:节点与目标点间的距离:return: 无"""n = len(k_neighbor_sets)for i, item in enumerate(k_neighbor_sets):if item[0] < node_distance:# 将距离最远的结点插入到前面k_neighbor_sets.insert(i, (node_distance, node.index, node.value))breakif len(k_neighbor_sets) == n:k_neighbor_sets.append((node_distance, node.index, node.value))iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
data = np.array(df)
X = data[:, 0:4]
y = data[:, 4]
train_X, test_X, train_y, test_y = train_test_split(X, y, test_size=0.3)  # (105, 4) (105,)
# print(train_X.shape,train_y.shape)
my_clock_start = time.perf_counter()
kd_tree = KDTree(train_X)
k = 2
cnt = 0
for i, point in enumerate(test_X):# print("目标点是{},对应标签是{}".format(point,test_y[i]))DistanceList_KBestPoints_And_x, IndexList_KBestPoints_in_OriginalData = kd_tree.query(np.array(point), k=k)predict_y = []for index in IndexList_KBestPoints_in_OriginalData:# print(train_y[index],end=",")predict_y.append(train_y[index])# print("")if (max(predict_y, key=predict_y.count) == test_y[i]):cnt += 1
print("我的模型准确率为{}".format(cnt / len(test_X)))
my_clock_end = time.perf_counter()
print("我的模型花费时间{}".format(my_clock_end - my_clock_start))cnt = 0
from sklearn.neighbors import KNeighborsClassifierclocl_start = time.perf_counter()
model = KNeighborsClassifier(n_neighbors=2)
model.fit(train_X, train_y)
predict_y = model.predict(test_X)
for index, label in enumerate(predict_y):if label == test_y[index]:cnt += 1
print("sklearn模型准确率为{}".format(cnt / len(test_X)))
clock_end = time.perf_counter()
print("sklearn模型花费时间{}".format(clock_end - clocl_start))"""
我的模型准确率为0.9333333333333333    准确率有时比sklearn中的高,要看k的选择
我的模型花费时间0.10518830000000001   时间差接近十倍,因为做了一些其他事
sklearn模型准确率为0.9111111111111111
sklearn模型花费时间0.007459700000000069
"""

Kd树实现K近邻算法相关推荐

  1. 基于KD树的K近邻算法(KNN)算法

    文章目录 KNN 简介 KNN 三要素 距离度量 k值的选择 分类决策规则 KNN 实现 1,构造kd树 2,搜索最近邻 3,预测 用kd树完成最近邻搜索 K近邻算法(KNN)算法,是一种基本的分类与 ...

  2. 基于kd树的k近邻算法——KNN

    1.简介 k近邻算法是机器学习中一种基本的分类与回归算法,对你没听错k近邻算法不仅可以用来做分类,还可以用于回归,英文全称为k-Nearest Neighbor简称k-NN.k近邻算法属于一种有监督学 ...

  3. kd树 python实现_kd树 寻找k近邻算法 python实现

    按照链接里的算法写了k近邻的python实现 from math import sqrt class KDnode: def __init__(self, data, left, right, spl ...

  4. 机器学习基础(四十三)—— kd 树( k 近邻法的实现)

    实现 k 近邻法时,主要考虑的问题是如何对训练数据进行快速 k 近邻搜索,这点在如下的两种情况时,显得尤为必要: (1)特征空间的维度大 (2)训练数据的容量很大时 k 近邻法的最简单的实现是现行扫描 ...

  5. python机器学习 | K近邻算法学习(1)

    K近邻算法学习 1 K近邻算法介绍 1.1算法定义 1.2算法原理 1.3算法讨论 1.3.1 K值选择 1.3.2距离计算 1.3.3 KD树 2 K近邻算法实现 2.1scikit-learn工具 ...

  6. K近邻算法的kd树实现

    k近邻算法的介绍 k近邻算法是一种基本的分类和回归方法,这里只实现分类的k近邻算法. k近邻算法的输入为实例的特征向量,对应特征空间的点:输出为实例的类别,可以取多类. k近邻算法不具有显式的学习过程 ...

  7. 机器学习入门笔记(三):K近邻算法

    文章目录 一.K近邻算法的基本概念 1.1 K近邻算法实现 二.K近邻分类三要素 2.1 距离度量 2.2 K值的选择 2.2.1 基于m-fold cross validation的 K值选择 2. ...

  8. 机器学习算法系列之K近邻算法

    本系列机器学习的文章打算从机器学习算法的一些理论知识.python实现该算法和调一些该算法的相应包来实现. 目录 K近邻算法 一.K近邻算法原理 k近邻算法 通俗解释 近邻距离的度量 k值的选择 KN ...

  9. 机器学习之重点汇总系列(二)——K近邻算法(k-Nearest Neighbor,kNN)

    什么是K近邻算法 引例 假设有数据集,其中前6部是训练集(有属性值和标记),我们根据训练集训练一个KNN模型,预测最后一部影片的电影类型 首先,将训练集中的所有样例画入坐标系,也将待测样例画入 然后计 ...

  10. K近邻快速算法 -- KD树、BBF改进算法

    K近邻算法即是查找与当前点(向量)距离最近的K个点(向量),距离计算一般用欧氏距离. 最简单的方法就是穷举法:计算每个向量与当前向量的欧氏距离,选取最小的K个为所求.但这种方法计算量太大,无法应对大样 ...

最新文章

  1. VISP视觉库框架结构与使用入门
  2. 【Linux】一步一步学Linux——bc命令(233)
  3. 【题解】luogu p1111 修复公路
  4. 以Settings.APPLICATION_DEVELOPMENT_SETTINGS打开开发者面板出错总结
  5. overflow-x后覆盖滚动条
  6. php 模拟提交金数据,小程序提交表单到金数据实例
  7. 图像分割算法的优缺点比较
  8. 一个80年代大学生的悲壮职业人生
  9. VISTA下载全集(下)
  10. 文件夹的菜单栏和地址栏不见了
  11. QQ浏览器网页版微信抓包和IPAD微信抓包 Wireshark
  12. [NOIP2008]笨小猴 T1
  13. 第三方短信平台服务提供商是什么?
  14. 面向Web开发人员和网站管理员的Web缓存指南
  15. 【夜读】一个人最积极的生活状态
  16. ERP开发之看板展示
  17. Unity DOTS简明教程
  18. 《史上最全、最强Java学习路线知识点整理!!全是干货!!》
  19. 大数据告诉你,离开北上广的互联网工程师最终都去了哪里?
  20. 62 Three.js 使用对象组合

热门文章

  1. 资源 | 最新版区块链术语表(中英文对照)2019-1.14
  2. 企业信息化认知的四个误区
  3. 使用Javassist 动态生成类
  4. 本科三级专业目录计算机类,大学本科专业目录
  5. python selenium 下载附件通过oss上传
  6. python生成图表
  7. python用matplotlib或boxplot作图的时候,中文标注无法正常显示,乱码为小方框的解决办法
  8. 希捷移动硬盘官方测试软件,希捷 SeaTools 硬盘检测软件安装使用教程
  9. 基于51单片机的自动电梯控制模拟系统设计
  10. 51单片机连接ESP8266串口WiFi模块