KNN是一种常见的监督学习算法,工作机制很好理解:给定测试样本,基于某种距离度量找出训练集中与其最靠近的k个训练样本,然后基于这k个“邻居”的信息来进行预测。总结一句话就是“近朱者赤,近墨者黑”。

KNN可用作分类也可用于回归,在分类任务中可使用“投票法”,即选择这k个样本中出现最多的类别标记作为测试结果;在回归任务中可使用“平均法”将这k个样本的标记平均值作为预测结果;还可以基于距离远近进行加权平均或加权投票,距离越近的样本权重越大。
       KNN和之前介绍的监督学习算法有一个很大的不同,它没有前期的训练过程,是一种“懒惰学习”的算法,只有收到测试样本后,再和训练样本进行比较处理。
       初学者容易把KNN和K-means搞混淆,虽然都有K,:-)但这是两种不同的算法,二者区别如下:

  KNN K-Means
不同点 是一种分类算法,属于监督学习的范畴,训练数据是带有label的 是一种聚类算法,属于非监督学习的范畴,训练数据没有label,杂乱无章的
没有明显的训练过程,属于lazy learning 有明确的训练过程
K的含义:与预测样本距离最近的K个样本 K的含义:K是事前人工定好的参数,假设数据集可分为K个簇
相同点 都用到了NN(nearst Neighbor)算法,一般用KD树来实现。

--KNN算法基本原理
KNN算法简单的步骤如下:
(1)计算距离:给定测试对象,计算它与训练集中每个对象的距离,空间距离的计算方法有多种,有欧式距离、夹角余弦(多在文本分类中使用)等。
(2)找邻居:圈定距离最近的k个对象,作为测试对象的近邻。
(3)做分类:根据这k个近邻归属的主要类别,对测试对象进行分类。

下面通过一个简单的示例说明下KNN算法是怎么进行分类的:

上图的蓝色方块和红色三角是已经打好label的数据,绿色圆圈是待分类的测试数据。

如果我们让K=3,那么上图实心圆圈中的两个三角和一个方块就是离测试数据最近的3个点,那么通过投票法则,测试数据会被分类为红色三角;

如果我们让K=5,那么上图虚线圆圈中的两个三角和三个方块就是离测试数据最近的5个点,通过投票法则,测试数据则会被分类为蓝色方块;

整个算法的原理是不是很简单?但实际上并没有那么简单,K如何选择?数据之间的距离怎么计算?

--K值的选择

如果K值太小,整体模型会变得复杂,容易发生过拟合,容易将一些噪声学习进来,二忽略数据的真实分布。

如果K值过大,模型会变得相对简单,可以减少学习的估计误差,但近似误差会变大,比如极端情况下K=N(N维训练样本数),则不论预测对象是什么,预测结果都将是训练集中最多的类型,这显然是一个过渡简化的模型,无法实际应用。

k值一般采用交叉验证或者Grid Search的方法确定。

--距离计算

提取数据的特征值,根据特征值组成一个n维实数向量空间(特征空间),然后计算向量之间的空间距离,如欧式距离、余弦相似度等。

对于数据,其特征空间为n维实数向量空间:

欧式距离计算公式为:

余弦相似度计算公式为:

余弦相似度的值越接近1表示其越相似,接近0表示其差异越大。余弦相似度更多应用在文本类任务中。

--代码示例

依旧以sklearn中的cancer数据集为例,做一个通过30维特征判断是否患癌症的示例,示例中数据量很少,只有569条数据,每条数据各有30个特征数值。采用sklearn中的KNN分类器,除k外都采用默认参数,距离度量采用欧式距离 。通过交叉验证法来确定最佳的K值,从下图可见,K=14时,验证准确率最高。

-Python 代码

__author__ = 'z00421185'import pandas as pd
from sklearn import datasets
import matplotlib.pyplot as plt
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifierbreast_data = datasets.load_breast_cancer()
data = pd.DataFrame(datasets.load_breast_cancer().data)
data.columns = breast_data['feature_names']data_np = breast_data['data']
target_np = breast_data['target']
print(data_np.shape)x_train, x_test, y_train, y_test = train_test_split(data_np, target_np, test_size=0.3, random_state=0)# 设定交叉验证k的范围,一般从1~样本数的开方
k_range = range(1, 24)
scores = []
for k in k_range:knn = KNeighborsClassifier(k, metric='euclidean')score = cross_val_score(knn, x_train, y_train, cv=10, scoring='accuracy')scores.append(score.mean())# 从折线图上看最佳K取值
plt.plot(k_range, scores)
plt.xlabel('K')
plt.ylabel('Accuracy')
plt.show()model = KNeighborsClassifier(n_neighbors=13)
model.fit(x_train, y_train)
y_pred = model.predict(x_test)
print(accuracy_score(y_test, y_pred))
---------------------------------
0.9649122807017544

作者:华为云专家 周捷

