目录

0. 前言

1. k-近邻算法kNN(k-Nearest Neighbor)

2. 实战案例

2.1. 简单案例

2.2. 约会网站案例

2.3. 手写识别案例


学习完机器学习实战的k-近邻算法,简单的做个笔记。文中部分描述属于个人消化后的理解,仅供参考。

所有代码和数据可以访问 我的 github

如果这篇文章对你有一点小小的帮助,请给个关注喔~我会非常开心的~

0. 前言

k-近邻算法kNN(k-Nearest Neighbor)是一种监督学习的分类算法,算法思想是通过判断向量之间的距离,决定所属的类别。

  • 优点:精度高、对异常值不敏感
  • 缺点:计算复杂度高、空间复杂度高
  • 适用数据类型:数值型和标称型

1. k-近邻算法kNN(k-Nearest Neighbor)

算法流程可描述如下:

  1. 已知待测试样本  ,训练集合 
  2. 计算待测试样本与训练集合中每一个样本的欧式距离 
  3. 对  从小到大排序
  4. 选择前  个距离最短的样本,其中出现次数最多的类别,就是待测试样本的分类结果

其中, 与  的欧式距离表示为:

注:kNN算法必须保存所有的样本数据集,并且每一个测试样本,都要计算其与所有样本数据的距离,所以时间复杂度和空间复杂度都很高。

2. 实战案例

以下将展示书中的三个案例的代码段,所有代码和数据可以在github中下载:

2.1. 简单案例

# coding:utf-8
from numpy import *
import operator"""
简单案例
"""# 创建数据集和标签
def createDataSet():group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])labels = ['A', 'A', 'B', 'B']return group, labels# 分类算法
def classify0(intX, dataSet, labels, k):dataSetSize = dataSet.shape[0]# tile():#   在行方向上重复 intX,dataSetSize 次#   在列方向上重复 intX,1 次diffMat = tile(intX, (dataSetSize, 1)) - dataSet# ** 表示平方sqDiffMat = diffMat ** 2# sum(axis=0) 表示每一列相加# sum(axis=1) 表示每一行相加sqDistances = sqDiffMat.sum(axis=1)distances = sqDistances ** 0.5# argsort():#   按照数值从小到大,对数字的索引进行排序sortedDistIndicies = distances.argsort()classCount = {}for i in range(k):voteIlabel = labels[sortedDistIndicies[i]]# {}.get(voteIlabel, 0):#   查找键值 voteIlabel,如果键值不存在则返回 0classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1# key=operator.itemgetter(1)#   获取对象第 1 个域的值sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)return sortedClassCount[0][0]if __name__ == '__main__':group, labels = createDataSet()intX = [0, 0]k = 3clasifierResult = classify0(intX, group, labels, k)print(clasifierResult)

2.2. 约会网站案例

# coding:utf-8
from numpy import *
import matplotlib.pyplot as plt
import operator"""
约会网站案例
"""# 将txt文中中的数据转换为矩阵
def file2matrix(filename):fr = open(filename)arrayOLines = fr.readlines()numberOfLines = len(arrayOLines)returnMat = zeros((numberOfLines, 3))classLabelVector = []index = 0for line in arrayOLines:# strip():#   移除字符串头尾的指定字符line = line.strip()listFromLine = line.split('\t')returnMat[index, :] = listFromLine[0:3]classLabelVector.append(int(listFromLine[-1]))index += 1return returnMat, classLabelVector# 均值归一化
def autoNorm(dataSet):# min(a):#   a=0 每列的最小值#   a=1 每行的最小值minVals = dataSet.min(0)maxVals = dataSet.max(0)meanVals = dataSet.mean(0)ranges = maxVals - minValsnormDataSet = zeros(shape(dataSet))m = dataSet.shape[0]normDataSet = dataSet - tile(meanVals, (m, 1))normDataSet = normDataSet / tile(ranges, (m, 1))return normDataSet, ranges, meanVals# 分类算法
def classify0(intX, dataSet, labels, k):dataSetSize = dataSet.shape[0]# tile():#   在行方向上重复 intX,dataSetSize 次#   在列方向上重复 intX,1 次diffMat = tile(intX, (dataSetSize, 1)) - dataSet# ** 表示平方sqDiffMat = diffMat ** 2# sum(axis=0) 表示每一列相加# sum(axis=1) 表示每一行相加sqDistances = sqDiffMat.sum(axis=1)distances = sqDistances ** 0.5# argsort():#   按照数值从小到大,对数字的索引进行排序sortedDistIndicies = distances.argsort()classCount = {}for i in range(k):voteIlabel = labels[sortedDistIndicies[i]]# {}.get(voteIlabel, 0):#   查找键值 voteIlabel,如果键值不存在则返回 0classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1# key=operator.itemgetter(1)#   获取对象第 1 个域的值sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)return sortedClassCount[0][0]# 测试分类算法
def datingClassTest():hoRatio = 0.1datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')normMat, ranges, meanVals = autoNorm(datingDataMat)m = normMat.shape[0]numTestVecs = int(m * hoRatio)correctCount = 0for i in range(numTestVecs):classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :],datingLabels[numTestVecs:m], 3)print('the classifier came back with: %d, the real answer is: %d'% (classifierResult, datingLabels[i]))if classifierResult == datingLabels[i]:correctCount += 1.0print('the total accuracy is: %f' % (correctCount / float(numTestVecs)))if __name__ == '__main__':datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')fig = plt.figure()# add_subplot(321):#   将画图分割成 3 行 2 列,现在这个在从左到右从上到下第 1 个ax = fig.add_subplot(111)ax.scatter(datingDataMat[:, 1], datingDataMat[:, 2],15.0 * array(datingLabels), 15.0 * array(datingLabels))plt.show()datingClassTest()

