如何选择K值

首先让我们理解K值到底如何影响KNN算法。如果我们

有很多蓝色点和红色点数据,使用不同K值,最终的分类效果大概如下图。我们发现随着K值的增大,分界面越来越平滑。

一般在机器学习中我们要将数据集分为训练集和测试集,用训练集训练模型,再用测试集评价模型效果。这里我们绘制了不同k值下模型准确率。

从上图中我们发现当k=1和k=无穷大时,KNN的误差都很大。但是在某个点时能够将误差降低到最小。在本例中k=10

KNN算法伪码

我们现在设计下KNN伪码:

一. 读取数据

二. 初始化k值

三. 为了得到测试数据的预测类别,对训练集每条数据进行迭代

计算测试数据与训练集中每一条数据的距离。这里我们选用比较通用的欧几里得距离作为距离的实现方法。

对距离进行升序排列

对排列结果选择前K个值

得到出现次数最多的类

返回测试数据的预测类别

用Python实现KNN算法

这里我们使用iris数据集来构件KNN模型。

from sklearn.datasets import load_iris

import numpy as np

#从sklearn中导入iris数据集

iris = load_iris()

X = iris.data

y = iris.target

#计算data1和data2的欧几里得距离。

def euclideanDistance(data1, data2):

distance = 0

#data1和data2长度一致,这里我们就使用data1的长度

for x in range(len(data1)):

distance += np.square(data1[x]-data2[x])

return np.sqrt(distance)

#定义KNN模型

def KNN(X, y, testInstance, k):

distances = dict()

#三、为了得到测试数据的预测类别,对训练集每条数据进行迭代

for idx, trainInstance in enumerate(X):

#1. 计算测试数据与训练集中每一条数据的距离。

dist = euclideanDistance(testInstance, trainInstance)

distances[idx] = dist

#2. 对距离进行升序排列

sorted_d = sorted(distances.items(), key=lambda k:k[1])

neighbors = []

#3. 对排列结果选择前K个值

for x in range(k):

neighbors.append(sorted_d[x][0])

classVotes = dict()

#4. 得到出现次数最多的类

for x in range(len(neighbors)):

label = y[neighbors[x]]

if label in classVotes:

classVotes[label]+=1

else:

classVotes[label]=1

#5. 返回测试数据的预测类别

sortedVotes = sorted(classVotes.items(), key=lambda k:k[1], reverse=True)

return (sortedVotes[0][0], neighbors)

我们先测试下

testInstance = [7.2, 3.6, 5.1, 2.5]

k = 1

predicted, neighbors = KNN(X, y, testInstance, k)

print('predicted:',predicted)

print('neighbors:',neighbors)

运行结果

predicted: 2

neighbors:[141]

现在我们将k设置为3.

k = 3

predicted, neighbors = KNN(X, y, testInstance, k)

print('predicted:',predicted)

print('neighbors:',neighbors)

运行结果

predicted: 2

neighbors:[141, 139, 120]

k = 5

predicted, neighbors = KNN(X, y, testInstance, k)

print('predicted:',predicted)

print('neighbors:',neighbors)

运行结果

predicted: 2

neighbors:[141, 139, 120, 145, 144]

与scikit-learn比较

from sklearn.neighbors import KNeighborsClassifier

knn = KNeighborsClassifier(n_neighbors=3)

knn.fit(X, y)

#scikit中输入的数据一般都为二维数组(矩阵)。

testdata = [testInstance]

print('predicted:', knn.predict(testdata))

print('neighbors', knn.kneighbors(testdata))

运行结果

predicted: [2]

neighbors (array([[0.6164414 , 0.76811457, 0.80622577]]), array([[141, 139, 120]]))

我们仅仅用了一个例子,当k=3时,从scikit库knn算法的运行结果与我们设计的完全一致。neigbors都是[141, 139, 120]

小节

KNN算法是最简单的分类算法之一,即便算法如此简单,模型表现极佳。 KNN算法也可用于回归问题。 与所讨论方法的唯一区别是使用最近邻居的平均值而不是最近邻居的投票。

