kNN(K Nearest Neighbor)算法是机器学习中最基础入门,也是最常用的算法之一,可以解决大多数分类与回归问题。这里以鸢尾花数据集为例,讨论分类问题中的 kNN 的思想。

鸢尾花数据集内包含 3 类共 150 条记录,每类各 50 个数据,每条记录都有 4 项特征:花萼长度(sepal length)、花萼宽度(sepal width)、花瓣长度(petal length)、花瓣宽度(petal length)。

可以通过这 4 个特征预测鸢尾花卉属于(iris-setosa,,iris-versicolour, iris-virginica)中的哪一品种,这里使用 kNN 来预测。

首先,导入鸢尾花数据集(两种方式,一种是下载鸢尾花数据集,然后从文件读取,我们采用第二种,直接从datasets中读取,返回的是字典格式的数据),并将鸢尾花数据集分为训练集和测试集。

iris = datasets.load_iris()
X = iris.data
y = iris.target
# 随机划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=20,                                                        shuffle=True)

为了方便理解 kNN,将鸢尾花的训练数据的前两个特征值,分别作为 x 轴和 y 轴数据,进行可视化。

# 数据可视化
plt.scatter(X_train[y_train == 0][:, 0], X_train[y_train == 0][:, 1], color='r')
plt.scatter(X_train[y_train == 1][:, 0], X_train[y_train == 1][:, 1], color='g')
plt.scatter(X_train[y_train == 2][:, 0], X_train[y_train == 2][:, 1], color='b')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.show()

如图所示,三个不同的颜色分别代表鸢尾花的三个类别。现在如果有一个新的数据(图中黑色点表示),如何判断它属于哪个类别呢?

我们需要使用的 kNN 算法,正如它的英文 K Nearest Neighbor,算法的核心思想是,选取训练集中离该数据最近的 k 个点,它们中的大多数属于哪个类别,则该新数据就属于哪个类别。

根据它的核心思想,模型中有三个需要确定的要素:

  • k 如何选择
  • 如何确定「最近」,也就是如何度量距离
  • 如何确定分类的规则

其中,k 的选择是一个超参数的选择问题,需要通过调整 K 的值确定最好的 K,最好选奇数,否则会出现同票。

可以通过交叉验证法确定模型的最佳 k 值(这里后面会谈);

度量距离的方式,一般为 Lp 距离:

p = 1 时,为曼哈顿距离:

p = 2 时,为欧式距离:

欧式距离是我们最常用的计算距离的方式。

分类的规则,采取多数表决的原则,即由输入实例的 k 个近邻的训练实例中的多数类决定输入实例的类。

代码如下:

# 计算距离,默认为欧氏距离
def calculateDistance(data1, data2, p=2):if len(data1) == len(data2) and len(data1) >= 1:sum = 0for i in range(len(data1)):sum += math.pow(abs(data1[i] - data2[i]), p)dist = math.pow(sum, 1/p)return dist# knn模型分类
def knnClassify(X_train, y_train, test_data, k):dist = [calculateDistance(train_data, test_data) for train_data in X_train]# 返回距离最近的k个训练样本的索引(下标)indexes = np.argsort(dist)[:k]count = Counter(y_train[indexes])return count.most_common(1)[0][0]if __name__ == '__main__':# 预测结果predictions = [knnClassify(X_train, y_train, test_data, 3) for test_data in X_test]# 与实际结果对比correct = np.count_nonzero((predictions == y_test) == True)print("Accuracy is: %.3f" % (correct/len(X_test)))

这里是自己实现的分类代码,在 sklearn 中有封装好的 kNN 库,代码如下:

# 创建kNN_classifier实例
kNN_classifier = KNeighborsClassifier(n_neighbors=3)
# kNN_classifier做一遍fit(拟合)的过程,没有返回值,模型就存储在kNN_classifier实例中
kNN_classifier.fit(X_train, y_train)
correct = np.count_nonzero((kNN_classifier.predict(X_test) == y_test) == True)
print("Accuracy is: %.3f" % (correct/len(X_test)))

kNN 没有显式的学习过程,这是它的优点,但在用它进行数据分类时,需要注意几个问题:

  • 不同特征有不同的量纲,必要时需进行特征归一化处理
  • kNN 的时间复杂度为O(D*N*N),D 是维度数,N 是样本数,这样,在特征空间很大和训练数据很大时,kNN 的训练时间会非常慢。这时就需要用到 kd 树,可以将时间复杂度降为O(logD*N*N)(kd 树后面会讲)。

参考文章:机器学习-kNN 算法