机器学习笔记(十)---- KNN(K Nearst Neighbor)相关推荐

  1. 机器学习笔记十四:随机森林

    在上一篇机器学习笔记十三:Ensemble思想(上)中,简要的提了一下集成学习的原理和两种主要的集成学习形式.  而在这部分要讲的随机森林,就算是其中属于bagging思路的一种学习方法.为了篇幅,b ...

  2. 机器学习笔记 十四:k-近邻算法(kNN)的实现

    目录 1. 什么是机器学习 2. k-近邻算法(kNN) 2.1 kNN的原理 2.2 k-近邻算法的一般流程 2.3 kNN伪代码 3. 函数介绍 3.1 get()函数:利用字典统计列表中元素出现 ...

  3. 机器学习笔记(5) KNN算法

    这篇其实应该作为机器学习的第一篇笔记的,但是在刚开始学习的时候,我还没有用博客记录笔记的打算.所以也就想到哪写到哪了. 你在网上搜索机器学习系列文章的话,大部分都是以KNN(k nearest nei ...

  4. 机器学习笔记(十)降维和度量学习

    10.降维和度量学习 10.1k近邻学习 k近邻(k-NearestNeighbor,简称kNN)学习是一种常用的监督学习方法,其原理是:给定测试样本,基于某种距离度量找出训练集中与其最靠近的k个训练 ...

  5. 机器学习笔记(十五)规则学习

    15.规则学习 15.1基本概念 机器学习中的规则(rule)通常是指语义明确.能描述数据分布所隐含的客观规律或领域概念.可写成若-则-形式的逻辑规则.规则学习(rulelearning)是从训练数据 ...

  6. 2018-3-20李宏毅机器学习笔记十----------Logistic Regression

    上节讲到:既然是一个直线型,只需要求解w和b.为何还要那么费劲的使用概率??? 视频:李宏毅机器学习(2017)_哔哩哔哩 (゜-゜)つロ 干杯~-bilibili https://www.bilib ...

  7. 机器学习笔记(十六)强化学习

    16.强化学习 16.1任务与奖赏 强化学习(reinforcementlearning)的过程就是机器通过一系列的动作和环境交互,从而得到最佳的动作序列.图示: 强化学习任务用马尔可夫决策(Mark ...

  8. 机器学习笔记(十二)计算学习理论

    12.计算学习理论 12.1基础知识 计算学习理论(computationallearning theory)研究的是关于通过计算来进行学习的理论,即关于机器学习的理论基础,其目的是分析学习任务的困难 ...

  9. 机器学习笔记(十八)——HMM的参数估计

    一.HMM中的第三个基本问题 参数估计问题:给定一个观察序列O=O1O2-OTO=O_1O_2\dots O_T,如何调节模型μ=(A,B,π)\mu = (A, B, \pi)的参数,使得P(O|μ ...

最新文章

  1. 模型大十倍,性能提升几倍?谷歌研究员进行了一番研究
  2. 本地应用 v-on补充
  3. nginx 子请求接收响应_Nginx详解其原理
  4. 基础练习 高精度加法
  5. swift瀑布流实现_CSS 实现瀑布流布局(display: flex)
  6. 怎么创建数据表的实体类和业务类_微服务项目第13天:商品分类业务的实现
  7. HTML粘性滑块导航源码-可用来做首页
  8. jenkins集成tomcat
  9. oracle错误ORA-00604 递归sql级别1出现错误 ora-00942 表或试图不存在 ORA-06512 在line 11...
  10. php 7 中对数值 * 100 出现很多小数_PHP快速入门第二讲:数据类型
  11. WebApi的调用-1.前端调用
  12. Java对象映射XML文件
  13. HTML5 css3 阴影效果
  14. modelandview 跳转问题_ModelAndView 跳转的使用
  15. 基于Android的防疫信息管理系统源码【包调试运行】
  16. R语言mad函数、median函数、mean函数计算中位数绝对偏差、中位数、均值实战
  17. 机器人学与OROCOS-KDL(三)姿态描述与旋转矩阵
  18. 《按自己的意愿过一生》读书笔记
  19. 6种不同画法画平行线_学会6种常用平行线的判定方法,数学成绩悄悄涨20分
  20. AD2020库安装及查找库

热门文章

  1. antd vue关闭模态对话框_如何在Bootstrap项目中用Vue.js替代jQuery
  2. delphi控件切图界面闪烁_DirectUI用户手册.pdf
  3. oracle数据泵导入导出_【软件】R语言数据导入与导出
  4. php no route to host,java.net.NoRouteToHostException: No route to host解决方法
  5. PX4 CMakeLists.txt 文件剖析
  6. spark 源码分析之十九 -- DAG的生成和Stage的划分
  7. 蓝桥杯 单点最短路径问题
  8. Ubuntu apt-get方式安装Subversion
  9. 大家一起做训练 第一场 A Next Test
  10. 面向对象基本原则-转载