分类算法之K-Nearest Neighbors

一:认识KNN算法

K-Nearest Neighbors又称为k最近邻算法。最初由Cover和Hart在1968年提出。属于Classification(分类)算法。Classification算法一般分为两个阶段,训练阶段和学习阶段

二:算法原理

通用步骤:

  • 计算距离(常用欧式距离或马氏距离)
  • 升序排列
  • 取前k个
  • 加权平均

k的选取:

  • k太大:导致分类模糊
  • k太小:手个例影响,波动较大

如何选取k:

  • 经验
  • 均方根误差

三:算法应用

  1. 简单KNN算法实现

    ​ python中的sklearn库为我们封装了KNN算法的实现方式

    ​ datasets是sklearn中的一个存放各类数据的模块,我们可以通过导入这个模块可以获取大量标准的数据进行练习。我们以鸢尾花数据为例。

    ​ 数据解析:鸢尾花的共有4个属性,分别为sepal length 、sepal width 、 petal length 、 petal width, 通过属性的不同我们将鸢尾花分为了setosa、versicolor、virginica三类,分别用数字0、1、2表示。综上所述,4个属性为特征,类别为标签。

    代码实现:

from sklearn.datasets import load_iris
from sklearn.neighbors import KNeighborsClassifier# 获取数据
iris_data = load_iris()
print(iris_data)# 将鸢尾花数据分为两组,取前100条数据为训练集,后100条数据为测试集
# 获取训练集的特征值和标签值
trans_feature = iris_data["data"][:100]
trans_label = iris_data["target"][:100]
# 获取测试集的特征值和标签值
test_feature = iris_data["data"][-100:]
test_label = iris_data["target"][-100:]# 创建KNN模型对象,传入的参数相当于k值
KNN = KNeighborsClassifier(n_neighbors=7)# 测试模型
KNN.fit(trans_feature,trans_label)# 模型预测
KNN.predict(test_feature)# 获取精度
precision = KNN.score(test_feature, test_label)
print(precision)
  1. Cross Validation (交叉验证)

    ​ Cross Validation是一种评估模型性能的重要方法。主要用于在多个模型中(不同中模型和同一种类不同超参数组合)挑选出在当前问题场景下表现最优的模型

    引入原因:

    ​ 在训练集(train set)上训练得到的模型表现良好,但在测试集(test set)的预测结果不尽如人意,这就说明模型可能出现了过拟合(overfitting),在未知数据上的泛化能力差。

    ​ 一个改进方案是,在训练集的基础上进一步划分出新的训练集和验证集(validate set),在新训练集训练模型,在验证集测试模型,不断调整初始模型(超参数等),使得训练得到的模型在验证集上的表现最好,最后放到测试集上得到这个最优模型的评估结果。

    ​ 这个方案的问题在于模型的表现依赖于验证集的划分,可能使某些特殊样本被划入验证集,导致模型的表现出现异常(偏好或偏差)。而且训练集划了一部分给验证集后,训练模型能得到的数据就变少了,也会影响训练效果。因为通常来说,训练数据越多,越能反映出数据的真实分布,模型训练的效果就越好,越可能得到无偏估计。

    ​ 交叉验证思想应运而生,交叉验证可以充分使用所有的训练数据用于评估模型。

    • K-fold(K折交叉验证)

      k折交叉验证是最基本的cv方法,具体方法为,将训练集随机等分为k份,取其中一份为验证集评估模型,其余k-1份为训练集训练模型,重复该步骤k次,每次都取一份不同的子集为验证集,最终得到k个不同的模型(不是对一个模型迭代k次)和k个评分,综合这k个模型的表现(平均得分或其他)评估模型在当前问题中的优劣。

      训练集和验证集的划分都在模块底层自动实现,我们只需要传入一个集合即可。

      k值的选取很有讲究,选取适当大小的k很重要,经验值(empirical value)是k=10。

      算法实现:

      from sklearn.datasets import load_iris
      from sklearn.neighbors import KNeighborsClassifier
      from sklearn.model_selection import cross_val_score# 获取数据
      iris_data = load_iris()# 实例化模型
      KNN = KNeighborsClassifier(n_neighbors=7)# 提取特征值和标签
      iris_features = iris_data["data"]
      iris_labels = iris_data["target"]# 交叉验证,返回模型每次的精度
      score = cross_val_score(KNN,iris_features,iris_labels,cv=10)# 获取精度,精度一般取平均值
      precision = score.mean()
      print(precision)
      
    • Leave one out(LOO)

      LOO(留一法)每次在训练集的N个样本中选一个不同的样本作为验证集,其余样本为训练集,训练得到N-1个不同的模型。LOOCV是特殊的K-fold,当K=N时,二者相同。【不常用,了解即可】

  2. 网格搜索

    网格搜索的目的主要是用来确定超参数k的值,CV是网格搜索的一种方法。

    cv与网格搜索

    ​ 通常情况下,有很多参数是需要手动指定的(如k-近邻算法中的K值),这种叫超参数。但是手动过程繁杂,所以需要对模型预设几种超参数组合。每组超参数都采用交叉验证来进行评估。最后选出最优参数组合建立模型。

    ​ scikit learn库中的GridSearchCV可以方便地尝试不同的超参数(hyper parameter),得到最优组合。GridSearchCV中的cv参数(默认为3-fold)利用交叉验证为某一超参数组合打出合理的评分。可以使用sklearn.model_selection.GridSearchCV(estimator, param_grid=None,cv=None)对估计器模型的指定参数值进行详尽搜索。其中estimator为模型对象,param_grid为k值的备选列表,cv为数据等分个数

    ​ GridSearchCV结果展示方法:

    ​ best_params_:返回最高精度对应的k值

    ​ best_score_:返回最高精度值

    ​ best_estimator_:返回最好的参数模型

    代码实现:

    from sklearn.datasets import load_iris
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.model_selection import GridSearchCV
    # 获取数据
    iris_data = load_iris()# 获取特征值和标签
    iris_features = iris_data["data"]
    iris_labels = iris_data["target"]# 实例化模型对象
    KNN = KNeighborsClassifier()# 定义k值备选列表
    params = {"n_neighbors":[i for i in range(10,30,2)]}# 实例化网格对象
    gridcv = GridSearchCV(KNN,grid_param=params,cv=10)# 测试数据
    gridcv.fit(iris_features,iris_labels)# 返回精度最高的参数
    print(gridcv.best_params_)
    # 返回最高精度值
    print(gridcv.best_scores_)
    # 返回最高精度模型
    

