零.广告

本文所有代码实现均可以在 DML 找到,不介意的话请大家在github里给我点个Star

一.引入

K近邻算法作为数据挖掘十大经典算法之一,其算法思想可谓是intuitive,就是从训练集里找离预测点最近的K个样本来预测分类

因为算法思想简单,你可以用很多方法实现它,这时效率就是我们需要慎重考虑的事情,最简单的自然是求出测试样本和训练集所有点的距离然后排序选择前K个,这个是O(nlogn)的,而其实从N个数据找前K个数据是一个很常见的算法题,可以用最大堆(最小堆)实现,其效率是O(nlogk)的,而最广泛的算法是使用kd树来减少扫描的点,这也就是这篇文章的主要内容,本文偏实现,详细理论教程见july的文章 ,不得不服,july这篇文章巨细无遗!

二.前提:堆的实现

堆是一种二叉树,用一个数组存储,对于k号元素,k*2号是其左儿子,k*2+1号是其右儿子

而大根堆就是跟比左儿子和右儿子都大,小根堆反之。

要满足这个条件我们需要通过up( index )操作和down( index )维护它的结构

当然讲这个的文章实在有些多了,随便搜一篇大家看看:点击打开链接

大小根堆的作用是

a) 优先队列:因为第一个元素是最大或者最小的元素,所以可以实现优先队列

b) 前K个最大(最小)值:这里限制堆的大小为k,来获得O( n log k)的效率,但注意此时小根堆是获得前K个最大值,大根堆是获得前K个最小值,插入的时候先把元素和堆顶比较再决定是否插入。

因为事先KD-tree+BBF 要同时用到这两个东西,所以把它们实现在了同一个类里,感觉代码略漂亮,贴出来观赏一下:

此代码是dml / tool / heap.py

[python] view plaincopy
  1. from __future__ import division
  2. import numpy as np
  3. import scipy as sp
  4. def heap_judge(a,b):
  5. return a>b
  6. class Heap:
  7. def __init__(self,K=None,compare=heap_judge):
  8. '''''
  9. 'K'                 is the parameter to restrict the length of Heap
  10. !!! when K is confirmed,the Min heap contain Max K elements
  11. while Max heap contain Min K elements
  12. 'compare'         is the compare function which return a BOOL when pass two variable
  13. default is Max heap
  14. '''
  15. self.K=K
  16. self.compare=compare
  17. self.heap=['#']
  18. self.counter=0
  19. def insert(self,a):
  20. #print self.heap
  21. if self.K!=None:
  22. print a.x,'==='
  23. if self.K==None:
  24. self.heap.append(a)
  25. self.counter+=1
  26. self.up(self.counter)
  27. else:
  28. if self.counter<self.K:
  29. self.heap.append(a)
  30. self.counter+=1
  31. self.up(self.counter)
  32. else:
  33. if (not self.compare(a,self.heap[1])):
  34. self.heap[1]=a
  35. self.down(1)
  36. return
  37. def up(self,index):
  38. if (index==1):
  39. return
  40. '''''
  41. print index
  42. for t in range(index+1):
  43. if t==0:
  44. continue
  45. print self.heap[t].x
  46. print
  47. '''
  48. if self.compare(self.heap[index],self.heap[int(index/2)]):
  49. #fit the condition
  50. self.heap[index],self.heap[int(index/2)]=self.heap[int(index/2)],self.heap[index]
  51. self.up(int(index/2))
  52. return
  53. def down(self,index):
  54. if 2*index>self.counter:
  55. return
  56. tar_index=0
  57. if 2*index<self.counter:
  58. if self.compare(self.heap[index*2],self.heap[index*2+1]):
  59. tar_index=index*2
  60. else:
  61. tar_index=index*2+1
  62. else:
  63. tar_index=index*2
  64. if not self.compare(self.heap[index],self.heap[tar_index]):
  65. self.heap[index],self.heap[tar_index]=self.heap[tar_index],self.heap[index]
  66. self.down(tar_index)
  67. return
  68. def delete(self,index):
  69. self.heap[index],self.heap[self.counter]=self.heap[self.counter],self.heap[index]
  70. self.heap.pop()
  71. self.counter-=1
  72. self.down(index)
  73. pass
  74. def delete_ele(self,a):
  75. try:
  76. t=self.heap.index(a)
  77. except ValueError:
  78. t=None
  79. if t!=None:
  80. self.delete(t)
  81. return t

传入的时候不设置K就是正常的优先队列,设置了K就是限制堆的大小了

compare参数是比较大小的,默认是“数”的大根堆,你可以往堆里传任何类,只要有相适应的compare参数,比如我们KD-tree传的就是KD-Node

