K最近邻(k-Nearest Neighbor,KNN)分类算法可以说是最简单的机器学习算法了。它采用测量不同特征值之间的距离方法进行分类。它的思想很简单:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。

其算法描述如下:

1)计算已知类别数据集中的点与当前点之间的距离;

2)按照距离递增次序排序;

3)选取与当前点距离最小的k个点;

4)确定前k个点所在类别的出现频率;

5)返回前k个点出现频率最高的类别作为当前点的预测分类。

对于机器学习而已,Python需要额外安装三件宝,分别是Numpy,scipy和Matplotlib。前两者用于数值计算,后者用于画图。安装很简单,直接到各自的官网下载回来安装即可。安装程序会自动搜索我们的python版本和目录,然后安装到python支持的搜索路径下。反正就python和这三个插件都默认安装就没问题了。

一般实现一个算法后,我们需要先用一个很小的数据库来测试它的正确性,否则一下子给个大数据给它,它也很难消化,而且还不利于我们分析代码的有效性。

补充用python实现的代码,要给python装numpy和matplotlib库,建议直接装anaconda,装好了anaconda默认安装了spyder,里面集成了这两个库,比较方便。

首先,我们新建一个kNN.py脚本文件,文件里面包含两个函数,一个用来生成小数据库,一个实现kNN分类算法。代码如下:

# -*- coding: utf-8 -*-from numpy import *
def createDataSet():#创建一个很小的数据库group=array([[1.0,0.9],[1.0,1.0],[0.1,0.2],[0.0,0.1]])labels=['A','A','B','B']return group,labels
#[1.2,1.0]
def kNNClassify(newInput,dataset,labels,k):#knn算法核心#先求数据库中每个点到所求点之间的距离numSamples=dataset.shape[0] #获取数据库的行数diff=tile(newInput,(numSamples,1))-dataset#使用tile函数迭代创建一个numSample行1列的array与dataset做差squaredDiff=diff**2#diff中的每一个数都平方squaredDist=sum(squaredDiff,axis=1)#每一行两数求和distance=squaredDist**0.5#再开方#再对距离进行排序sortedDistIndices=argsort(distance)#argsort函数对distance中元素从小到大排序,返回序号classCount={}#统计距离小于等于k的每个点的类别for i in xrange(k):voteLabel=labels[sortedDistIndices[i]]classCount[voteLabel]=classCount.get(voteLabel,0)+1maxCount=0              #找出离所求点最近的k个点中最多的类别         for key,value in classCount.items():if maxCount<value:maxCount=valuemaxIndex=key#返回所求点的类型,算法到此结束return maxIndex

  然后我们在命令行中或在建一个python文件这样测试即可:

import kNN
from numpy import *dataSet,labels=kNN.createDataSet()
testX=array([0,0.05])
k=3
maxIndex=kNN.kNNClassify(testX,dataSet,labels,3)
print maxIndex

  运行程序:此时点[0,0.05]最接近B类。

应用:这里我们用kNN来分类一个大点的数据库,包括数据维度比较大和样本数比较多的数据库。这里我们用到一个手写数字的数据库,可以到这里下载。这个数据库包括数字0-9的手写体。每个数字大约有200个样本。每个样本保持在一个txt文件中。手写体图像本身的大小是32x32的二值图,转换到txt文件保存后,内容也是32x32个数字,0或者1.

数据库解压后有两个目录:目录trainingDigits存放的是大约2000个训练数据,testDigits存放大约900个测试数据。

编写kNN2.py:

# -*- coding: utf-8 -*-
from numpy import *
import os
def kNNClassify(newInput,dataset,labels,k):#knn算法核心#先求数据库中每个图像与所要分类图像像素值差的平方再开方,用这种计算方法表示距离(相似度)俗称欧氏距离numSamples=dataset.shape[0] #获取数据库的行数(即文件夹下的文件数)diff=tile(newInput,(numSamples,1))-dataset#使用tile函数迭代创建一个numSample行1列的array与dataset做差squaredDiff=diff**2#diff中的每一个数都平方squaredDist=sum(squaredDiff,axis=1)#每一行的数求和distance=squaredDist**0.5#再开方#再对距离进行排序sortedDistIndices=argsort(distance)classCount={}#统计距离为k的每个图像的类别(即统计相似度最小的k个图像所表示的数字)for i in xrange(k):  voteLabel = labels[sortedDistIndices[i]]  classCount[voteLabel] = classCount.get(voteLabel, 0) + 1maxCount=0            #找出离所求图像类别最近的k个图像中最多的类别       for key,value in classCount.items():if maxCount<value:maxCount=valuemaxIndex=key#返回所求图像的类型(类型即数字)return maxIndex#函数img2vector把一张32*32的图像转化成一行向量imgVector
def img2vector(filename):rows=32cols=32imgVector=zeros((1,rows*cols))fileIn=open(filename)for row in xrange(rows):lineStr=fileIn.readline()for col in xrange(cols):imgVector[0,row*32+col]=int(lineStr[col])return imgVector#函数loadDataSet从文件夹中加载多个文件数据,python对文件数据流加载到内存的操作很方便,这里的代码可以仔细理解加记忆一下
def loadDataSet():dataSetDir='/home/chao/Desktop/python_work/knn/'trainingFileList=os.listdir(dataSetDir+'trainingDigits')numSamples=len(trainingFileList)train_x=zeros((numSamples,1024))#使用zeros函数为train_x分配numSamples行,每行1024列,每行为一个图像转化后的数据,总共numSamples行train_y=[]#用来存放每个图像的真实值for i in xrange(numSamples):filename=trainingFileList[i]train_x[i,:]=img2vector(dataSetDir+'trainingDigits/%s'%filename)label=int(filename.split('_')[0])train_y.append(label)testingFileList=os.listdir(dataSetDir+'testDigits')numSamples=len(testingFileList)test_x=zeros((numSamples,1024))#同train_x,但这里表示的是测试图像文件的test_y=[]for i in xrange(numSamples):filename=testingFileList[i]test_x[i,:]=img2vector(dataSetDir+'testDigits/%s'%filename)label=int(filename.split('_')[0])test_y.append(label)return train_x,train_y,test_x,test_y#测试预测准确率
def testHandWritingClass():print "第一步:加载数据。。。"train_x,train_y,test_x,test_y=loadDataSet()numTestSamples=test_x.shape[0]#返回待测试图像的个数print "数据加载完成"matchCount=0#用来表示预测正确的图像的个数#每个待测图像都要调用一次knn预测其值for i in xrange(numTestSamples):print ipredict=kNNClassify(test_x[i],train_x,train_y,3)#这里k=3准确率达98.63%,如果改成k=1的话会达到98.97%if predict==test_y[i]:matchCount+=1accuracy=float(matchCount)/numTestSamplesprint matchCount#打印正确预测个数print "accuracy is:%.2f%%"%(accuracy*100)#打印正确预测准确率

测试非常简单,编写一个main.py:

# -*- coding: utf-8 -*-
import kNN2
kNN2.testHandWritingClass()

然后运行main.py观察正确率:

933个预测正确
accuracy is:98.63%

版权声明:原文地址http://www.cnblogs.com/lcbg/p/6491900.html

转载于:https://www.cnblogs.com/lcbg/p/6491900.html