2.3. 手写识别案例

# coding:utf-8
from numpy import *
import operator
from os import listdir"""
手写识别案例
"""# 将01文本表示的图像转换为向量
def img2vector(filename):returnVect = zeros((1, 1024))fr = open(filename)for i in range(32):lineStr = fr.readline()for j in range(32):returnVect[0, 32 * i + j] = int(lineStr[j])return returnVect# 分类算法
def classify0(intX, dataSet, labels, k):dataSetSize = dataSet.shape[0]# tile():#   在行方向上重复 intX,dataSetSize 次#   在列方向上重复 intX,1 次diffMat = tile(intX, (dataSetSize, 1)) - dataSet# ** 表示平方sqDiffMat = diffMat ** 2# sum(axis=0) 表示每一列相加# sum(axis=1) 表示每一行相加sqDistances = sqDiffMat.sum(axis=1)distances = sqDistances ** 0.5# argsort():#   按照数值从小到大,对数字的索引进行排序sortedDistIndicies = distances.argsort()classCount = {}for i in range(k):voteIlabel = labels[sortedDistIndicies[i]]# {}.get(voteIlabel, 0):#   查找键值 voteIlabel,如果键值不存在则返回 0classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1# key=operator.itemgetter(1)#   获取对象第 1 个域的值sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)return sortedClassCount[0][0]# 测试分类算法
def handwritingClassTest():hwLabels = []# 读取目录下文件列表trainingFileList = listdir('trainingDigits')m = len(trainingFileList)trainingMat = zeros((m, 1024))for i in range(m):fileNameStr = trainingFileList[i]fileStr = fileNameStr.split('.')[0]classNumStr = int(fileStr.split('_')[0])hwLabels.append(classNumStr)trainingMat[i, :] = img2vector('trainingDigits/%s' % (fileNameStr))testFileList = listdir('testDigits')correctCount = 0.0mTest = len(testFileList)for i in range(mTest):fileNameStr = testFileList[i]fileStr = fileNameStr.split('.')[0]classNumStr = int(fileStr.split('_')[0])vectorUnderTest = img2vector('testDigits/%s' % (fileNameStr))classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)print('the classifier came back with: %d, the real answer is: %d'% (classifierResult, classNumStr))if classifierResult == classNumStr:correctCount += 1.0print('the total accuracy is: %f' % (correctCount / float(mTest)))if __name__ == '__main__':handwritingClassTest()

如果这篇文章对你有一点小小的帮助,请给个关注喔~我会非常开心的~

