文章目录

  • 一、K-近邻算法(KNN)概述
  • 二、Python实现
  • 三、Python实现手写数字识别
  • 四、K值选取对数据准确率的影响

在做此项目之前,首先要明白何为KNN算法。

一、K-近邻算法(KNN)概述

最简单最初级的分类器是将全部的训练数据所对应的类别都记录下来,当测试对象的属性和某个训练对象的属性完全匹配时,便可以对其进行分类。但是怎么可能所有测试对象都会找到与之完全匹配的训练对象呢,其次就是存在一个测试对象同时与多个训练对象匹配,导致一个训练对象被分到了多个类的问题,基于这些问题呢,就产生了KNN。

KNN是通过测量不同特征值之间的距离进行分类。它的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别,其中K通常是不大于20的整数。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。

下面通过一个简单的例子说明一下:如下图,绿色圆要被决定赋予哪个类,是红色三角形还是蓝色四方形?如果K=3,由于红色三角形所占比例为2/3,绿色圆将被赋予红色三角形那个类,如果K=5,由于蓝色四方形比例为3/5,因此绿色圆被赋予蓝色四方形类。

由此也说明了KNN算法的结果很大程度取决于K的选择。

在KNN中,通过计算对象间距离来作为各个对象之间的非相似性指标,避免了对象之间的匹配问题,在这里距离一般使用欧氏距离或曼哈顿距离:

同时,KNN通过依据k个对象中占优的类别进行决策,而不是单一的对象类别决策。这两点就是KNN算法的优势。

接下来对KNN算法的思想总结一下:就是在训练集中数据和标签已知的情况下,输入测试数据,将测试数据的特征与训练集中对应的特征进行相互比较,找到训练集中与之最为相似的前K个数据,则该测试数据对应的类别就是K个数据中出现次数最多的那个分类,其算法的描述为:

1)计算测试数据与各个训练数据之间的距离;

2)按照距离的递增关系进行排序;

3)选取距离最小的K个点;

4)确定前K个点所在类别的出现频率;

5)返回前K个点中出现频率最高的类别作为测试数据的预测分类。

二、Python实现

首先呢,需要说明的是我用的是python3.6,里面有一些用法与2.7还是有些出入。

建立一个KNN.py文件对算法的可行性进行验证,如下:

#coding:utf-8from numpy import *
import operator##给出训练数据以及对应的类别
def createDataSet():group = array([[1.0,2.0],[1.2,0.1],[0.1,1.4],[0.3,3.5]])labels = ['A','A','B','B']return group,labels###通过KNN进行分类
def classify(input,dataSe t,label,k):dataSize = dataSet.shape[0]####计算欧式距离diff = tile(input,(dataSize,1)) - dataSetsqdiff = diff ** 2squareDist = sum(sqdiff,axis = 1)###行向量分别相加,从而得到新的一个行向量dist = squareDist ** 0.5##对距离进行排序sortedDistIndex = argsort(dist)##argsort()根据元素的值从大到小对元素进行排序,返回下标classCount={}for i in range(k):voteLabel = label[sortedDistIndex[i]]###对选取的K个样本所属的类别个数进行统计classCount[voteLabel] = classCount.get(voteLabel,0) + 1###选取出现的类别次数最多的类别maxCount = 0for key,value in classCount.items():if value > maxCount:maxCount = valueclasses = keyreturn classes

接下来,在命令行窗口输入如下代码:

#-*-coding:utf-8 -*-
import sys
sys.path.append("...文件路径...")
import KNN
from numpy import *
dataSet,labels = KNN.createDataSet()
input = array([1.1,0.3])
K = 3
output = KNN.classify(input,dataSet,labels,K)
print("测试数据为:",input,"分类结果为:",output)

回车之后的结果为:

测试数据为: [ 1.1 0.3] 分类为: A

答案符合我们的预期,要证明算法的准确性,势必还需要通过处理复杂问题进行验证,之后另行说明。

三、Python实现手写数字识别

训练集和测试集地址如下:
链接:https://pan.baidu.com/s/1lgRDsEp5IGcVn348jjEc8w 提取码:pqm7

