参考链接: K最近邻居的Python实现

效果说明:

Input:输入Num个Dim维点的坐标,Points.size=(Num,Dim),输入一个目标点坐标Target、查找最近邻点数量K。Output: 求出距离Target最近的K个点的索引和距离。(具体坐标可由索引和Points列表获取)环境要求: Python 3 with numpy and matplotlib

当Dim=2、Num=30、K=4时,绘制图如下:

输出: candidate_index : [ 5 3 21 12 29 20] candidate_distance : [0. 0.1107 0.1316 0.1701 0.2225 0.2656] 【注】这里以5号点作为目标点,它距离自己本身距离为0。

思路:

1、构建kdTree:通过递归构建一个二叉树,以当前空间维度的中位数点作为分割点,依次将空间分割,注意保存每个节点的坐标索引,以及由该节点划分出的左右节点序列和左右空间边界。

注意:这里的左右指的是每个维度的左右边界,默认:左小右大。

Node类参数说明: 这里没有将点的具体坐标信息赋予节点,而是保存节点对应的坐标索引,这样需要坐标值时根据索引调用坐标即可,也比较容易debug。

self.mid                         # 节点索引(中位数)

self.left                        # 节点左空间索引列表

self.right = right                # 节点右空间索引列表

self.bound = bound  # Dim * 2    # 当前节点所在空间范围(每个维度由左右边界控制)

self.flag = flag                # 表示该节点对应的分割线应分割的维度索引(通过取模来控制变化)

self.lchild = lchild            # 左子节点地址

self.rchild = rchild            # 右子节点地址

self.par = par                    # 父节点地址

self.l_bound = l_bound            # 节点左空间范围

self.r_bound = r_bound            # 节点右空间范围

self.side = side                # 当前节点是其父节点的左节点(0)或右节点(1)

2、确定初始节点(空间)

3、查找K近邻(具体详见参考书或与基础理论相关的博文)

# kd_Tree

# Edited By ocean_waver

import numpy as np

import matplotlib.pyplot as plt

class Node(object):

def __init__(self, mid, left, right, bound, flag, lchild=None, rchild=None, par=None,

l_bound=None, r_bound=None, side=-1):

self.mid = mid

self.left = left

self.right = right

self.bound = bound  # Dim * 2

self.flag = flag

self.lchild = lchild

self.rchild = rchild

self.par = par

self.l_bound = l_bound

self.r_bound = r_bound

self.side = side

def find_median(a):

# s = np.sort(a)

arg_s = np.argsort(a)