三.KD-BFF的原理:

首先从KD-Tree的创建说起:(直接贴《统计学习方法》的内容了)

事实上从选择哪一个feature开始切割,还可以选择方差最大的那个参数,但是考虑到简便,以及我们可以选择更多的相似性度量方法,还是用《统计学习方法》里面的选择方式了。

然后是KD-tree搜索的方法:(来自《统计学习方法》,但注意这里是最近邻,也就是k=1的时候)

那么我们要K近邻要怎么做呢?就是用堆的第二个应用,用大根堆保持K个最小的距离,然后用根的距离(也就是其中最大的一个)来作为判断的依据是否有更近的点不在结果中,这一点很重要!

同时摘录july博客的一段读者留言讲得非常好的:

   在某一层,分割面是第ki维,分割值是kv,那么 abs(q[ki]-kv) 就是没有选择的那个分支的优先级,也就是计算的是那一维上的距离; 同时,从优先队列里面取节点只在某次搜索到叶节点后才发生,计算过距离的节点不会出现在队列的,比如1~10这10个节点,你第一次搜索到叶节点的路径是1-5-7,那么1,5,7是不会出现在优先队列的。换句话说,优先队列里面存的都是查询路径上节点对应的相反子节点,比如:搜索左子树,就把对应这一层的右节点存进队列。

大致这就是我们实现的基本思路了

四.KD-BFF的实现:

知道原理了,并且有了堆这个工具之后我们就可以着手实现这个算法了:(终于要贴代码了)

代码~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~此代码是 dml / KNN / kd.py

[python] view plaincopy
  1. from __future__ import division
  2. import numpy as np
  3. import scipy as sp
  4. from operator import itemgetter
  5. from scipy.spatial.distance import euclidean
  6. from dml.tool import Heap
  7. class KDNode:
  8. def __init__(self,x,y,l):
  9. self.x=x
  10. self.y=y
  11. self.l=l
  12. self.F=None
  13. self.Lc=None
  14. self.Rc=None
  15. self.distsToNode=None
  16. class KDTree:
  17. def __init__(self,X,y=None,dist=euclidean):
  18. self.X=X
  19. self.k=X.shape[0] #N
  20. self.y=y
  21. self.dist=dist
  22. self.P=self.maketree(X,y,0)
  23. self.P.F=None
  24. def maketree(self,data,y,deep):
  25. if data.size==0:
  26. return None
  27. lenght = data.shape[0]
  28. case = data.shape[1]
  29. p=int((case)/2)
  30. l = (deep%self.k)
  31. #print data
  32. data=np.vstack((data,y))
  33. data=np.array(sorted(data.transpose(),key=itemgetter(l))).transpose()
  34. #print data
  35. y=data[lenght,:]
  36. data=data[:lenght,:]
  37. v=data[l,p]
  38. rP=KDNode(data[:,p],y[p],l)
  39. #print data[:,p],y[p],l
  40. if case>1:
  41. ldata=data[:,data[l,:]<v]
  42. ly=y[data[l,:]<v]
  43. data[l,p]=v-1
  44. rdata=data[:,data[l,:]>=v]
  45. ry=y[data[l,:]>=v]
  46. data[l,p]=v
  47. rP.Lc=self.maketree(ldata,ly,deep+1)
  48. if rP.Lc!=None:
  49. rP.Lc.F=rP
  50. rP.Rc=self.maketree(rdata,ry,deep+1)
  51. if rP.Rc!=None:
  52. rP.Rc.F=rP
  53. return rP
  54. def search_knn(self,P,x,k,maxiter=200):
  55. def pf_compare(a,b):
  56. return self.dist(x,a.x)<self.dist(x,b.x)
  57. def ans_compare(a,b):
  58. return self.dist(x,a.x)>self.dist(x,b.x)
  59. pf_seq=Heap(compare=pf_compare)
  60. pf_seq.insert(P)    #prior sequence
  61. ans=Heap(k,compare=ans_compare)  #ans sequence
  62. while pf_seq.counter>0:
  63. t=pf_seq.heap[1]
  64. pf_seq.delete(1)
  65. flag=True
  66. if ans.counter==k:
  67. now=t.F
  68. #print ans.heap[1].x,'========'
  69. if now != None:
  70. q=x.copy()
  71. q[now.l]=now.x[now.l]
  72. length=self.dist(q,x)
  73. if length>self.dist(ans.heap[1].x,x):
  74. flag=False
  75. else:
  76. flag=True
  77. else:
  78. flag=True
  79. if flag:
  80. tp,pf_seq,ans=self.to_leaf(t,x,pf_seq,ans)
  81. #print "============="
  82. #ans.insert(tp)
  83. return ans
  84. def to_leaf(self,P,x,pf_seq,ans):
  85. tp=P
  86. if tp!=None:
  87. ans.insert(tp)
  88. if tp.x[tp.l]>x[tp.l]:
  89. if tp.Rc!=None:
  90. pf_seq.insert(tp.Rc)
  91. if tp.Lc==None:
  92. return tp,pf_seq,ans
  93. else:
  94. return self.to_leaf(tp.Lc,x,pf_seq,ans)
  95. if tp.Lc!=None:
  96. pf_seq.insert(tp.Lc)
  97. if tp.Rc==None:
  98. return tp,pf_seq,ans
  99. else:
  100. return self.to_leaf(tp.Rc,x,pf_seq,ans)

