目录

一、前言:

二、理论难点:

距离度量:

欧式距离:

三、数据可视化

四、数据归一化:

五、代码实践:

理论补充

实验一: 海伦约会

实验二 使用sklearn实现knn

六、总结

1、kNN算法的优缺点

2、其他

参考资料:


一、前言:

按照机器学习实战---每一次学习都应该按照如下步骤进行思考:

  1. 收集数据:可以使用爬虫进行数据的收集,也可以使用第三方提供的免费或收费的数据。一般来讲,数据放在txt文本文件中,按照一定的格式进行存储,便于解析及处理。
  2. 准备数据:使用Python解析、预处理数据。
  3. 分析数据:可以使用很多方法对数据进行分析,例如使用Matplotlib将数据可视化。
  4. 测试算法:计算错误率。
  5. 使用算法:错误率在可接受范围内,就可以运行k-近邻算法进行分类。

k-近邻算法步骤如下:

  1. 计算已知类别数据集中的点与当前点之间的距离;
  2. 按照距离递增次序排序;
  3. 选取与当前点距离最小的k个点;
  4. 确定前k个点所在类别的出现频率;
  5. 返回前k个点所出现频率最高的类别作为当前点的预测分类。

二、理论难点:

直接看书的话,容易被什么欧拉公式什么的吓到,建议简单入手,参考:

https://cuijiahua.com/blog/2017/11/ml_1_knn.html#comments

距离度量:

欧式距离:

三、数据可视化