knn算法python代码iris_KNN算法原理及代码实现相关推荐

  1. 【数据结构与算法python】最小生成树算法-Prim算法

    1.引入 本算法涉及到在互联网中网游设计者和网络收音机所面临的问题:信息广播问题,如网游需要让所有玩家获知其他玩家所在的位置,收音机则需要让所有听众获取直播的音频数据 2.算法介绍 (1)单播解法 信 ...

  2. 【OpenCV/C++】KNN算法识别数字的实现原理与代码详解

    KNN算法识别数字 一.KNN原理 1.1 KNN原理介绍 1.2 KNN的关键参数 二.KNN算法识别手写数字 2.1 训练过程代码详解 2.2 预测分类的实现过程 三.KNN算法识别印刷数字 2. ...

  3. c4.5算法python实现_算法:用Python实现—最优化算法

    今天给大家分享一下算法,用python来实现最优化算法.废话不多说,直接上代码: 一.二分法 函数详见rres,此代码使该算法运行了两次 def asdf(x): rres=8*x**3-2*x**2 ...

  4. 聚类算法——python实现SOM算法

    算法简介 SOM网络是一种竞争学习型的无监督神经网络,将高维空间中相似的样本点映射到网络输出层中的邻近神经元. 训练过程简述:在接收到训练样本后,每个输出层神经元会计算该样本与自身携带的权向量之间的距 ...

  5. 算法python实现_Relief算法python实现

    文章目录 一.算法流程图 二.代码步骤 1.第一步:定义类和init方法 2.第二步:数据处理 3.第三步:通过计算距离,找出猜错近邻和猜对近邻 4.第四步:计算特征权重 5.第五步:根据权重过滤式选 ...

  6. pca算法python实现_PCA算法——Python实现

    一.流程生成模拟数据 模型训练 特征向量重要性分析 画图 二.Python语言 1.生成模拟数据 # 导入相关数据库 from sklearn import datasets ​ # 提取数据 dig ...

  7. 提取lbp特征java代码_LBP特征提取原理及代码实现

    老规矩,先上背景,算是表示对LBP算法提出者的一种尊敬(其实,是为了装...kkk,大家都懂ha). 一.LBP背景: LBP(Local Binary Pattern,局部二值模式)是一种用来描述图 ...

  8. Python版插入排序算法

    问题描述:在插入排序算法中,把所有元素分为前面的已排序序列和后面的未排序序列两部分,每次处理未排序序列中的第一个元素,将其插入到前面已排序序列中的合适位置,从而不停地扩大已排序序列并缩小未排序序列,直 ...

  9. Python使用超高效算法查找所有类似123-45-67+89=100的组合

    问题描述:在123456789这9个数字中间插入任意多个+和-的组合,使得表达式的值为100,输出所有符合条件的表达式. 昨天发了一个暴力测试的方法来解决问题,详见Python查找所有类似于123-4 ...

最新文章

  1. java 版本SQLHelper
  2. 第五讲 Python中的字符串(一)
  3. 测试框架之testng使用
  4. Abp vnext Web应用程序开发教程 10 —— 书与作者的关系
  5. Facebook 在西雅图和匹兹堡建立新AI实验室,业界担心大学人才争夺战烽火再起...
  6. 如何启用Domino 8 的ODS磁盘结构
  7. jdbc.postgresql源码分析
  8. 百度股市通如何实现智能选股?
  9. 51单片机流水灯三种实现方法
  10. jQuery学习笔记之选取选定复选框的同行某列元素
  11. SQL教程之使用 dbt 和 SQLfluff 整理 SQL
  12. C# 调试应用提示“无法访问此网站”问题的处理(图文)
  13. 大数据和云计算技术周报(第102期)
  14. 3、Spark 和 D3.js 分析航班大数据
  15. 【高中数学】频率分布表和频率分布直方图
  16. 飞机大战之-添加背景
  17. 使用FileOutPutStream下载docx文件报文件已损坏解决
  18. Android开发:微信应用签名如何获取
  19. 四大运营商频段最新划分情况
  20. android ocr 中文版,android ocr

热门文章

  1. 如何使用WPS软件创建文本文档?
  2. 鸿蒙开源oppo,华为鸿蒙开源,OPPO公关粗鄙言论将自己置于舆论风暴中
  3. c# redis分布式锁
  4. 《记忆错觉:记忆如何影响了我们的感知,思维与心理》
  5. JS 生成随机数/随机数组
  6. 【热门算法】ctr、cvr
  7. 狂神闲谈:正确的学习态度
  8. JSON增删改查学习笔记
  9. 花开时节不再来--再见了,我的大四
  10. html 网页图片保存