然后KNN就是对上面这个类的一个包装:

代码~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~此代码是 dml / KNN / knn.py

[python] view plaincopy
  1. #coding:utf-8
  2. import numpy as np
  3. import scipy as sp
  4. from scipy.spatial.distance import cdist
  5. from scipy.spatial.distance import euclidean
  6. from dml.KNN.kd import KDTree
  7. #import pylab as py
  8. class KNNC:
  9. """docstring for KNNC"""
  10. def __init__(self,X,K,labels=None,dist=euclidean):
  11. '''''
  12. X is a N*M matrix where M is the case
  13. labels is prepare for the predict.
  14. dist is the similarity measurement way,
  15. The distance function can be ‘braycurtis’, ‘canberra’,
  16. ‘chebyshev’, ‘cityblock’, ‘correlation’, ‘cosine’,
  17. ‘dice’, ‘euclidean’, ‘hamming’, ‘jaccard’, ‘kulsinski’,
  18. ‘mahalanobis’,
  19. '''
  20. self.X = np.array(X)
  21. if labels==None:
  22. np.zeros((1,self.X.shape[1]))
  23. self.labels = np.array(labels)
  24. self.K = K
  25. self.dist = dist
  26. self.KDTrees=KDTree(X,labels,self.dist)
  27. def predict(self,x,k):
  28. ans=self.KDTrees.search_knn(self.KDTrees.P,x,k)
  29. dc={}
  30. maxx=0
  31. y=0
  32. for i in range(ans.counter+1):
  33. if i==0:
  34. continue
  35. dc.setdefault(ans.heap[i].y,0)
  36. dc[ans.heap[i].y]+=1
  37. if dc[ans.heap[i].y]>maxx:
  38. maxx=dc[ans.heap[i].y]
  39. y=ans.heap[i].y
  40. return y
  41. def pred(self,test_x,k=None):
  42. '''''
  43. test_x is a N*TM matrix,and indicate TM test case
  44. you can redecide the k
  45. '''
  46. if k==None:
  47. k=self.K
  48. test_case=np.array(test_x)
  49. y=[]
  50. for i in range(test_case.shape[1]):
  51. y.append(self.predict(test_case[:,i].transpose(),k))
  52. return y

因为KNN毕竟是一个分类算法,所以我在predict是加上了分类的代码,如果只想检验Kd-tree的话,你可以直接用for_point()找最近k个点

五.测试+后记

测试:

我们选取《统计学习方法》上面的例子:

使用代码:

[python] view plaincopy
  1. X=np.array([[2,5,9,4,8,7],[3,4,6,7,1,2]])
  2. y=np.array([2,5,9,4,8,7])
  3. knn=KNNC(X,1,y)
  4. print knn.for_point([[6.5],[7]],1)

这里y是label,是用来预测的,这个例子里没有实际作用,这是用来分类的

输出中后面带了“===”的是扫描过的点,最后的是搜索的结果:

我们可以看到的确避免扫描了(2,3),Bingo!!

我们再knn.for_point([[2],[2]]):可以看到避免扫了很多点!!!

后记:

从实现写此文前后耗时两天,昨天写代码写到熄灯且刚好测试通过,怎一个爽字了得!!最后,再在github上求个Star

reference:

【1】从K近邻算法、距离度量谈到KD树、SIFT+BBF算法 http://blog.csdn.net/v_july_v/article/details/8203674

【2】《统计学习方法》 李航

【3】最大堆的插入/删除/调整/排序操作(图解+程序)  http://www.java3z.com/cwbwebhome/article/article1/1362.html?id=4745