kNN处理鸢尾花数据集相关推荐

  1. python决策树分类 导入数据集_BPNN、决策树、KNN、SVM分类鸢尾花数据集Python实现...

    数据集处理 数据获取 使用sklearn的dataset获取数据 from sklearn import datasets from sklearn.model_selection import tr ...

  2. python KNN分类算法 使用鸢尾花数据集实战

    KNN分类算法,又叫K近邻算法,它概念极其简单,但效果又很优秀. 如觉得有帮助请点赞关注收藏啦~~~ KNN算法的核心是,如果一个样本在特征空间中的K个最相似,即特征空间中最邻近的样本中的大多数属于某 ...

  3. python进行KNN算法分析实战(鸢尾花数据集)

    KNN算法分析实战(鸢尾花数据集) 目录 KNN算法分析实战(鸢尾花数据集) 代码效果图 一.导入需要的包 二. 1.导入数据 ​ 2.建立训练集和测试集 3.设置K值 4. 十重交叉验证K值 5.模 ...

  4. Python原生代码实现KNN算法(鸢尾花数据集)

    一.作业题目 Python原生代码实现KNN分类算法,使用鸢尾花数据集. KNN算法介绍: K最近邻(k-Nearest Neighbor,KNN)分类算法,是机器学习算法之一. 该方法的思路是:如果 ...

  5. KNN算法实现,(鸢尾花数据集分类)

    KNN算法实现实例 KNN算法 一,KNN算法概述 二,算法步骤 KNN算法的具体实现 一.数据获取 二.具体代码的实现 参考博客 KNN算法 一,KNN算法概述 knn算法,又叫k-近邻算法.属于一 ...

  6. KNN算法实现鸢尾花数据集分类

    KNN算法实现鸢尾花数据集分类 作者介绍 数据集介绍 KNN算法介绍 用KNN实现鸢尾花分类 作者介绍 乔冠华,女,西安工程大学电子信息学院,2020级硕士研究生,张宏伟人工智能课题组. 研究方向:机 ...

  7. 用鸢尾花数据集实现knn分类算法

    一.题目: 原生python实现knn分类算法,用鸢尾花数据集 二.算法设计 1. 准备数据,对数据进行预处理 2. 选用合适的数据结构存储训练数据和测试元组 3. 设定参数,如k 4.维护一个大小为 ...

  8. 各种机器学习方法实现多分类(KNN,Logistics,Decision tree,byeis,SVM)以鸢尾花数据集为例

    之前做的关于分类问题大都是基于CNN实现图像分类如手写体识别-CNN实现minst识别,已有的参考也是比较多,还整理过一个关于PIMA糖尿病人数据集的分类,该问题属于对于数据的二分类问题,根据数据判断 ...

  9. 实验一:鸢尾花数据集分类

    实验一:鸢尾花数据集分类 一.问题描述 利用机器学习算法构建模型,根据鸢尾花的花萼和花瓣大小,区分鸢尾花的品种.实现一个基础的三分类问题. 二.数据集分析 Iris 鸢尾花数据集内包含 3 种类别,分 ...

最新文章

  1. xauth: (stdin):1: bad display name LSPPC-Lenny:1 in add command
  2. Android中获取系统语言(适用于Android7.0以上系统)
  3. 采购订单模板_采购必备:如何搭建合规的采购流程
  4. cocos2dx在不同安卓机型下scrollview裁剪失败
  5. Nginx+Mongodb 文件存储方案
  6. 请确保已载入内核模块vmmon_冒充市场监管部门短信诈骗,多人已受骗!
  7. java arraydeque poll,Java ArrayDeque pollLast()方法
  8. 【ANSYS】网格划分技术之映射网格
  9. 微信小程序前端登录模块设计
  10. CodeMeter***大赛战况:百人参赛,无人摘金!
  11. 联想x100e linux,联想小红ThinkPad X100e笔记本拆解!
  12. 虚幻c++入门到入土(一)VS插件Resharper使用
  13. 谈判如何在谈判中_谈判工作的十大规则
  14. NPOI 操作word 创建页眉页脚,页眉中插入图片
  15. 设置不显示桌面上的计算机图标不见了,显示桌面,教您显示桌面图标不见了如何恢复...
  16. 解决QQ语音通话后耳机失效的问题
  17. 虚拟机安装centos7上网设置总结
  18. 大小写字母转换 (15分)
  19. python TypeError: Descriptors cannot not be created directly错误解决
  20. 设计原则之合成复用原则

热门文章

  1. python回调函数的作用是_回调函数的意义以及python实现
  2. java多线程的应用-使用两个线程打印12A34B56C78D
  3. java实现发红包案例(一)
  4. 安卓系统怎么安装软件_这些系统帮助我们实现了在PC上安装安卓系统!
  5. enumerator迭代器和Iterator迭代器浅述
  6. 用python编写掷100次硬币_认识概率,用python模拟掷硬币
  7. Linux:下载wget
  8. 【恒指早盘分析】投资从学会区分能力和运气开始
  9. vue 项目使用three.js 实现3D看房效果
  10. bzoj 1242: Zju1015 Fishing Net 弦图判定