# -*- coding: UTF-8 -*-
from matplotlib.font_manager import FontProperties
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import numpy as np"""
函数说明:打开并解析文件,对数据进行分类:1代表不喜欢,2代表魅力一般,3代表极具魅力Parameters:filename - 文件名
Returns:returnMat - 特征矩阵classLabelVector - 分类Label向量Modify:2017-03-24
"""
def file2matrix(filename):#打开文件fr = open(filename)#读取文件所有内容arrayOLines = fr.readlines()#得到文件行数numberOfLines = len(arrayOLines)#返回的NumPy矩阵,解析完成的数据:numberOfLines行,3列returnMat = np.zeros((numberOfLines,3))#返回的分类标签向量classLabelVector = []#行的索引值index = 0for line in arrayOLines:#s.strip(rm),当rm空时,默认删除空白符(包括'\n','\r','\t',' ')line = line.strip()#使用s.split(str="",num=string,cout(str))将字符串根据'\t'分隔符进行切片。listFromLine = line.split('\t')#将数据前三列提取出来,存放到returnMat的NumPy矩阵中,也就是特征矩阵returnMat[index,:] = listFromLine[0:3]#根据文本中标记的喜欢的程度进行分类,1代表不喜欢,2代表魅力一般,3代表极具魅力if listFromLine[-1] == 'didntLike':classLabelVector.append(1)elif listFromLine[-1] == 'smallDoses':classLabelVector.append(2)elif listFromLine[-1] == 'largeDoses':classLabelVector.append(3)index += 1return returnMat, classLabelVector"""
函数说明:可视化数据Parameters:datingDataMat - 特征矩阵datingLabels - 分类Label
Returns:无
Modify:2017-03-24
"""
def showdatas(datingDataMat, datingLabels):#设置汉字格式font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14)#将fig画布分隔成1行1列,不共享x轴和y轴,fig画布的大小为(13,8)#当nrow=2,nclos=2时,代表fig画布被分为四个区域,axs[0][0]表示第一行第一个区域fig, axs = plt.subplots(nrows=2, ncols=2,sharex=False, sharey=False, figsize=(13,8))numberOfLabels = len(datingLabels)LabelsColors = []for i in datingLabels:if i == 1:LabelsColors.append('black')if i == 2:LabelsColors.append('orange')if i == 3:LabelsColors.append('red')#画出散点图,以datingDataMat矩阵的第一(飞行常客例程)、第二列(玩游戏)数据画散点数据,散点大小为15,透明度为0.5axs[0][0].scatter(x=datingDataMat[:,0], y=datingDataMat[:,1], color=LabelsColors,s=15, alpha=.5)#设置标题,x轴label,y轴labelaxs0_title_text = axs[0][0].set_title(u'每年获得的飞行常客里程数与玩视频游戏所消耗时间占比',FontProperties=font)axs0_xlabel_text = axs[0][0].set_xlabel(u'每年获得的飞行常客里程数',FontProperties=font)axs0_ylabel_text = axs[0][0].set_ylabel(u'玩视频游戏所消耗时间占',FontProperties=font)plt.setp(axs0_title_text, size=9, weight='bold', color='red') plt.setp(axs0_xlabel_text, size=7, weight='bold', color='black') plt.setp(axs0_ylabel_text, size=7, weight='bold', color='black')#画出散点图,以datingDataMat矩阵的第一(飞行常客例程)、第三列(冰激凌)数据画散点数据,散点大小为15,透明度为0.5axs[0][1].scatter(x=datingDataMat[:,0], y=datingDataMat[:,2], color=LabelsColors,s=15, alpha=.5)#设置标题,x轴label,y轴labelaxs1_title_text = axs[0][1].set_title(u'每年获得的飞行常客里程数与每周消费的冰激淋公升数',FontProperties=font)axs1_xlabel_text = axs[0][1].set_xlabel(u'每年获得的飞行常客里程数',FontProperties=font)axs1_ylabel_text = axs[0][1].set_ylabel(u'每周消费的冰激淋公升数',FontProperties=font)plt.setp(axs1_title_text, size=9, weight='bold', color='red') plt.setp(axs1_xlabel_text, size=7, weight='bold', color='black') plt.setp(axs1_ylabel_text, size=7, weight='bold', color='black')#画出散点图,以datingDataMat矩阵的第二(玩游戏)、第三列(冰激凌)数据画散点数据,散点大小为15,透明度为0.5axs[1][0].scatter(x=datingDataMat[:,1], y=datingDataMat[:,2], color=LabelsColors,s=15, alpha=.5)#设置标题,x轴label,y轴labelaxs2_title_text = axs[1][0].set_title(u'玩视频游戏所消耗时间占比与每周消费的冰激淋公升数',FontProperties=font)axs2_xlabel_text = axs[1][0].set_xlabel(u'玩视频游戏所消耗时间占比',FontProperties=font)axs2_ylabel_text = axs[1][0].set_ylabel(u'每周消费的冰激淋公升数',FontProperties=font)plt.setp(axs2_title_text, size=9, weight='bold', color='red') plt.setp(axs2_xlabel_text, size=7, weight='bold', color='black') plt.setp(axs2_ylabel_text, size=7, weight='bold', color='black')#设置图例didntLike = mlines.Line2D([], [], color='black', marker='.',markersize=6, label='didntLike')smallDoses = mlines.Line2D([], [], color='orange', marker='.',markersize=6, label='smallDoses')largeDoses = mlines.Line2D([], [], color='red', marker='.',markersize=6, label='largeDoses')#添加图例axs[0][0].legend(handles=[didntLike,smallDoses,largeDoses])axs[0][1].legend(handles=[didntLike,smallDoses,largeDoses])axs[1][0].legend(handles=[didntLike,smallDoses,largeDoses])#显示图片plt.show()"""
函数说明:main函数Parameters:无
Returns:无Modify:2017-03-24
"""
if __name__ == '__main__':#打开的文件名filename = "datingTestSet.txt"#打开并处理数据datingDataMat, datingLabels = file2matrix(filename)showdatas(datingDataMat, datingLabels)

四、数据归一化:

因为我们希望我们的样本的特征都是同等重要的。而如果只按照公式来算:

显然飞机的里程会更为重要,所以需要归一化:
归一化的方法有挺多种,目的是如将取值范围处理为0到1或者-1到1之间

本次练习采用:
newValue = (oldValue - min) / (max - min)

其中min和max分别是数据集中的最小特征值和最大特征值。虽然改变数值取值范围增加了分类器的复杂度,但为了得到准确结果,我们必须这样做。

五、代码实践:

理论补充

统计学习方法-KNN原理学习-统计学习的课后作业

PS:KD树的构建可以看教材的实现代码

实验一: 海伦约会

import numpy as np
import operator"""
函数说明:kNN算法,分类器Parameters:inX - 用于分类的数据(测试集)dataSet - 用于训练的数据(训练集)labes - 分类标签k - kNN算法参数,选择距离最小的k个点
Returns:sortedClassCount[0][0] - 分类结果"""
def classify0(inX, dataSet, labels, k):#numpy函数shape[0]返回dataSet的行数dataSetSize = dataSet.shape[0]#在列向量方向上重复inX共1次(横向),行向量方向上重复inX共dataSetSize次(纵向)diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet#二维特征相减后平方sqDiffMat = diffMat**2#sum()所有元素相加,sum(0)列相加,sum(1)行相加sqDistances = sqDiffMat.sum(axis=1)#开方,计算出距离distances = sqDistances**0.5#返回distances中元素从小到大排序后的索引值sortedDistIndices = distances.argsort()#定一个记录类别次数的字典classCount = {}for i in range(k):#取出前k个元素的类别voteIlabel = labels[sortedDistIndices[i]]#dict.get(key,default=None),字典的get()方法,返回指定键的值,如果值不在字典中返回默认值。#计算类别次数classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1#python3中用items()替换python2中的iteritems()#key=operator.itemgetter(1)根据字典的值进行排序#key=operator.itemgetter(0)根据字典的键进行排序#reverse降序排序字典sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)#返回次数最多的类别,即所要分类的类别return sortedClassCount[0][0]"""
函数说明:打开并解析文件,对数据进行分类:1代表不喜欢,2代表魅力一般,3代表极具魅力Parameters:filename - 文件名
Returns:returnMat - 特征矩阵classLabelVector - 分类Label向量
"""
def file2matrix(filename):#打开文件fr = open(filename)#读取文件所有内容arrayOLines = fr.readlines()#得到文件行数numberOfLines = len(arrayOLines)#返回的NumPy矩阵,解析完成的数据:numberOfLines行,3列returnMat = np.zeros((numberOfLines,3))#返回的分类标签向量classLabelVector = []#行的索引值index = 0for line in arrayOLines:#s.strip(rm),当rm空时,默认删除空白符(包括'\n','\r','\t',' ')line = line.strip()#使用s.split(str="",num=string,cout(str))将字符串根据'\t'分隔符进行切片。listFromLine = line.split('\t')#将数据前三列提取出来,存放到returnMat的NumPy矩阵中,也就是特征矩阵returnMat[index,:] = listFromLine[0:3]#根据文本中标记的喜欢的程度进行分类,1代表不喜欢,2代表魅力一般,3代表极具魅力if listFromLine[-1] == 'didntLike':classLabelVector.append(1)elif listFromLine[-1] == 'smallDoses':classLabelVector.append(2)elif listFromLine[-1] == 'largeDoses':classLabelVector.append(3)index += 1return returnMat, classLabelVector"""
函数说明:对数据进行归一化Parameters:dataSet - 特征矩阵
Returns:normDataSet - 归一化后的特征矩阵ranges - 数据范围minVals - 数据最小值"""
def autoNorm(dataSet):#获得数据的最小值minVals = dataSet.min(0)maxVals = dataSet.max(0)#最大值和最小值的范围ranges = maxVals - minVals#shape(dataSet)返回dataSet的矩阵行列数normDataSet = np.zeros(np.shape(dataSet))#返回dataSet的行数m = dataSet.shape[0]#原始值减去最小值normDataSet = dataSet - np.tile(minVals, (m, 1))#除以最大和最小值的差,得到归一化数据normDataSet = normDataSet / np.tile(ranges, (m, 1))#返回归一化数据结果,数据范围,最小值return normDataSet, ranges, minVals"""
函数说明:通过输入一个人的三维特征,进行分类输出Parameters:无
Returns:无
"""
def classifyPerson():#输出结果resultList = ['讨厌','有些喜欢','非常喜欢']#三维特征用户输入precentTats = float(input("玩视频游戏所耗时间百分比:"))ffMiles = float(input("每年获得的飞行常客里程数:"))iceCream = float(input("每周消费的冰激淋公升数:"))#打开的文件名filename = "datingTestSet.txt"#打开并处理数据datingDataMat, datingLabels = file2matrix(filename)#训练集归一化normMat, ranges, minVals = autoNorm(datingDataMat)#生成NumPy数组,测试集inArr = np.array([ffMiles, precentTats, iceCream])#测试集归一化norminArr = (inArr - minVals) / ranges#返回分类结果classifierResult = classify0(norminArr, normMat, datingLabels, 3)#打印结果print("你可能%s这个人" % (resultList[classifierResult-1]))"""
函数说明:main函数Parameters:无
Returns:无
"""
if __name__ == '__main__':classifyPerson()

结果:

实验二 使用sklearn实现knn

import numpy as np
import operator
from os import listdir
from sklearn.neighbors import KNeighborsClassifier as kNN"""
函数说明:将32x32的二进制图像转换为1x1024向量。Parameters:filename - 文件名
Returns:returnVect - 返回的二进制图像的1x1024向量Modify:2020-07-14
"""
def img2vector(filename):#创建1x1024零向量returnVect = np.zeros((1, 1024))#打开文件fr = open(filename)#按行读取for i in range(32):#读一行数据lineStr = fr.readline()#每一行的前32个元素依次添加到returnVect中for j in range(32):returnVect[0, 32*i+j] = int(lineStr[j])#返回转换后的1x1024向量return returnVect"""
函数说明:手写数字分类测试Parameters:无
Returns:无Modify:2020-07-15
"""
def handwritingClassTest():#测试集的LabelshwLabels = []#返回trainingDigits目录下的文件名trainingFileList = listdir('trainingDigits')#返回文件夹下文件的个数m = len(trainingFileList)#初始化训练的Mat矩阵,测试集trainingMat = np.zeros((m, 1024))#从文件名中解析出训练集的类别for i in range(m):#获得文件的名字fileNameStr = trainingFileList[i]#获得分类的数字classNumber = int(fileNameStr.split('_')[0])#将获得的类别添加到hwLabels中hwLabels.append(classNumber)#将每一个文件的1x1024数据存储到trainingMat矩阵中trainingMat[i,:] = img2vector('trainingDigits/%s' % (fileNameStr))#构建kNN分类器neigh = kNN(n_neighbors = 3, algorithm = 'auto')#拟合模型, trainingMat为训练矩阵,hwLabels为对应的标签neigh.fit(trainingMat, hwLabels)#返回testDigits目录下的文件列表testFileList = listdir('testDigits')#错误检测计数errorCount = 0.0#测试数据的数量mTest = len(testFileList)#从文件中解析出测试集的类别并进行分类测试for i in range(mTest):#获得文件的名字fileNameStr = testFileList[i]#获得分类的数字classNumber = int(fileNameStr.split('_')[0])#获得测试集的1x1024向量,用于训练vectorUnderTest = img2vector('testDigits/%s' % (fileNameStr))#获得预测结果# classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)classifierResult = neigh.predict(vectorUnderTest)print("分类返回结果为%d\t真实结果为%d" % (classifierResult, classNumber))if(classifierResult != classNumber):errorCount += 1.0print("总共错了%d个数据\n错误率为%f%%" % (errorCount, errorCount/mTest * 100))

结果:

上述代码使用的algorithm参数是auto,更改algorithm参数为brute,使用暴力搜索,你会发现,运行时间变长了,变为10s+。更改n_neighbors参数,你会发现,不同的值,检测精度也是不同的。自己可以尝试更改这些参数的设置,加深对其函数的理解。

六、总结

1、kNN算法的优缺点

优点:

  • 简单好用,容易理解,精度高,理论成熟,既可以用来做分类也可以用来做回归
  • 可用于数值型数据和离散型数据;
  • 训练时间复杂度为O(n);
  • 对异常值不敏感

缺点:

  • 计算复杂性高;空间复杂性高;
  • 样本不平衡问题(即有些类别的样本数量很多,而其它样本的数量很少);
  • 一般数值很大的时候不用这个,计算量太大。但是单个样本又不能太少,否则容易发生误分。
  • 最大的缺点是无法给出数据的内在含义。

2、其他

关于algorithm参数kd_tree的原理,可以查看《统计学方法 李航》书中的讲解;
关于距离度量的方法还有雪夫距离、马氏距离、巴氏距离;

参考资料:

  1. https://cuijiahua.com/blog/2017/11/ml_5_bayes_2.html
  2. 统计学习方法-李航
  3. 统计学习方法-课后练习
  4. 机器学习实战

笔记会放在我的Github里面

统计学方法机器学习实战(二) K近邻算法相关推荐

  1. 机器学习实战之K近邻算法

    k近邻算法概述 简单地说,K近邻算法采用测量不同特征值之间的距离方法进行分类. 优 点 :精度高.对异常值不敏感.无数据输入假定. 缺点:计算复杂度高.空间复杂度高. 适用数据范围:数值型和标称型. ...

  2. 【机器学习实战】k近邻算法实战——手写识别系统

    文章目录 手写识别系统 步骤: 准备数据:将图像转换为测试向量 测试算法:使用k-近邻算法识别手写数字 [完整代码] 手写识别系统 为了简单起见,这里构造的系统只能识别数字0到9,参见图2-6.需要识 ...

  3. 机器学习-分类之K近邻算法(KNN)原理及实战

    k近邻算法(KNN) 简介 KNN算法是数据挖掘分类技术中最简单的方法之一.它通过测量不同特征值之间的距离进行分类的.其基本思路为:如果一个样本在特征空间中的k个最近邻样本中的大多数属于某一个类别,则 ...

  4. 【白话机器学习】算法理论+实战之K近邻算法

    作者1. 写在前面 如果想从事数据挖掘或者机器学习的工作,掌握常用的机器学习算法是非常有必要的,在这简单的先捋一捋, 常见的机器学习算法: 监督学习算法:逻辑回归,线性回归,决策树,朴素贝叶斯,K近邻 ...

  5. k近邻算法_【白话机器学习】算法理论+实战之K近邻算法

    1. 写在前面 如果想从事数据挖掘或者机器学习的工作,掌握常用的机器学习算法是非常有必要的,在这简单的先捋一捋, 常见的机器学习算法: 监督学习算法:逻辑回归,线性回归,决策树,朴素贝叶斯,K近邻,支 ...

  6. 白话机器学习算法理论+实战之K近邻算法

    1. 写在前面 如果想从事数据挖掘或者机器学习的工作,掌握常用的机器学习算法是非常有必要的,比如我之前写过的一篇十大机器学习算法的小总结,在这简单的先捋一捋, 常见的机器学习算法: 监督学习算法:逻辑 ...

  7. 01. 机器学习笔记01——K近邻算法 , CV_example

    K近邻算法(K-nearest neighbor,KNN算法) 李航博士<统计学习方法> 最近邻(k-Nearest Neighbors,KNN)算法是一种分类算法 应用场景:字符识别.文 ...

  8. 2、python机器学习基础教程——K近邻算法鸢尾花分类

    一.第一个K近邻算法应用:鸢尾花分类 import numpy as np from sklearn.datasets import load_iris from sklearn.model_sele ...

  9. 机器学习——聚类之k近邻算法及python使用

    聚类算法之k近邻及python使用 什么是k近邻算法 k近邻算法流程 使用sklearn进行代码实现 数据集介绍 标准化 代码实现 写在开头,套用我的老师的一句话目前所有自然学科的前沿都是在研究数学, ...

最新文章

  1. Ajax缓存解决办法
  2. Intelli IDEA导入jar包
  3. 20年的老程序员对新入行的朋友的一些建议
  4. 【SAM】loj#6401. 字符串
  5. matlab vision hdl,MATLAB下载,MATLAB购买,MATLAB试用,MATLAB介绍,MATLAB评价
  6. python多进程间通信
  7. 上粱正,下粱不歪——网吧母盘制作流程(转)
  8. java 员工管理系统
  9. java的hashmap排序_Java HashMap两种简便排序方法解析
  10. 中标麒麟使用centos源_中标麒麟操作系统使用笔记
  11. 在linux下MySQL的常用操作命令
  12. 按键精灵找文字的基础代码模板
  13. s3c2440芯片累加汇编语言,S3C2440—3.用点亮LED来熟悉裸机开发的详细流程
  14. 01_多操作系统课题研究[2011-01-21]
  15. [乐意黎]Nginx 重写wordpress路径于二级子目录方法
  16. C语言单元测试之安装gtest教程及一个简单样例
  17. 用例图、活动图、时序图、类图的详细介绍
  18. 其实你也可以制作一款专属的书架app,信不信看看就知道
  19. html对颜色加深,css字体阴影如何加深?
  20. Oracle等待事件(三)—— buffer busy waits 常见原因及对应解决方法

热门文章

  1. 第1章 数字图像处理绪论
  2. 本质安全设备标准(IEC60079-11)的理解(二)
  3. python实现扎破气球----童年经典游戏
  4. 数据中心机柜散热解决方案,知道这两点就够了!
  5. 游戏行业从业者过了30岁后都是如何发展的?游戏美术设计呢?
  6. 燕十八mysql基础复习
  7. 使用element-ui中tree树状图
  8. 【CodeForces】CF13C Sequence(配数学证明)
  9. 生鲜电商Instacart再度下调估值:降至130亿美元 暂不IPO
  10. Info.plist 文件 和pch文件