idx_mid = arg_s[len(arg_s) // 2]

idx_left = np.array([arg_s[j] for j in range(0, len(arg_s) // 2)], dtype='int32')

idx_right = np.array([arg_s[j] for j in range(len(arg_s) // 2 + 1, np.size(a))], dtype='int32')

return idx_mid, idx_left, idx_right

def kd_tree_establish(root, points, dim):

# print(root.mid)

layer_flag = (root.flag + 1) % dim    # 确定分割点对应的分割线的维度

if dim == 2:

static_pos = points[root.mid, root.flag]

if root.flag == 0:

x_line = np.linspace(static_pos, static_pos, 10)

y_line = np.linspace(root.bound[1, 0], root.bound[1, 1], 10)

elif root.flag == 1:

x_line = np.linspace(root.bound[0, 0], root.bound[0, 1], 10)

y_line = np.linspace(static_pos, static_pos, 10)

plt.plot(x_line, y_line, color='darkorange')

# plt.axis([0, 1, 0, 1])

# plt.draw()

# plt.pause(0.05)

# new bound:

root.l_bound = root.bound.copy()    # 先复制一份根节点边界(Note: need to use deep copy!)

root.l_bound[root.flag, 1] = points[root.mid, root.flag]  # 改变特定边界的最大值,获取新边界

root.r_bound = root.bound.copy()

root.r_bound[root.flag, 0] = points[root.mid, root.flag]  # 改变特定边界的最小值,获取新边界

if root.left.size > 0:

# print('left : ', root.left)

mid, left, right = find_median(points[root.left, layer_flag])

mid, left, right = root.left[mid], root.left[left], root.left[right]

left_node = Node(mid, left, right, root.l_bound, layer_flag)

root.lchild = left_node

left_node.par = root

left_node.side = 0

kd_tree_establish(left_node, points, dim)

if root.right.size > 0:

# print('right : ', root.right)

mid, left, right = find_median(points[root.right, layer_flag])

mid, left, right = root.right[mid], root.right[left], root.right[right]

right_node = Node(mid, left, right, root.r_bound, layer_flag)

root.rchild = right_node

right_node.par = root

right_node.side = 1

kd_tree_establish(right_node, points, dim)

def distance(a, b, p):

"""

Lp distance:

input: a and b must have equal length

p must be a positive integer, which decides the type of norm

output: Lp distance of vector a-b"""

try:

vector = a - b

except ValueError:

print('Distance : input error !\n the coordinates have different length !')

dis = np.power(np.sum(np.power(vector, p)), 1/p)

return dis

# def search_other_branch(target, branch_node, points, dim):

def judge_cross(circle, branch, dim):

"""

Judge if a sphere in dimension(dim) and the space of the other branch cross each other

cross     : return 1

not cross : return 0"""

# print(circle, branch)

count = 0

for j in range(0, dim):

if circle[j, 1] < branch[j, 0] or circle[j, 0] > branch[j, 1]:

count = count + 1

if count == 0:

return 1    # cross

else:

return 0

if __name__ == '__main__':

# --------基本参数设置--------

Num = 30    # 训练点数量

Dim = 2        # 空间维度

Points = np.random.rand(Num, Dim) + 100    # 产生随机点

# Points = np.array([[127,163,255],[126,165,255],[127,164,255],[127,165,254],[127,165,255],[127,167,253],[126,166,255],[126,167,254]])

# Points = np.array([[  1,  0,  2],[  0,  2,  2],[  1,  1,  2],[  1,  2,  1],[  1,  2,  2],[  1,  4,  0],[  0,  3,  2],[  0,  4,  1]])

Num = Points.shape[0]    # 重新确定点数和维度,调整自定义造成的属性更改

Dim = Points.shape[1]

K = 6        # 查找近邻数量

p = 2        # 计算欧氏距离

# Target = np.array([0.1, 0.9])

Target = np.squeeze(np.random.rand(1, Dim))  # 这里只考虑一个目标点

Target = Points[5, :]  # 设定初始点

'''# Test for find_median()

idx_mid, idx_left, idx_right = find_median(Points[:, 0])

print(Points[:, 0])

print(Points[idx_mid, 0], idx_mid, idx_left, idx_right)'''

# kdTree establish

Mid, Left, Right = find_median(Points[:, 0])

L_bound = np.min(Points, axis=0)

R_bound = np.max(Points, axis=0)

Bound = np.vstack((L_bound, R_bound)).T

Root = Node(Mid, Left, Right, Bound, flag=0)

print('kdTree establish ...')

kd_tree_establish(Root, Points, Dim)

print('kdTree establish Done')

# 定位初始搜索区域

node = Root

temp = Root

side = 0    # 下降定位在终止时点所在的是左侧(side=0)还是右侧(side=1)

while temp is not None:

if Points[temp.mid, temp.flag] > Target[temp.flag]:    # 大于的情况

node = temp

temp = temp.lchild

side = 0

else:   # 包括小于和等于的情况

node = temp

temp = temp.rchild

side = 1

print('start node : ', node.mid, Points[node.mid])

# 搜索最近邻点

can_idx = np.array([], dtype='int32')

can_dis = np.array([])

temp = node

while node is not None:

# min_dis = distance(Target, Points[can_idx[-1]])

search_flag = False

temp_dis = distance(Target, Points[node.mid], 2)

if can_idx.size < K:    # 候选点列表未满

can_idx = np.append(can_idx, node.mid)

can_dis = np.append(can_dis, temp_dis)

elif temp_dis < np.max(can_dis):

can_idx[np.argmax(can_dis)] = node.mid

can_dis[np.argmax(can_dis)] = temp_dis

search_flag = False         # 查看另一支路是否为空

if side == 0 and node.rchild is not None:

branch_bound = node.rchild.bound

branch_list = node.right

search_flag = True

elif side == 1 and node.lchild is not None:

branch_bound = node.lchild.bound

branch_list = node.left

search_flag = True

if search_flag is True:     # 开始判断和搜索另一侧的支路

r = np.max(can_dis)

# 构建Dim维球体边界

temp_bound = np.array([[Target[i]-r, Target[i]+r] for i in range(0, Dim)])

if judge_cross(temp_bound, branch_bound, Dim) == 1:     # 高维球与支路空间存在交叉

for i in branch_list:

a_dis = distance(Target, Points[i], 2)

if can_idx.size < K:            # 候选未满,直接添加

can_idx = np.append(can_idx, i)

can_dis = np.append(can_dis, a_dis)

elif a_dis < np.max(can_dis):   # 候选已满,更近者替换候选最远者

can_idx[np.argmax(can_dis)] = i

can_dis[np.argmax(can_dis)] = a_dis

# 向上更新查找节点

temp = node

side = temp.side    # 更新刚离开的node所处的左右方位

node = node.par

# 输出结果

sort_idx = np.argsort(can_dis)

can_idx = can_idx[sort_idx]

can_dis = can_dis[sort_idx]

print('candidate_index :    ', can_idx)

print('candidate_distance : ', np.round(can_dis, 4))

# print(Points)

if Dim == 2:

# 绘制点

plt.scatter(Points[:, 0], Points[:, 1], color='blue')

for i in range(0, Num):

plt.text(Points[i, 0], Points[i, 1], str(i))

# 绘制框架

plt.scatter(Target[0], Target[1], c='red', s=30)

frame_X = np.array([L_bound[0], R_bound[0], R_bound[0], L_bound[0], L_bound[0]])

frame_Y = np.array([L_bound[1], L_bound[1], R_bound[1], R_bound[1], L_bound[1]])

plt.plot(frame_X, frame_Y, color='black')

# 绘制圆

for i in range(0, K):

n = np.linspace(0, 2*3.14, 300)

x = can_dis[i] * np.cos(n) + Target[0]

y = can_dis[i] * np.sin(n) + Target[1]

plt.plot(x, y, c='lightsteelblue')

# plt.axis([np.min(L_bound), np.max(R_bound), np.min(L_bound), np.max(R_bound)])

plt.draw()

plt.show()

# 验证正确性

print('\n---------- Varification of the Correctness----------\n')

dist_list = np.power(np.sum(np.power(Points - Target, p), 1), 1/p)

sorted_dist_list = np.sort(dist_list)

print('correct_dist_list  : ', np.round(sorted_dist_list[0:K], 4))

print('sorted_dist_list   : ', np.round(sorted_dist_list, 4))

print('original_dist_list : ', np.round(dist_list, 4))

[转载] Python 统计学习方法——kdTree实现K近邻搜索相关推荐

  1. 统计学习方法笔记(一)-k近邻算法原理及python实现

    k近邻法 k近邻算法 算法原理 距离度量 距离度量python实现 k近邻算法实现 案例地址 k近邻算法 kkk近邻法(kkk-NN)是一种基本分类和回归方法. 算法原理 输入:训练集 T={(x1, ...

  2. python实现e2lsh高维数据集k近邻搜索——实现流程

    lsh学习链接: LSH(Locality Sensitive Hashing)原理与实现 对高维数据查询最近邻,推荐使用p-stable LSH: minLSH是针对文档查询最近邻得方法: pyth ...

  3. Python,OpenCV中的K近邻(knn K-Nearest Neighbor)及改进版的K近邻

    Python,OpenCV中的K近邻(knn K-Nearest Neighbor)及改进版的K近邻 1. 效果图 2. 源码 参考 这篇博客将介绍将K-最近邻 (KNN K-Nearest Neig ...

  4. kd树 python实现_kd树 寻找k近邻算法 python实现

    按照链接里的算法写了k近邻的python实现 from math import sqrt class KDnode: def __init__(self, data, left, right, spl ...

  5. python人工智能——机器学习——分类算法-k近邻算法

    分类算法-k近邻算法(KNN) 定义:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别. 来源:KNN算法最早是由Cover和Hart提 ...

  6. 【Python机器学习】多项式回归、K近邻KNN回归的讲解及实战(图文解释 附源码)

    需要源码请点赞关注收藏后评论区留言私信~~~ 多项式回归 非线性回归是用一条曲线或者曲面去逼近原始样本在空间中的分布,它"贴近"原始分布的能力一般较线性回归更强. 多项式是由称为不 ...

  7. 机器学习与python实战(一)-k近邻

    kNN(k-nearest neighbor)算法是一个简单而经典的机器学习分类算法,通过度量"待分类数据"和"类别已知的样本"的距离对样本进行分类. from ...

  8. python人工智能——机器学习——分类算法-k近邻算法——kaggle案例: Facebook V: Predicting Check Ins

    题目及翻译 Facebook and Kaggle are launching a machine learning engineering competition for 2016. Faceboo ...

  9. java实现k 近邻算法_K近邻算法哪家强?KDTree、Annoy、HNSW原理和使用方法介绍

    1.什么是K近邻算法 K近邻算法(KNN)是一种常用的分类和回归方法,它的基本思想是从训练集中寻找和输入样本最相似的k个样本,如果这k个样本中的大多数属于某一个类别,则输入的样本也属于这个类别. 关于 ...

最新文章

  1. 童心制物(Makeblock)受邀参加2020年韩国机器人世界展览会,倡导以先进的STEAM教育培养未来复合型人才
  2. 阿里人工智能实验室新入职两名首席科学家,年薪百万美元
  3. python for
  4. Final Cut Pro快捷键
  5. vue-cli3全面配置详解
  6. 丘成桐现身合肥分享发现数学之美 点赞中国科大年轻学子
  7. 机器学习算法(2)——AdaBoost算法
  8. win7设置计算机共享的打印机共享的打印机共享,win7,xp打印机共享设置软件 一键共享...
  9. matlab中bp神经网络梯度怎么调精度,如何提高BP神经网络模型的预测精度?
  10. 微软Media Creation Tool 创建工具 1.3 升级:支持全新下载安装Win11 22H2 Build 22621.525
  11. stata进行无交互效应模型选择—混合模型,个体效应固定,时间效应固定,双向固定和随机效应
  12. Tableau实战 Tableau官网各版块访问情况(一 ~ 四汇总)仪表盘
  13. matlab系统频域分析,基于MATLAB的系统频域分析的实现
  14. 区块链技术应用场景有哪些?
  15. vue的v-for循环中图片加载路径问题
  16. windows11-USB禁用
  17. PA1.3 代码+笔记
  18. 视频 | 苏炳添的“冠军卧室”曝光,来看看百米飞人的另一面
  19. 数字化是指用计算机,数字化
  20. html之ins标签

热门文章

  1. NYOJ113 - 字符串替换
  2. linux下载tar.gz文件夹,手把手给你细说linux-Ubuntu如何安装tar.gz文件
  3. linux 编程 调度,Linux的进程线程及调度
  4. 怎么区分zh和ch_zh,ch,sh,和z,c,s怎么分辨?
  5. LeetCode刷题目录
  6. linux 下strstr函数,Linux中strchr与strstr函数实现。
  7. bzoj 3383: [Usaco2004 Open]Cave Cows 4 洞穴里的牛之四(set+BFS)
  8. bzoj 3392: [Usaco2005 Feb]Part Acquisition 交易(最短路)
  9. bzoj 3391: [Usaco2004 Dec]Tree Cutting网络破坏
  10. 莫烦python学习笔记之numpy.array,dtype,empty,zeros,ones,arrange,linspace