机器学习:k邻近算法(KNN)
title: 机器学习:k邻近算法(KNN)
date: 2019-11-16 20:20:41
mathjax: true
categories:
- 机器学习
tags: - 机器学习
什么是K邻近算法?
工作原理是:存在一个样本数 据集合,也称作训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据 与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的 特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们 只选择样本数据集中前k个最相似的数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数。 最后,选择k个最相似数据中出现次数最多的分类,作为新数据的分类
举例子
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-cV3JqQgh-1573908679233)( https://photos-1258741719.cos.ap-beijing.myqcloud.com/machine-learning/ 1573907111850.png)]
特征分析:动作片:打斗次数更多,爱情片:亲吻次数更多
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-D0lKE6ss-1573908679257)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\1573907179891.png)]
即使不知道未知电影属于哪种类型,我们也可以通过某种方法计算出来。首先计算未知电影
与样本集中其他电影的距离
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-aE8ZARu7-1573908679258)( https://photos-1258741719.cos.ap-beijing.myqcloud.com/machine-learning/ 1573907207887.png)]
现在我们得到了样本集中所有电影与未知电影的距离,按照距离递增排序,可以找到k个距
离最近的电影。假定k=3,则三个最靠近的电影依次是He’s Not Really into Dudes、Beautiful Woman
和California Man。k-近邻算法按照距离最近的三部电影的类型,决定未知电影的类型,而这三部
电影全是爱情片,因此我们判定未知电影是爱情片。
构建KNN算法:
算法原理:
(1) 收集数据:可以使用任何方法。 (2) 准备数据:距离计算所需要的数值,最好是结构化的数据格式。 (3) 分析数据:可以使用任何方法。 (4) 训练算法:此步骤不适用于k-近邻算法。 (5) 测试算法:计算错误率。 (6) 使用算法:首先需要输入样本数据和结构化的输出结果,然后运行k-近邻算法判定输 入数据分别属于哪个分类,最后应用对计算出的分类执行后续的处理
导入数据
from numpy import * import operator def createDataSet():group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])labels = ['A','A','B','B']return group,labels
查看数据:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-r8w04SwZ-1573908679260)( https://photos-1258741719.cos.ap-beijing.myqcloud.com/machine-learning/ 1573907418563.png)]
KNN算法伪代码:
对未知类别属性的数据集中的每个点依次执行以下操作:
(1) 计算已知类别数据集中的点与当前点之间的距离;
(2) 按照距离递增次序排序;
(3) 选取与当前点距离最小的k个点;
(4) 确定前k个点所在类别的出现频率;
(5) 返回前k个点出现频率最高的类别作为当前点的预测分类
def classify0(inX,dataSet,labels,K):dataSetSize = dataSet.shape[0]#行的个数diffMat = tile(inX,(dataSetSize,1)) - dataSet#将输入向量复制减去dataset就可以得到差值sqDiffMat = diffMat**2sqDistances = sqDiffMat.sum(axis = 1)#对于矩阵的每行求和,结果就是每行求和的矩阵distances = sqDistances**0.5sortedDisIndicies = distances.argsort()#返回的是矩阵从小到达的索引值classCount = {}for i in range(K):voteIlabel = labels[sortedDisIndicies[i]]#获得指定表亲classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1sortedClassCount = sorted(classCount.items(),key=lambda x:x[1],reverse=True)return sortedClassCount[0][0]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-GBeDKzSn-1573908679263)( https://photos-1258741719.cos.ap-beijing.myqcloud.com/machine-learning/1573907549202.png)]
运行示例:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ydK31oKp-1573908679264)(C https://photos-1258741719.cos.ap-beijing.myqcloud.com/machine-learning/1573907600578.png)]
使用 k-近邻算法改进约会网站的配对效果
数据地址
题目描述
海伦使用约会网站寻找约会对象。经过一段时间之后,她发现曾交往过三种类型的人:
- 不喜欢的人
- 魅力一般的人
- 极具魅力的人
她希望:
- 工作日与魅力一般的人约会
- 周末与极具魅力的人约会
- 不喜欢的人则直接排除掉
现在她收集到了一些约会网站未曾记录的数据信息,这更有助于匹配对象的归类。
开发流程
收集数据:提供文本文件 准备数据:使用 Python 解析文本文件 分析数据:使用 Matplotlib 画二维散点图 训练算法:此步骤不适用于 k-近邻算法 测试算法:使用海伦提供的部分数据作为测试样本。测试样本和非测试样本的区别在于:测试样本是已经完成分类的数据,如果预测分类与实际类别不同,则标记为一个错误。 使用算法:产生简单的命令行程序,然后海伦可以输入一些特征数据以判断对方是否为自己喜欢的类型。
打开文件:
def file2matrix(filename):fr = open(filename)arrayOLines = fr.readlines()numberOfLines = len(arrayOLines)#得到文件行数returnMat = zeros((numberOfLines,3))#创建一个n行3列数据classLabelVector = []index = 0for line in arrayOLines:line = line.strip()#去除两边空格listFromLine = line.split('\t')returnMat[index,:] = listFromLine[0:3]classLabelVector.append(int(listFromLine[-1]))#结果集index +=1return returnMat,classLabelVector
注意:由于NumPy库提供的数组操作并不支持Python自带的数组类型,因此在编写代码时要注意不要使用错误的数组类型
画出散点图:
import matplotlib import matplotlib.pyplot as plt fig = plt.figure() ax = fig.add_subplot(111) #1*1网格 ax.scatter(datingDataMat[:,1],datingDataMat[:,2],15.0*array(datinglabels),15.0*array(datinglabels))#画出第二列第三列图像 plt.show()
结果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-LJvYdTcO-1573908679267)( https://photos-1258741719.cos.ap-beijing.myqcloud.com/machine-learning/11162037.png)]
数据归一化处理:
def autoNorm(dataSet):minVals = dataSet.min(0)#最大值,每列的maxVals = dataSet.max(0)ranges = maxVals-minVals#差值normDataSet = zeros(shape(dataSet))#构建一个和dataset类似的0矩阵m = dataSet.shape[0]normDataSet = dataSet - tile(minVals,(m,1))normDataSet = normDataSet/tile(ranges,(m,1))return normDataSet,ranges,minVals
进行数据分析:
def datingClassTest():hoRation = 0.10datingDataMat,datingLabels = file2matrix('./2.KNN/datingTestSet2.txt')normMat,ranges,minVals = autoNorm(datingDataMat)m = normMat.shape[0]numtestVecs = int(m*hoRation)errorCount = 0.0for i in range(numtestVecs):classifierResult = classify0(normMat[i,:],normMat[numtestVecs:m,:],datingLabels[numtestVecs:m],5)print("the classifier came back with: %d ,the real answer is: %d"%(classifierResult,datingLabels[i]))if(classifierResult!= datingLabels[i]):errorCount+=1;print("the total error rate is:%f"%(errorCount/float(numtestVecs)))
结果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-JVeoNOdf-1573908679268)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\1573907990818.png)]
判定函数:
def classifyPerson():resultList = ['not at all','in small doses','in large doses']percentTats = float(input("percentage of time spent playing video games?"))ffMiles = float(input("frequent flier miles earned per year?"))iceCream = float(input("liters of ice cream consumed per year?"))datingDataaMat,datingLabels = file2matrix('./2.KNN/datingTestSet2.txt')normMat,ranges,minVals = autoNorm(datingDataMat)#数据归一化inArr = array([ffMiles,percentTats,iceCream])classifierResult = classify0((inArr-minVals)/ranges,normMat,datingLabels,3)print('you will probably like this person:',resultList[classifierResult-1])
运行:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-grvpc061-1573908679270)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\1573908051140.png)]
手写数字识别
算法流程:
(1) 收集数据:提供文本文件。
(2) 准备数据:编写函数classify0(),将图像格式转换为分类器使用的list格式。
(3) 分析数据:在Python命令提示符中检查数据,确保它符合要求。
(4) 训练算法:此步骤不适用于k-近邻算法。
(5) 测试算法:编写函数使用提供的部分数据集作为测试样本,测试样本与非测试样本
的区别在于测试样本是已经完成分类的数据,如果预测分类与实际类别不同,则标记
为一个错误。
(6) 使用算法:本例没有完成此步骤,若你感兴趣可以构建完整的应用程序,从图像中提
取数字,并完成数字识别,美国的邮件分拣系统就是一个实际运行的类似系统。
文本向量处理:
将图像格式化处理为一个向量。我们将把一个32× 32的二进制图像矩阵转换为1×1024的向量 #把一个32×32的二进制图像矩阵转换为1×1024的向量,这样前两节使用的分类器就可以处理数字图像信息了 def img2vector(filename):returnVect = zeros((1,1024))fr = open(filename)for i in range(32):lineStr = fr.readline()for j in range(32):returnVect[0,32*i+j] = int(lineStr[j])# 将变量赋值进入矩阵return returnVect
算法:
from os import listdir # def handwritingClassTest():# 1. 导入训练数据hwLabels = []trainingFileList = listdir('./2.KNN/trainingDigits') # load the training setm = len(trainingFileList)#m就是看有多少文件for i in trainingFileList:print(i)trainingMat = zeros((m, 1024))# hwLabels存储0~9对应的index位置, trainingMat存放的每个位置对应的图片向量for i in range(m):fileNameStr = trainingFileList[i]fileStr = fileNameStr.split('.')[0] # take off .txtclassNumStr = int(fileStr.split('_')[0])hwLabels.append(classNumStr)#获得label值# 将 32*32的矩阵->1*1024的矩阵trainingMat[i, :] = img2vector('./2.KNN/trainingDigits/%s' % fileNameStr)# 2. 导入测试数据testFileList = listdir('./2.KNN/testDigits') # iterate through the test seterrorCount = 0.0mTest = len(testFileList)for i in range(mTest):fileNameStr = testFileList[i]fileStr = fileNameStr.split('.')[0] # take off .txtclassNumStr = int(fileStr.split('_')[0])vectorUnderTest = img2vector('./2.KNN/testDigits/%s' % fileNameStr)classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 10)#print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))if (classifierResult != classNumStr): errorCount += 1.0print("\nthe total number of errors is: %d" % errorCount)print("\nthe total error rate is: %f" % (errorCount / float(mTest)))
小结
近似误差,更关注于“训练”。最小化近似误差,即为使估计值尽量接近真实值,但是这个接近只是对训练样本(当前问题)而言,模型本身并不是最接近真实分布。换一组样本,可能就不近似了。这种只管眼前不顾未来预测的行为,即为过拟合。估计误差,更关注于“测试”、“泛化”。最小化估计误差,即为使估计系数尽量接近真实系数,但是此时对训练样本(当前问题)得到的估计值不一定是最接近真实值的估计值;但是对模型本身来说,它能适应更多的问题(测试样本)也就是泛化能力更强
K, K的取值
对查询点标签影响显著(效果拔群)。k值小的时候 近似误差小,估计误差大。 k值大 近似误差大,估计误差小。
如果选择较小的 k 值,就相当于用较小的邻域中的训练实例进行预测,“学习”的近似误差(approximation error)会减小,只有与输入实例较近的(相似的)训练实例才会对预测结果起作用。但缺点是“学习”的估计误差(estimation error)会增大,预测结果会对近邻的实例点非常敏感。如果邻近的实例点恰巧是噪声,预测就会出错。换句话说,k 值的减小就意味着整体模型变得复杂,容易发生过拟合。
如果选择较大的 k 值,就相当于用较大的邻域中的训练实例进行预测。其优点是可以减少学习的估计误差。但缺点是学习的近似误差会增大。这时与输入实例较远的(不相似的)训练实例也会对预测起作用,使预测发生错误。 k 值的增大就意味着整体的模型变得简单。
太大太小都不太好,可以用交叉验证(cross validation)来选取适合的k值。
近似误差和估计误差,请看这里:https://www.zhihu.com/question/60793482
距离度量 Metric/Distance Measure
距离度量 通常为 欧式距离(Euclidean distance),还可以是 Minkowski 距离 或者 曼哈顿距离。也可以是 地理空间中的一些距离公式。(更多细节可以参看 sklearn 中 valid_metric 部分)
分类决策 (decision rule)
分类决策 在 分类问题中 通常为通过少数服从多数 来选取票数最多的标签,在回归问题中通常为 K个最邻点的标签的平均值。
KNN算法是很基本的机器学习算法了,它非常容易学习,在维度很高的时候也有很好的分类效率,因此运用也很广泛,这里总结下KNN的优缺点。
KNN的主要优点有:
1) 理论成熟,思想简单,既可以用来做分类也可以用来做回归
2) 可用于非线性分类
3) 训练时间复杂度比支持向量机之类的算法低,仅为O(n)
4) 和朴素贝叶斯之类的算法比,对数据没有假设,准确度高,对异常点不敏感
5) 由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合
6)该算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分
KNN的主要缺点有:
1)计算量大,尤其是特征数非常多的时候
2)样本不平衡的时候,对稀有类别的预测准确率低
3)KD树,球树之类的模型建立需要大量的内存
4)使用懒散学习方法,基本上不学习,导致预测时速度比起逻辑回归之类的算法慢
5)相比决策树模型,KNN模型可解释性不强
最重要一点,请看链接-》链接
机器学习:k邻近算法(KNN)相关推荐
- 机器学习实战读书笔记--k邻近算法KNN
k邻近算法的伪代码: 对未知类别属性的数据集中的每个点一次执行以下操作: (1)计算已知类别数据集中的点与当前点之间的距离: (2)按照距离递增次序排列 (3)选取与当前点距离最小的k个点 (4)确定 ...
- k邻近算法(KNN)实例
一 k近邻算法原理 k近邻算法是一种基本分类和回归方法. 原理:K近邻算法,即是给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的K个实例,这K个实例的多数属于某个类,就把该输入实 ...
- 2 机器学习 K近邻算法(KNN) 学习曲线 交叉验证 手写数字识别
机器学习 1 K-近邻算法介绍 1.1 分类问题 分类问题:根据已知样本的某些特征,判断一个未知样本属于哪种样本类别. 与回归问题相比,分类问题的输出结果是离散值,用于指定输入的样本数据属于哪个类别. ...
- k折交叉验证优缺点_R语言中K邻近算法的初学者指南:从菜鸟到大神(附代码&链接)...
作者:Leihua Ye, UC Santa Barbara 翻译:陈超 校对:冯羽 本文约2300字,建议阅读10分钟 本文介绍了一种针对初学者的K临近算法在R语言中的实现方法. 本文呈现了一种在R ...
- 机器学习3—分类算法之K邻近算法(KNN)
K邻近算法(KNN) 一.算法思想 二.KNN类KNeighborsClassifier的使用 三.KNN分析红酒类型 3.1红酒数据集 3.2红酒数据的读取 3.3将红酒的数据集拆分为训练和测试集 ...
- 刻意练习:机器学习实战 -- Task01. K邻近算法
背景 这是我们为拥有 Python 基础的同学推出的精进技能的"机器学习实战" 刻意练习活动,这也是我们本学期推出的第三次活动了. 我们准备利用8周时间,夯实机器学习常用算法,完成 ...
- K邻近算法概述、欧式距离、Scikit-learn使用 、kNN邻近算法距离度量、曼哈顿距离、切比雪夫距离、闵可夫斯基距离、标准化欧氏距离、余弦距离、汉明距离、杰卡德距离、马氏距离
一.K-邻近算法概述 K邻近算(K Nearest Neighbor算法,KNN算法):如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别 ...
- 《机器学习实战》K邻近算法
K邻近算法 存在一个样本数据集合,样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系.输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样 ...
- K邻近算法(K-NN)
简单记录学习<统计学习方法>书中的k近邻模型. k近邻法(k-nearest neighbor,k-NN)是一种基本分类与回归方法. k邻近算法 k邻近算犯法简单.直观描述:给定一个训练数 ...
最新文章
- 西部开源学习笔记BOOK2-《unit 4》
- Java并发:volatile内存可见性和指令重排
- JAVA调用shell脚本实例
- 使用dom breakpoint找到修改属性的javascript代码
- Jonit Bayesian 的推导
- Android开发笔记(一百六十一)NFC近场通信
- cacti安装的一个错误
- Spring Boot中Bean管理
- 英文横版游戏《玛丽师傅》源码H5+安卓+IOS三端源码
- 锐捷网关交换机开启dhcp服务
- Excel进行灵敏度分析
- 天正2014打开加载lisp_天正CAD2014对不信任加载项的解决方法
- 盈透api python封装_[转载]用MT4来接入IB盈透TWS平台交易外汇
- 八、服务器【Ubuntu】GPU-TeslaP100部署
- 百度松果 买礼物(贪心)
- ubuntu下通过鼠标右键创建txt文件
- Java解析快手视频去水印教程--2020年10月最新有效代码
- 相亲时候遇到的92年妹子,老衲想下手了
- 仿网易蜗牛读书小程序
- 学习pandas库笔记(pd.read_excel)