机器学习实战(一)k-近邻kNN(k-Nearest Neighbor)相关推荐

  1. 【Python机器学习】多项式回归、K近邻KNN回归的讲解及实战(图文解释 附源码)

    需要源码请点赞关注收藏后评论区留言私信~~~ 多项式回归 非线性回归是用一条曲线或者曲面去逼近原始样本在空间中的分布,它"贴近"原始分布的能力一般较线性回归更强. 多项式是由称为不 ...

  2. 机器学习经典算法具体解释及Python实现--K近邻(KNN)算法

    (一)KNN依旧是一种监督学习算法 KNN(K Nearest Neighbors,K近邻 )算法是机器学习全部算法中理论最简单.最好理解的.KNN是一种基于实例的学习,通过计算新数据与训练数据特征值 ...

  3. K 近邻法(K-Nearest Neighbor, K-NN)

    文章目录 1. k近邻算法 2. k近邻模型 2.1 模型 2.2 距离度量 2.2.1 距离计算代码 Python 2.3 kkk 值的选择 2.4 分类决策规则 3. 实现方法, kd树 3.1 ...

  4. Python+OpenCV:理解k近邻(kNN)算法(k-Nearest Neighbour (kNN) algorithm)

    Python+OpenCV:理解k近邻(kNN)算法(k-Nearest Neighbour (kNN) algorithm) 理论 kNN is one of the simplest classi ...

  5. K近邻算法(k-nearest neighbor,KNN)

    K近邻算法(k-nearest neighbor,KNN) 经过一周昏天黑地的加班之后,终于到了周末,又感觉到生活如此美好,遂提笔写一写KNN,这个也许是机器学习众多算法中思想最为简单直白的算法了,其 ...

  6. 机器学习实战(3)—— kNN实战约会网站

    机器学习实战(3)-- kNN实战约会网站 老板:小韩啊,别忘了去改进一下约会网站的配对效果. 我:好嘞好嘞!马上工作!! 好了,又要开始一天的工作啦.接着上篇文章老板布置的任务,我们来看一下这次实战 ...

  7. 机器学习-监督学习之分类算法:K近邻法 (K-Nearest Neighbor,KNN)

    目录 KNN概述 举个例子: K值选取 距离计算 曼哈顿距离,切比雪夫距离关系(相互转化) k-近邻(KNN)算法步骤 相关代码实现 简单实例:判断电影类别 创建数据集 数据可视化 分类测试 运行结果 ...

  8. 【机器学习经典算法】K近邻(KNN):核心与总结

    文章目录 1. 初识K近邻 2. 相知 2.1 K近邻三要素 2.2 KD树 2.2.1 kd树的构建 2.2.2 kd树的搜索 3. 总结 1. 初识K近邻 K-近邻(K-Nearest Neigh ...

  9. 机器学习算法(7)——K近邻(KNN)、K-means、模糊c-均值聚类、DBSCAN与层次与谱聚类算法

    1.K-近邻算法(KNN)概述 (有监督算法,分类算法) 最简单最初级的分类器是将全部的训练数据所对应的类别都记录下来,当测试对象的属性和某个训练对象的属性完全匹配时,便可以对其进行分类.但是怎么可能 ...

  10. R语言机器学习(1)-R的近邻分析—K近邻法

    转载请注明出处:https://blog.csdn.net/xiezhiliang22 对应书籍:<R数据挖掘>薛薇 中国人民大学出版社 1 前言 下面这几个部分主要讲述如何使用R语言来实 ...

最新文章

  1. MinGW安装与使用简介
  2. SQL Server 一些使用小技巧
  3. 【js笔记】数组那些事[0]
  4. mysql数据库的安全机制管理_mysql管理之安全机制
  5. 【iOS-Cocos2d游戏开发之二十一 】自定义精灵类并为你的精灵设置攻击帧(指定开始帧)以及扩展Cocos2d源码的CCAnimation简化动画创建!...
  6. 算法导论 练习12.3
  7. 仿写王者荣耀主页代码HTML CSS,CSS3实现王者荣耀匹配人员加载页面的方法
  8. cad计算机试题及答案,CAD与CAM试题及答案.doc
  9. 制作一个全盘搜索器 ,能搜索整个磁盘所有目录
  10. 【安全知识分享】2021年安全生产月主题宣讲课件(附下载)
  11. mysql存储函数中游标报错 No data - zero rows fetched, selected
  12. Datatype LP64 ILP64 LLP64 ILP32 LP32
  13. 计算机夏令营英语怎么说,“夏令营”英语怎么说
  14. Unity API常用方法和类学习笔记1
  15. 兄弟组件通过$bus调接口,数据赋值成功,但渲染没效果
  16. Java项目:SSM药品进货销售管理系统
  17. 我的新浪微博欢迎大家互粉
  18. python Numpy 生成一个随机矩阵(整数型)
  19. 删除电脑文件夹右击出现“自定义文件夹”选项
  20. 如何通过猪八戒网引流?猪八戒网怎么做推广?如何利用猪八戒网

热门文章

  1. 为什么会有jQuery、Dojo、Ext、Prototype、YUI、Zepto这么多JS包?
  2. 6410 spi 设备驱动
  3. OpenCV形态学操作
  4. web developer tips (38):如何用请求失败记录追踪重写规则
  5. java8 Lambda Stream collect Collectors 常用实例
  6. 2018.09.17 atcoder Digit Sum(数论)
  7. Asp.Net高级知识回顾_HttpModule及应用程序生命周期_1
  8. LeetCode 124. Binary Tree Maximum Path Sum
  9. Html5+NodeJS——拖拽多个文件上传到服务器
  10. Transaction rolled back because it has been marked as rollback-only 原因 和解决方案