测试集和训练集的数字是由1024个01组成
训练集从0-9一共1934个文件
测试集从0-9一共945个文件

思路:
如果直接以每个文件作为测试集的话,效率低而且相对复杂,所以我的思路是将每个txt中的内容拼成一行,然后将所有的txt都存储在一个txt中。

这样训练集就变成一个1934行,1024列的txt,
同理测试集变成了945行1024列的txt。

每一行代表一个测试集/训练集,每一列代表一个属性,但是只有属性没有类别不行,文件名的第一位就是对应的数字,所以再将属性一起拼接在txt中,所以训练集和测试集就变成了1025列,前1024列为属性,最后一列为类别。

拉直存在一个txt后如下:

接下来就是代码实现了,为了让程序看起来更python,我使用了类进行封装。

import  pandas as pd #导入pands、numpy并设置别名pd、np
import numpy as np
import os,randomclass NumRecog():def __init__(self,test,train):self.mkfile(test)  ##生成测试集txtself.mkfile(train)  ##生成训练集txtdef mkfile(self,file):  #训练集生成函数if os.path.exists(file+'.txt'):  ##如果文件存在则不返回self.long = len(os.listdir('./digits/testDigits/')) #获取测试集的长度,以供后面随机选择测试集进行测试else:  #否则生成文件num_list = os.listdir('./digits/{}/'.format(file)) #获取所有文件for i in num_list:a = open('./digits/{}/{}'.format(file,i),encoding='utf-8') #打开文件with open(r'E:\学习\数据分析\监督学习\{}.txt'.format(file),'a',encoding='utf-8') as f: #写入新文件#每个文件撸直并以‘,’隔开每个数字f.write(','.join(list(''.join(a.readlines()).replace('\n','')+i[0]+'\n'))) def DateFrameCreate(self): #数据处理函数testF = pd.read_table('testDigits.txt',sep=',') #以,为分隔将文件变成dataframetestF.drop(testF.columns[-1],axis=1,inplace=True) #删除最后一列,最后一列为空f = pd.read_table('trainingDigits.txt',sep=',') #同理f.drop(f.columns[-1],axis=1,inplace=True)return self.KNN(testF,f)  #将处理好的数据返回给KNN函数def KNN(self,testF,f): #KNN函数i = random.randint(1,self.long) #随机取测试集一个测试集f_Set = f.iloc[:, :-1]  #最后一列是作为类型,先剔除testF_set = testF.iloc[i,:-1] #最后一列是作为类型,先剔除testF_set = np.tile(testF_set, [len(f_Set),1]) #把测试集扩展成与训练集同行数#生成新的DataframenewF = pd.DataFrame({'数字':f.iloc[:,-1],'相似度':np.sqrt(np.sum((testF_set-f_Set)**2, axis=1))})newF = newF.sort_values(by='相似度').head(100) #k=100count = newF['数字'].value_counts() #按数字统计return count.index[0],testF.iloc[i,-1] #返回统计最多的数字和测试原数字if __name__ == '__main__':numre = NumRecog('testDigits','trainingDigits')re = numre.DateFrameCreate()print('获取数字:',re[0],'  检测数字为:',re[1])

结果如下:

四、K值选取对数据准确率的影响


经过对K = 3~300的不同取值得出:
能得出结论在此事件建中k<总测试集数量10%,准确率在90%以上