print(gridcv.best_estimator_)

分类算法之K-Nearest Neighbors相关推荐

  1. KNN(K Nearest Neighbors)分类是什么学习方法?如何或者最佳的K值?RadiusneighborsClassifer分类器又是什么?KNN进行分类详解及实践

    KNN(K Nearest Neighbors)分类是什么学习方法?如何或者最佳的K值?RadiusneighborsClassifer分类器又是什么?KNN进行分类详解及实践 如何使用GridSea ...

  2. 机器学习之深入理解K最近邻分类算法(K Nearest Neighbor)

    [机器学习]<机器学习实战>读书笔记及代码:第2章 - k-近邻算法 1.初识 K最近邻分类算法(K Nearest Neighbor)是著名的模式识别统计学方法,在机器学习分类算法中占有 ...

  3. 分类算法之K近邻和朴素贝叶斯

    分类算法之K近邻和朴素贝叶斯 K近邻 一个例子弄懂k-近邻 sklearn k-近邻算法API 朴素贝叶斯 概率论基础 贝叶斯公式 拉普拉斯平滑 sklearn朴素贝叶斯实现API K近邻 一个例子弄 ...

  4. 机器学习-监督学习之分类算法:K近邻法 (K-Nearest Neighbor,KNN)

    目录 KNN概述 举个例子: K值选取 距离计算 曼哈顿距离,切比雪夫距离关系(相互转化) k-近邻(KNN)算法步骤 相关代码实现 简单实例:判断电影类别 创建数据集 数据可视化 分类测试 运行结果 ...

  5. k近邻算法_机器学习分类算法之k近邻算法

    本编文章将介绍机器学习入门算法-k近邻算法,将会用demo演示机器学习分类算法. 在先介绍算法时,先回顾分类和回归的区别.像文章分类识别也是这样处理的,如1代表体育,2代表科技,3代表娱乐属于分类问题 ...

  6. 机器学习3—分类算法之K邻近算法(KNN)

    K邻近算法(KNN) 一.算法思想 二.KNN类KNeighborsClassifier的使用 三.KNN分析红酒类型 3.1红酒数据集 3.2红酒数据的读取 3.3将红酒的数据集拆分为训练和测试集 ...

  7. 机器学习——分类算法之K近邻+朴素贝叶斯,模型选择与调优

    目录 K-近邻算法 定义 如何求距离? 数据预处理--标准化 sklearn k-近邻算法API 案例--预测入住位置 分类问题 数据处理 k近邻算法相关问题 k值取多大?有什么影响? 优缺点 应用场 ...

  8. kNN算法(k近邻算法,k Nearest Neighbor)

    主要内容: 1.认识kNN算法 2.kNN算法原理 3.应用举例 4.kNN改进方法 1.认识knn算法 "看一个人怎么样,看他身边的朋友什么样就知道了",kNN算法即寻找最近的K ...

  9. 如下10种分类算法对比Classifier comparison

    如下10种分类算法对比 names = ["Nearest Neighbors", "Linear SVM", "RBF SVM", &qu ...

  10. 菜鸟进阶: C++实现KNN文本分类算法

    作者:finallyliuyu(转载请注明原作者和出处) (代码暂不发布源码下载版,以后会发布) KNN文本分类算法又称为(k nearest neighhor).它是一种基于事例的学习方法,也称懒惰 ...

最新文章

  1. RHEL6.3 DNS高级技术二 通过DNS主从区域复制实现DNS View负载均衡和冗余备份
  2. UVa12166 Equilibrium Mobile修改天平(二叉树+dfs)
  3. 开源项目管理软件禅道使用帮助下载
  4. 新视窗java_《计算机组成原理实验》教学大纲 - 兰州大学信息科学与工程学院.DOC...
  5. 1.8编程基础之多维数组 03 计算矩阵边缘元素之和 python
  6. java addobject_springMVC后台的值无法通过ModelAndView的addObject传到前台
  7. 逻辑回归python实现
  8. hdu-3488-Tour(KM最佳完美匹配)
  9. 基于motion的视频压缩的实时监控系统
  10. 系统目录 linux命令,linux基础命令之系统目录(示例代码)
  11. network 网络带宽
  12. python重写和装饰器_Python | 老司机教你 5 分钟读懂 Python 装饰器
  13. 淘晶驰串口屏下载工程慢怎么办
  14. 案例: 模拟京东快递单号查询
  15. 西数硬盘刷新固件_关于西数硬盘转速的fake news
  16. 解决no-console异常
  17. Spark应用启动报错:Could not locate executable null\bin\winutils.exe in the Hadoop binaries.
  18. xbox one 手柄在ubuntu下的使用和开发
  19. 如何笔记本盖上连接显示器不熄屏?
  20. 精简版Win10/11系统无法登录Xbox及Minecraft Launcher解决方法,不需要重装系统

热门文章

  1. mysql 查询本月所有天数统计对应数据
  2. 01 HA haproxy 功能简介以及应用
  3. vue 登录 动态树 表格 cud
  4. No version of NDK matched the requested version xxx 问题解决
  5. Python反爬机制-验证码
  6. Cobaltstrike内网工具的使用笔记
  7. springboot中banner图制作
  8. Please restart Word to load MathType addin properly;运行时错误53,文件未找到MathPage.WLL
  9. npm安装报错(npm ERR code EPERM npm ERR syscall mkdir npm ERR path CProgram Filesnodejsnode_ca...)
  10. 在java中重写方法应遵循规则的包括_蘑菇街2017校园招聘笔试题