KNN(k-nearest neighbor algorithm)--从原理到实现相关推荐

  1. k Nearest Neighbor Algorithm

    k Nearest Neighbor Algorithm k Nearest Neighbor(kNN) algorithm算法和k-Means算法一样,都是简单理解,但是实际效果出人意料的算法之一. ...

  2. K NEAREST NEIGHBOR 算法(knn)

    K Nearest Neighbor算法又叫KNN算法,这个算法是机器学习里面一个比较经典的算法, 总体来说KNN算法是相对比较容易理解的算法.其中的K表示最接近自己的K个数据样本.KNN算法和K-M ...

  3. Nearest Neighbor Algorithm

    Nearest Neighbor Algorithm 邻近算法(Nearest Neighbor)的思想实际上十分简单,就是将测试图片和储存起来的训练集一一进行相似度计算,计算出最相近的图片,这张图片 ...

  4. 机器学习之深入理解K最近邻分类算法(K Nearest Neighbor)

    [机器学习]<机器学习实战>读书笔记及代码:第2章 - k-近邻算法 1.初识 K最近邻分类算法(K Nearest Neighbor)是著名的模式识别统计学方法,在机器学习分类算法中占有 ...

  5. K Nearest Neighbor 算法

    K Nearest Neighbor算法又叫KNN算法,这个算法是机器学习里面一个比较经典的算法, 总体来说KNN算法是相对比较容易理解的算法.其中的K表示最接近自己的K个数据样本.KNN算法和K-M ...

  6. 文献记录(part81)--Clustering-based k -nearest neighbor classification for large-scale data with ...

    学习笔记,仅供参考,有错必纠 文章目录 Clustering-based k -nearest neighbor classification for large-scale data with ne ...

  7. 机器学习——K近邻算法(KNN)(K Nearest Neighbor)

    参考视频与文献: python与人工智能-KNN算法实现_哔哩哔哩_bilibili 机器学习--K近邻算法(KNN)及其python实现_清泉_流响的博客-CSDN博客_python实现knn 机器 ...

  8. kNN算法(k近邻算法,k Nearest Neighbor)

    主要内容: 1.认识kNN算法 2.kNN算法原理 3.应用举例 4.kNN改进方法 1.认识knn算法 "看一个人怎么样,看他身边的朋友什么样就知道了",kNN算法即寻找最近的K ...

  9. 模式识别之knn---KNN(k-nearest neighbor algorithm)--从原理到实现

    用官方的话来说,所谓K近邻算法,即是给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的K个实例(也就是上面所说的K个邻居),这K个实例的多数属于某个类,就把该输入实例分类到这个类中 ...

  10. 【资源分享】今日学习打卡--k近邻法 (k Nearest Neighbor Method)

    他来了他来了,他带着礼物走来了. 今天继续学习打卡 今天给大家分享的是k近邻法,是监督学习中的最基本的分类问题模型. 博主从自己看的几个比较好的视频中,分享了一些自己认为比较好的视频.希望能在大家和博 ...

最新文章

  1. python函数结构图_Python数据结构与算法之图结构(Graph)实例分析
  2. oracle 服务名 数据库名 实例名
  3. HTML一些常用的标签
  4. html 商品展示框
  5. java调用百度搜索_Java爬虫怎么调用百度搜索引擎,对关键字的网页爬取?
  6. hive与hbase整合方式和优劣
  7. .net 4下引用低版本.net类库发生异常的解决方案
  8. cisco 2960-24 配置(生产环境)
  9. python基础——使用模块
  10. python 内推_[宜配屋]听图阁
  11. Python3系列__01Python安装
  12. C#中的主从Datagridview
  13. 三星note9刷Android9,三星Note9官方韩版安卓9固件rom线刷刷机包:N960NKSU2CSE3
  14. python从图片提取文字_用python提取图片文字
  15. DDD案例(2):从领域分析到代码实现
  16. 记一次Mysql并发死锁,引出的问题及讨论
  17. 选手机壳要擦亮双眼,不会遮挡激光对焦传感器的才是好壳子!
  18. Android JetPack Security简介
  19. C++中流控制函数 setw() setfill() setbase() setprecision()
  20. 特斯拉产业的几个问题

热门文章

  1. win7设置固定IP重启后无法上网,ipconfig显示为自动配置IPV4 169.254的地址
  2. SQL Server中的事务日志管理(6/9):大容量日志恢复模式里的日志管理
  3. Sql server 行列转换
  4. 2022年寒假ACM练习1
  5. db2 随机数函数_sql中的随机函数怎么用?
  6. jsp页面ajax用法,在jsp中使用jquery的ajax
  7. php粉层,thinkphp 模型分层
  8. 浅析单调递增子序列问题(LIS)
  9. linux mysql 5.7 配置_Linux环境下详细讲解部署MySQL5.7版本
  10. go语言mysql框架_开源轻量级数据库访问框架-Go语言中文社区