文章目录

  • 一、KNN 基本介绍
  • 二、KNN 核心思想
  • 三、KNN 算法流程
  • 四、KNN 优缺点
  • 五、Java 代码实现 KNN
  • 六、KNN 改进策略

一、KNN 基本介绍

邻近算法,或者说K最邻近(KNN,K-NearestNeighbors)分类算法是分类方法中最简单的方法之一。所谓K最近邻,就是K个最近的邻居的意思,说的是每个样本都可以用它最接近的K个邻近值来代表。近邻算法就是将数据集合中每一个记录进行分类的方法。

KNN 最初由 Cover 和 Hart 于1968年提出,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一。

该方法的思路非常简单直观:如果一个样本在特征空间中的 K 个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。


二、KNN 核心思想

KNN算法的核心思想是,如果一个样本在特征空间中的K个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。

该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。KNN方法在类别决策时,只与极少量的相邻样本有关。

由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。

下面举一个具体的例子(来源于:https://zhuanlan.zhihu.com/p/143092725)

如下图所示,图中绿色的点就是我们要预测的那个点。

假设 K=3。那么 KNN 算法就会找到与它距离最近的三个点(这里用圆圈把它圈起来了),看看哪种类别多一些,比如这个例子中是蓝色三角形多一些,新来的绿色点就归类到蓝三角了。


但是,当 K=5 的时候,判定就变成不一样了。这次变成红圆多一些,所以新来的绿点被归类成红圆。如下图所示:

从这个例子中,我们就能看得出 K 的取值是很重要的。


三、KNN 算法流程

  1. 准备数据,对数据进行预处理
  2. 计算测试样本点(也就是待分类点)到其他每个样本点的距离
  3. 对每个距离进行排序,然后选择出距离最小的K个点
  4. 对K个点所属的类别进行比较,根据少数服从多数的原则,将测试样本点归入在K个点中占比最高的那一类

注意:由于 KNN 算法中需要计算两点之间的距离,距离有很多种度量方式,比如常见的曼哈顿距离、欧式距离、切比雪夫距离等等。不过通常 KNN 算法中使用的是欧式距离。


四、KNN 优缺点

KNN 优点

  • KNN 方法思路简单,易于理解,易于实现,无需估计参数(同样是分类算法,逻辑回归需要先对数据进行大量训练,最后才会得到一个算法模型。而 KNN 算法却不需要,它没有明确的训练数据的过程,或者说这个过程很快)
  • 模型训练时间快
  • 对异常值不敏感
  • 预测效果好

KNN 缺点

  • 对内存要求较高,因为该算法存储了所有训练数据
  • 预测阶段可能很慢,因为要从大量的训练数据中找到最近的 K 个点
  • 当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数

五、Java 代码实现 KNN

由于网络上关于 Python 实现 KNN 的博客实在是太多啦,所以本篇博客就以 Java 实现 KNN !Python 的话可以直接调用 sklearn,非常方便~

TrainDataSet:训练集对象

public class TrainDataSet {/*** 特征集合**/public List<double[]> features = new ArrayList<>();/*** 标签集合**/public List<Integer> labels = new ArrayList<>();/*** 特征向量维度**/public int featureDim;public int size() {return labels.size();}public double[] getFeature(int index) {return features.get(index);}public int getLabel(int index) {return labels.get(index);}public void addData(double[] feature, int label) {if (features.isEmpty()) {featureDim = feature.length;} else {if (featureDim != feature.length) {throwDimensionMismatchException(feature.length);}}features.add(feature);labels.add(label);}public void throwDimensionMismatchException(int errorLen) {throw new RuntimeException("DimensionMismatchError: 你应该传入维度为 " + featureDim + " 的特征向量 , 但你传入了维度为 " + errorLen + " 的特征向量");}}

KNearestNeighbors:KNN算法对象

public class KNearestNeighbors {/*** 训练数据集**/TrainDataSet trainDataSet;/*** k值**/int k;/*** 距离公式**/DistanceType distanceType;/*** @param trainDataSet: 训练数据集* @param k:            k值*/public KNearestNeighbors(TrainDataSet trainDataSet, int k, DistanceType distanceType) {this.trainDataSet = trainDataSet;this.k = k;this.distanceType = distanceType;}// 传入特征,返回预测值public int predict(double[] feature) {if (feature.length != trainDataSet.featureDim) {trainDataSet.throwDimensionMismatchException(feature.length);}PriorityQueue<Node> nodePriorityQueue = new PriorityQueue<>();for (int i = 0; i < trainDataSet.size(); i++) {nodePriorityQueue.add(new Node(trainDataSet.getLabel(i), calcDistance(trainDataSet.getFeature(i), feature)));}int cnt = 0;Map<Integer, Integer> map = new HashMap<>();int predictLabel = -1;int maxNum = -1;for (int i = 0; i < k && !nodePriorityQueue.isEmpty(); i++) {int label = nodePriorityQueue.poll().label;if (map.containsKey(label)) {map.replace(label, map.get(label) + 1);} else {map.put(label, 1);}if (map.get(label) > maxNum) {maxNum = map.get(label);predictLabel = label;}cnt++;}if (cnt != k || maxNum == -1) {throw new RuntimeException("predict fail");}return predictLabel;}// 计算距离private double calcDistance(double[] arr1, double[] arr2) {switch (distanceType) {case EuclideanDistance:return calcEuclideanDistance(arr1, arr2);case ManhattanDistance:return calcManhattanDistance(arr1, arr2);case ChebyshevDistance:return calcChebyshevDistance(arr1, arr2);default:break;}throw new RuntimeException("未知的distanceType: " + distanceType);}// 计算欧式距离private double calcEuclideanDistance(double[] arr1, double[] arr2) {double res = 0d;for (int i = 0; i < arr1.length; i++) {res += Math.pow(arr1[i] - arr2[i], 2);}return Math.sqrt(res);}// 计算曼哈顿距离private double calcManhattanDistance(double[] arr1, double[] arr2) {double res = 0d;for (int i = 0; i < arr1.length; i++) {res += Math.abs(arr1[i] - arr2[i]);}return res;}// 计算切比雪夫距离private double calcChebyshevDistance(double[] arr1, double[] arr2) {double res = 0d;for (int i = 0; i < arr1.length; i++) {res = Math.max(res, Math.abs(arr1[i] - arr2[i]));}return res;}private static class Node implements Comparable<Node> {int label;double distance;public Node(int label, double distance) {this.label = label;this.distance = distance;}@Overridepublic int compareTo(Node o) {return Double.compare(distance, o.distance);}}public enum DistanceType {// 欧式距离EuclideanDistance,// 曼哈顿距离ManhattanDistance,// 切比雪夫距离ChebyshevDistance;}}

六、KNN 改进策略

目前对 KNN 算法改进的方向主要可以分为 4 类

  • 寻求更接近于实际的距离函数以取代标准的欧氏距离,典型的工作包括 WAKNN、VDM
  • 搜索更加合理的 K 值以取代指定大小的 K 值典型的工作包括 SNNB、 DKNAW
  • 运用更加精确的概率估测方法去取代简单的投票机制,典型的工作包括 KNNDW、LWNB、 ICLNB
  • 建立高效的索引,以提高 KNN 算法的运行效率,代表性的研究工作包括 KDTree、 NBTree

【机器学习】K近邻算法(K-NearestNeighbors , KNN)详解 + Java代码实现相关推荐

  1. 机器学习之重点汇总系列(二)——K近邻算法(k-Nearest Neighbor,kNN)

    什么是K近邻算法 引例 假设有数据集,其中前6部是训练集(有属性值和标记),我们根据训练集训练一个KNN模型,预测最后一部影片的电影类型 首先,将训练集中的所有样例画入坐标系,也将待测样例画入 然后计 ...

  2. K近邻算法学习(KNN)

    K近邻算法--KNN 机器学习--K近邻算法(KNN) 基本知识点 基本原理 示例 关于KNN的基本问题 距离如何计算? k如何定义大小? k为为什么不定义一个偶数? KNN的优缺点 代码实现 第一次 ...

  3. 机器学习:Linear Discriminant Analysis(过程详解+实例代码MATLAB实现

    目录 LDA概念 线性判别分析(LDA)-二分类 LDA二分类过程 举个例子 线性判别分析-多分类 LDA多分类过程 Experiment 3: Linear Discriminant Analysi ...

  4. 线性回归 - 机器学习多元线性回归 - 一步一步详解 - Python代码实现

    目录 数据导入 单变量线性回归 绘制散点图 相关系数R 拆分训练集和测试集 多变量线性回归 数据检验(判断是否可以做线性回归) 训练线性回归模型 先甩几个典型的线性回归的模型,帮助大家捡起那些年被忘记 ...

  5. java注解式开发_JAVA语言之Spring MVC注解式开发使用详解[Java代码]

    本文主要向大家介绍了JAVA语言的Spring MVC注解式开发使用详解,通过具体的内容向大家展示,希望对大家学习JAVA语言有所帮助. MVC注解式开发即处理器基于注解的类开发, 对于每一个定义的处 ...

  6. 【大道至简】机器学习算法之EM算法(Expectation Maximization Algorithm)详解(附代码)---通俗理解EM算法。

    ☕️ 本文来自专栏:大道至简之机器学习系列专栏

  7. 帝国竞争算法(imperialist competitive algorithm, ICA )详解+Java代码实现

    前言 这段时间用过这个算法做过相关的工作,今天就介绍一下吧.虽然感觉效果嘛,勉勉强强啦.不过每种算法肯定有其适用的地方,用到了就Mark一下方便后人吧~ 介绍 帝国竞争算法(imperialist c ...

  8. 枚举算法经典日期问题详解java

    目录 枚举算法 日期问题 枚举思想 具体代码 枚举算法 枚举算法是我们在日常中使用到的最多的一个算法,它的核心思想就是:枚举所有的可能. 枚举法的本质就是从所有候选答案中去搜索正确的解. 使用该算法需 ...

  9. 一文速学数模-分类模型(一)SVM(Support Vector Machines)支持向量机算法原理以及应用详解+Python代码实现

    目录 前言 一.引论 二.理论铺垫 线性可分性(linear separability) 超平面 决策边界

最新文章

  1. Hibernate Annotation中英文文档链接下载 (Hibernate 注解)
  2. Golang之Go Module使用
  3. 【TensorFlow】——broadcast_to(在不复制内存的情况下自动扩张tensor)
  4. php 5.5 链接redis,PHP实例:PHP5.5安装PHPRedis扩展及连接测试方法
  5. 【图像处理】基于matlab GUI Hough变换+PDE图像去雨(带面板)【含Matlab源码 811期】
  6. excel转word后表格超出页面_excel转word后表格显示不全
  7. 统计推断——假设检验——检验的功效(势)
  8. Windows系统管理24招
  9. 分水岭算法 c语言实现,分水岭算法的应用
  10. webpy快速入门 搭建python服务器
  11. 中国联通(广东省分公司)研发技术初面
  12. Midjourney API 接口对接历程
  13. cocos creator 牌面翻转
  14. opengl 画椭圆_漫谈椭圆的几何性质(之一)
  15. Ubuntu 1804 切换国内源
  16. SCAU2021春季个人排位赛第七场 (部分题解))
  17. css实现渐变色字体
  18. 扫描网段找出树莓派IP
  19. Oracle_如何应对润秒
  20. 手机地图导航测试用例

热门文章

  1. 为什么MAC地址和IP地址不能合并成一个地址?
  2. LED显示屏公司如何在这个激烈的竟争中抓住机遇?
  3. ppt显示无法插入视频 解决方案
  4. 20160420 每天半小时学英语
  5. 在泰国旅居的第5天,我定了两个新目标
  6. 线控转向系统Carsim和Simulink联合仿真模型,带Carsim数据库
  7. NOIP2018出征策
  8. 小米、华为、iPhone三款钱包的竞品分析
  9. Ubuntu 常用命令收集[菜鸟版]
  10. 关于汉字在不同编码方式中的大小