knn算法(分类)-机器学习相关推荐

  1. Python机器学习实验二:1.编写代码,实现对iris数据集的KNN算法分类及预测

    Python机器学习实验二:编写代码,实现对iris数据集的KNN算法分类及预测 1.编写代码,实现对iris数据集的KNN算法分类及预测,要求: (1)数据集划分为测试集占20%: (2)n_nei ...

  2. KNN算法的机器学习基础

    KNN算法的机器学习基础 https://mp.weixin.qq.com/s/985Ym3LjFLdkmqbytIqpJQ 本文原标题 : Machine Learning Basics with ...

  3. 机器学习与深度学习——通过knn算法分类鸢尾花数据集iris求出错误率并进行可视化

    什么是knn算法? KNN算法是一种基于实例的机器学习算法,其全称为K-最近邻算法(K-Nearest Neighbors Algorithm).它是一种简单但非常有效的分类和回归算法. 该算法的基本 ...

  4. KNN算法——分类部分

    1.核心思想 如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性.也就是说找出一个样本的k个最近邻居,将这些邻居的属性的平均值赋给该 ...

  5. python的knn算法list_机器学习实战学习笔记1——KNN算法

    一.KNN算法概述: 1.KNN算法的工作原理是: (1)存在一个训练样本集,并且知道样本集中每一数据与所属分类的对应关系,即每个数据都存在分类标签. (2)若此时输入不带标签的新数据之后,将新数据的 ...

  6. python的knn算法list_[机器学习]kNN算法python实现(实例:数字识别)

    # 使用好任何机器学习算法的前提是选好Featuresfrom numpy import *import operatorfrom os import listdirdef classify0(inX ...

  7. 机器学习入门-kNN算法实现手写数字识别

    实验环境 Python:3.7.0 Anconda:3-5.3.1 64位 操作系统:win10 开发工具:sublime text(非必要) 简介 本次实验中的重点为采用kNN算法进行手写数字识别, ...

  8. 【机器学习_4】机器学习算法分类

    [机器学习_4]机器学习算法分类 机器学习算法可以分为传统的机器学习算法和深度学习. 传统机器学习算法主要包括以下五类: 回归:建立一个回归方程来预测目标值,用于连续型分布预测 分类:给定大量带标签的 ...

  9. 文本分类——KNN算法

    上一篇文章已经描述了朴素贝叶斯算法newgroup的分类实现,这篇文章采用KNN算法实现newgroup的分类. 文中代码参考:http://blog.csdn.net/yangliuy/articl ...

  10. 基于Python Scrapy爬虫改进KNN算法的网站分类系统

    目 录 摘 要 I Abstract II 第 1 章 绪 论 1 1.1课题的研究背景和意义 1 1.1.1目前网站分类的研究情况 1 1.1.2现有解决方案的优点与不足 1 1.1.3基于特征熵值 ...

最新文章

  1. 专访浪潮王虹莉 探互联网服务器市场的未来
  2. 数据处理_流数据处理利器
  3. php测试框架,PHPUnit使用
  4. 3.调用empty而不是检查size()是否为0
  5. 如何使用用户数据脚本在EC2实例上安装Apache Web Server
  6. Qt工作笔记-setWindowFlags的巧妙使用(使用|、、~运算符)
  7. in module     from . import multiarray ImportError: DLL load failed: 找不到指定的模块解决
  8. scala数据类型_Scala数据类型示例教程
  9. 京东一面:高并发下,如何保证分布式唯一全局 ID 生成?
  10. python学习笔记 - 函数: 传参列表副本(不改变列表本身)
  11. ajax里面可以alert吗,Javascript和AJAX,仅在使用alert()时有效
  12. python实现百度贴吧自动签到
  13. windows10怎么显示文件后缀(扩展名)?
  14. matlab读取wav文件
  15. ffmpeg-nvenc
  16. 【网页制作】CSS基本选择器讲解(附讲解视频)
  17. web前端开发远程调试工具Weinre
  18. hdu5454 Excited Database (线段树)
  19. C++ 知识结构思维导图
  20. 各种activation function(激活函数) 简介

热门文章

  1. .net chart(图表)控件的使用
  2. 压缩之后神经网络忘记了什么?Google研究员给出了答案
  3. “3D几何与视觉技术”全球在线研讨会第八期~识别3D中的物体和场景
  4. 任意形状文本检测:Look More Than Once
  5. 收藏 | 程序员生涯指南,在GitHub上获3.6万星
  6. 新突破!CVPR2019接收论文:新的基于自编码变换的无监督表示学习方法—AET
  7. 计算机内部总线和外部总线,科学网-怎样将计算机内部总线扩展为外部网络?-姜咏江的博文...
  8. 推荐系统遇上深度学习(一)--FM模型理论和实践
  9. 深度学习(四十三)条件变分自编码器概述
  10. 机器学习(五)PCA数据降维