Python人工智能--实现手写数字识别相关推荐

  1. python svm实现手写数字识别——直接可用

    python svm实现手写数字识别--直接可用 1.训练 1.1.训练数据集下载--已转化成csv文件 1.2 .训练源码 2.预测单张图片 2.1.待预测图像 2.2.预测源码 2.3.预测结果 ...

  2. Python不同方法实现手写数字识别结果和代码

    一.背景 手写数字识别是符号识别的一个分支,虽然只是识别简单的10个数字,但却有着非常大的实用价值.在我们的日常生活中,每天都要进行大量的文档处理工作,比如税单,银行支票,汇款单,信用卡账单的处理,以 ...

  3. python手写数字识别实验报告_机器学习python实战之手写数字识别

    看了上一篇内容之后,相信对K近邻算法有了一个清晰的认识,今天的内容--手写数字识别是对上一篇内容的延续,这里也是为了自己能更熟练的掌握k-NN算法. 我们有大约2000个训练样本和1000个左右测试样 ...

  4. 利用python卷积神经网络手写数字识别_卷积神经网络使用Python的手写数字识别

    为了使机器更智能,开发人员正在研究机器学习和深度学习技术.人类通过反复练习和重复执行任务来学习执行任务,从而记住了如何执行任务.然后,他大脑中的神经元会自动触发,它们可以快速执行所学的任务.深度学习与 ...

  5. python实现mnist手写数字识别

    看了<python神经网络编程>,跟着书上敲了一下mnist手写数字的代码,对神经网络有了初步的了解. 此项目为三层神经网络识别,激活函数采用sigmoid函数,数据集为mnist手写数字 ...

  6. 利用python卷积神经网络手写数字识别_Keras深度学习:卷积神经网络手写数字识别...

    引言:最近在闭关学习中,由于多久没有写博客了,今天给大家带来学习的一些内容,还在学习神经网络的同学,跑一跑下面的代码,给你一些自信吧!Nice 奥里给! 正文:首先该impor的库就不多说了,不会的就 ...

  7. 【Python学习】 - 手写数字识别 - python读入mnist数据集的多种方法

    写在前面: 其实网上有很多读入mnist数据的代码,但是都是比较麻烦冗长的函数,本篇文章介绍几种不算很麻烦的,借用库函数读入数据的方法. 方法1: 方法2: 方法3:

  8. Python+Keras实现手写数字识别

    机器学习入门级别项目,话不多说,我看的教程在这里,代码来源也是这里,快附上B站教程:https://www.bilibili.com/video/BV16g4y1z7Qu?from=search&am ...

  9. Python TensorFlow框架 实现手写数字识别系统

    手写数字识别算法的设计与实现 本文使用python基于TensorFlow设计手写数字识别算法,并编程实现GUI界面,构建手写数字识别系统.这是本人的本科毕业论文课题,当然,这个也是机器学习的基本问题 ...

最新文章

  1. What is a lambda expression in C++11?
  2. c#_文件的保存与读取
  3. 凑微分公式_武忠祥真题班归纳(更新至多元函数微分学)
  4. 90后,是时候想想你的副业了
  5. 商用工程开发中的一些习惯
  6. Redis 5种数据结构使用及注意事项
  7. 使用layui框架时,select的onchange事件没有生效。
  8. arduino控制点阵屏与蜂鸣器_还在为遥控项目发愁?Arduino遥控套装解决你的所有问题(下)...
  9. [转载] Python和java中的垃圾回收机制
  10. 拓端tecdat|在R语言中轻松创建关联网络
  11. 必做作业3:原型化系统---乘车app
  12. SpringBoot-Learning-作者:翟永超
  13. 井字棋小游戏c语言简单编码,C语言实现简易井字棋游戏
  14. python3.6library 学习 1.introduction,2.built-infunction
  15. Day10 多态 final
  16. 基于opencv-python的人脸识别、眼睛识别和微笑识别
  17. f2fs系列文章fsck(五)
  18. 群晖添加第三方套件源提示无效位置的解决方法(解决群晖 CA 根证书过期的问题)
  19. 软件开发本质论——自然之路 1
  20. js中避免全局变量冗杂的解决方式

热门文章

  1. 浅谈IT项目成熟度的评估
  2. java鼠标指针锤子,如果用4个技能打不过他,那就用7个
  3. 什么是区块链?详细介绍区块链。
  4. idea 跳转到方法调用处
  5. matlab位置1处索引超出数组边界错误怎么改呢
  6. 10个帮程序员减压放松的网站,爽!
  7. 无人机飞控处理器DFU方式刷机方法(STM32单片机)
  8. 微信小程序开发大赛经验总结
  9. 使用python进行数据分析之电影评分
  10. matlab绩点计算程序_用Matlab计算学分绩