基于KNN算法的手写体数字识别

KNN分类算法是一种经典的分类算法,属于懒惰学习算法的一种。

1.算法原理

工作原理:存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类。一般来说,我们只选择样本数据集中前k个最相似的数据,这就是kNN算法中k的出处,通常k是不大于20的整数。最后,选择k个最相似数据中出现次数最多的分类,作位新数据的分类。

kNN算法的一般流程
1.收集数据:可以使用任何方法。
2.准备数据:距离计算所需要的数值,最好是结构化的数据格式。
3.分析数据:可以使用任何方法。
4.训练算法:此步骤不使用与kNN算法
5.测试算法:计算错误率
6.使用算法:首先需要输入样本数据和结构化的输出结果,然后运行kNN算法判定输入数据分别属于那个分类,最后应用对计算出的分类执行后续处理。

2.手写识别系统大致流程

使用kNN算法的手写识别系统
1.收集数据:提供文本文件
2.准备数据:编写函数img2vector(),将图像格式转换为分类器使用的向量格式。
3.分析数据:在Python命令提示符中检查数据,确保它符合要求。
4.训练算法:此步骤不适用与kNN算法
5.测试算法:编写函数使用提供的部分数据集作为测试样本,测试样本与非测试样本的区别在于测试样本是已经完成分类的数据,如果预测分类与实际类别不同,则标记为一个错误。
6.使用算法:使用已编写好的算法来对测试样本进行测试

3.算法各模块程序介绍

3.1.kNN分类算法

伪代码
计算已知类别属性的数据集中的每个点依次执行以下操作:
1.计算已知类别数据集中的点与当前点之间的距离;
2.按照距离递增次序排序;
3.选取与当前点距离最小的k个点(此处取k=3);
4.确定前k个点所在类别的出现频率;
5.返回前k个点出现频率最高的类别作为当前点的预测分类。
#算法需要调用的python库
from numpy import *
import operator
from os import listdir
from skimage import data
import matplotlib.pyplot as plt
from skimage import io,color,transform

kNN算法程序:

def classify0(inX, dataSet, labels, k):dataSetSize = dataSet.shape[0]diffMat = tile(inX, (dataSetSize,1)) - dataSetsqDiffMat = diffMat**2sqDistances = sqDiffMat.sum(axis=1)distances = sqDistances**0.5           #距离计算sortedDistIndicies = distances.argsort()     classCount={}          for i in range(k):voteIlabel = labels[sortedDistIndicies[i]]    #选择距离最小的k个点classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)    #排序return sortedClassCount[0][0]

3.2.归一化数据(手写识别不适用)

由于可能遇到的特征值量纲不同,而形成干扰,故需对数据进行归一化。

def autoNorm(dataSet):minVals = dataSet.min(0)maxVals = dataSet.max(0)ranges = maxVals - minValsnormDataSet = zeros(shape(dataSet))m = dataSet.shape[0]normDataSet = dataSet - tile(minVals, (m,1))normDataSet = normDataSet/tile(ranges, (m,1))   #element wise dividereturn normDataSet, ranges, minVals

3.3.将图像转换为测试向量

由于本系统的训练测试数据是由32*32的txt文件构成,且文件名代表该样本标签,如下图所示:


故需要将其样本格式化处理为一个向量

#将图像矩阵转换为矩阵
def img2vector(filename):returnVect = zeros((1,1024))fr = open(filename)for i in range(32):lineStr = fr.readline()for j in range(32):returnVect[0,32*i+j] = int(lineStr[j])return returnVect

3.4.测试算法

#手写数字识别系统的测试函数
def handwritingClassTest():hwLabels = []trainingFileList = listdir('trainingDigits')m = len(trainingFileList)trainingMat = zeros((m,1024))for i in range(m):fileNameStr = trainingFileList[i]fileStr = fileNameStr.split('.')[0]classNumStr = int(fileStr.split('_')[0])#从文件名中解析分类数据hwLabels.append(classNumStr)trainingMat[i,:] = img2vector('trainingDigits/%s' %fileNameStr)testFileList = listdir('testDigits')errorCount = 0.0mTest = len(testFileList)for i in range(mTest):fileNameStr = testFileList[i]fileStr = fileNameStr.split('.')[0]classNumStr = int(fileStr.split('_')[0])vectorUnderTest = img2vector('testDigits/%s' %fileNameStr)classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)print("the classifier came back with: %d, the real answer is: %d" %(classifierResult,classNumStr))if(classifierResult != classNumStr):errorCount += 1.0print("\nthe total number of errors is: %d" %errorCount)print("\nthe total error rate is: %f" %(errorCount/float(mTest)))

测试效果如下:


3.5.应用图像处理函数

由于应用是使用图像作为输入的,所以需要将图像转换成32*32的txt文本格式

#应用样本处理函数,将图像变为符合要求的测试样本(32*32的txt文本)
def photosDeal():filename = '/9_1'mytest1 = io.imread('./handwriting_Yqx'+filename+'.png')print('the shape is  {}'.format(mytest1.shape))img_gray = color.rgb2gray(mytest1) #灰度化图像img_high = img_gray.shape[0]img_width = img_gray.shape[1]print('the gary_img shape is  {}'.format(img_gray.shape))print("the high of img is %d,the width is %d" %(img_high,img_width))for i in range(img_high):#二值化图像for j in range(img_width):if(img_gray[i][j] <= 0.5):img_gray[i][j] = 1else:img_gray[i][j] =0dst_img1 = transform.resize(img_gray,(32,32)) #缩放图像io.imshow(dst_img1)#plt.show()#将图片转为txtresult = ''for i in range(32):for j in range(32):result += str(int(dst_img1[i][j]))result+= '\n'with open('./txtYqx'+filename+'.txt',mode = 'w') as f:f.write(result)

3.6.应用测试程序

完成后对输入图像进行测试

 #测试函数
def YqxTrail():hwLabels = []trainingFileList = listdir('trainingDigits')m = len(trainingFileList)trainingMat = zeros((m,1024))for i in range(m):fileNameStr = trainingFileList[i]fileStr = fileNameStr.split('.')[0]classNumStr = int(fileStr.split('_')[0])#从文件名中解析分类数据hwLabels.append(classNumStr)trainingMat[i,:] = img2vector('trainingDigits/%s' %fileNameStr)testFileList = listdir('txtYqx')errorCount = 0.0mTest = len(testFileList)for i in range(mTest):fileNameStr = testFileList[i]fileStr = fileNameStr.split('.')[0]classNumStr = int(fileStr.split('_')[0])vectorUnderTest = img2vector('txtYqx/%s' %fileNameStr)classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)print("the classifier came back with: %d, the real answer is: %d" %(classifierResult,classNumStr))if(classifierResult != classNumStr):errorCount += 1.0print("\nthe total number of errors is: %d" %errorCount)print("\nthe total error rate is: %f" %(errorCount/float(mTest)))

4.程序运行效果

4.1.程序总体

kNN.py

from numpy import *
import operator
from os import listdir
from skimage import data
import matplotlib.pyplot as plt
from skimage import io,color,transformdef classify0(inX, dataSet, labels, k):dataSetSize = dataSet.shape[0]diffMat = tile(inX, (dataSetSize,1)) - dataSetsqDiffMat = diffMat**2sqDistances = sqDiffMat.sum(axis=1)distances = sqDistances**0.5sortedDistIndicies = distances.argsort()     classCount={}          for i in range(k):voteIlabel = labels[sortedDistIndicies[i]]classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)return sortedClassCount[0][0]def createDataSet():group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])labels = ['A','A','B','B']return group, labelsdef file2matrix(filename):fr = open(filename)numberOfLines = len(fr.readlines())         #get the number of lines in the filereturnMat = zeros((numberOfLines,3))        #prepare matrix to returnclassLabelVector = []                       #prepare labels return   fr = open(filename)index = 0for line in fr.readlines():line = line.strip()listFromLine = line.split('\t')returnMat[index,:] = listFromLine[0:3]classLabelVector.append(int(listFromLine[-1]))index += 1return returnMat,classLabelVectordef autoNorm(dataSet):minVals = dataSet.min(0)maxVals = dataSet.max(0)ranges = maxVals - minValsnormDataSet = zeros(shape(dataSet))m = dataSet.shape[0]normDataSet = dataSet - tile(minVals, (m,1))normDataSet = normDataSet/tile(ranges, (m,1))   #element wise dividereturn normDataSet, ranges, minValsdef datingClassTest():hoRatio = 0.50      #hold out 10%datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')       #load data setfrom filenormMat, ranges, minVals = autoNorm(datingDataMat)m = normMat.shape[0]numTestVecs = int(m*hoRatio)errorCount = 0.0for i in range(numTestVecs):classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]))if (classifierResult != datingLabels[i]):errorCount += 1.0print("the total error rate is: %f" % (errorCount/float(numTestVecs)))print(errorCount)def classifyPerson():resultList = ['not at all','in small doses','in large doses']percentTats = float(input("percentage of time spent playing video games?"))ffMiles = float(input("frequent flier miles earned per year?"))iceCream = float(input("liters of iceCream consumed per year?"))datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')normMat,ranges,minVals = autoNorm(datingDataMat)inArr = array([ffMiles,percentTats,iceCream])classifierResult = classify0((inArr - minVals)/ranges,normMat,datingLabels,3)print("You will probably like this person:",resultList[classifierResult-1])#将图像矩阵转换为矩阵
def img2vector(filename):returnVect = zeros((1,1024))fr = open(filename)for i in range(32):lineStr = fr.readline()for j in range(32):returnVect[0,32*i+j] = int(lineStr[j])return returnVect#手写数字识别系统的测试函数
def handwritingClassTest():hwLabels = []trainingFileList = listdir('trainingDigits')m = len(trainingFileList)trainingMat = zeros((m,1024))for i in range(m):fileNameStr = trainingFileList[i]fileStr = fileNameStr.split('.')[0]classNumStr = int(fileStr.split('_')[0])#从文件名中解析分类数据hwLabels.append(classNumStr)trainingMat[i,:] = img2vector('trainingDigits/%s' %fileNameStr)testFileList = listdir('testDigits')errorCount = 0.0mTest = len(testFileList)for i in range(mTest):fileNameStr = testFileList[i]fileStr = fileNameStr.split('.')[0]classNumStr = int(fileStr.split('_')[0])vectorUnderTest = img2vector('testDigits/%s' %fileNameStr)classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)print("the classifier came back with: %d, the real answer is: %d" %(classifierResult,classNumStr))if(classifierResult != classNumStr):errorCount += 1.0print("\nthe total number of errors is: %d" %errorCount)print("\nthe total error rate is: %f" %(errorCount/float(mTest)))#应用样本处理函数,将图像变为符合要求的测试样本(32*32的txt文本)
def photosDeal():filename = '/9_1'mytest1 = io.imread('./handwriting_Yqx'+filename+'.png')print('the shape is  {}'.format(mytest1.shape))img_gray = color.rgb2gray(mytest1) #灰度化图像img_high = img_gray.shape[0]img_width = img_gray.shape[1]print('the gary_img shape is  {}'.format(img_gray.shape))print("the high of img is %d,the width is %d" %(img_high,img_width))for i in range(img_high):#二值化图像for j in range(img_width):if(img_gray[i][j] <= 0.5):img_gray[i][j] = 1else:img_gray[i][j] =0dst_img1 = transform.resize(img_gray,(32,32)) #缩放图像io.imshow(dst_img1)#plt.show()#将图片转为txtresult = ''for i in range(32):for j in range(32):result += str(int(dst_img1[i][j]))result+= '\n'with open('./txtYqx'+filename+'.txt',mode = 'w') as f:f.write(result)#测试函数
def YqxTrail():hwLabels = []trainingFileList = listdir('trainingDigits')m = len(trainingFileList)trainingMat = zeros((m,1024))for i in range(m):fileNameStr = trainingFileList[i]fileStr = fileNameStr.split('.')[0]classNumStr = int(fileStr.split('_')[0])#从文件名中解析分类数据hwLabels.append(classNumStr)trainingMat[i,:] = img2vector('trainingDigits/%s' %fileNameStr)testFileList = listdir('txtYqx')errorCount = 0.0mTest = len(testFileList)for i in range(mTest):fileNameStr = testFileList[i]fileStr = fileNameStr.split('.')[0]classNumStr = int(fileStr.split('_')[0])vectorUnderTest = img2vector('txtYqx/%s' %fileNameStr)classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)print("the classifier came back with: %d, the real answer is: %d" %(classifierResult,classNumStr))if(classifierResult != classNumStr):errorCount += 1.0print("\nthe total number of errors is: %d" %errorCount)print("\nthe total error rate is: %f" %(errorCount/float(mTest)))

kNNTest.py

import kNN
from numpy import *
import importlib
import matplotlib
import matplotlib.pyplot as plt #导入matplotlib库,并将matplotlib.pyplot模块命名为plt
#kNN.photosDeal()
kNN.YqxTrail()

4.2.运行效果:


由于训练样本较少,故错误率较高。

基于KNN算法的手写体数字识别相关推荐

  1. 课程设计(毕业设计)—基于机器学习KNN算法手写数字识别系统—计算机专业课程设计(毕业设计)

    机器学习KNN算法手写数字识别系统 下载本文手写数字识别系统完整的代码和课设报告的链接(或者可以联系博主koukou(壹壹23七2五六98),获取源码和报告):https://download.csd ...

  2. 【Python】基于kNN算法的手写识别系统的实现与分类器测试

    基于kNN算法的手写识别系统 1.      数据准备 使用windows画图工具,手写0-9共10个数字,每个数字写20遍,共200个BMP文件. 方法如下,使用画图工具,打开网格线,调整像素为32 ...

  3. 并行化实现基于BP神经网络的手写体数字识别

    并行化实现基于BP神经网络的手写体数字识别 手写体数字识别可以堪称是神经网络学习的"Hello World" ,我今天要说的是如何实现BP神经网络算法的并行化,我们仍然是以手写体数 ...

  4. 基于BP神经网络的手写体数字识别matlab仿真实现

    目录 一.理论基础 二.核心程序 三.测试结果 一.理论基础 文字.数字识别是一个典型的模式识别问题,也是模式识别中一个非常重要的应用领域.在文字.数字识别系统中,手写体的文字与识别是一个较难的领域, ...

  5. 基于SVM+HOG的手写体数字识别

    本文是对下面这篇文章的一些略微详细的解释... OpenCV Hog+SVM 学习 最近在学习数字识别,搜索资料的时候,发现了这篇文章.文章很久了,是2013年发的,那时候我才刚上大学.....用的是 ...

  6. knn分类算法实现手写体数字识别python

    之前写过knn分类算法代码,想把knn用于设别手写体数字,看下正确率. 大概思路:获取图片(可以自己写,我之前有写过黑白图片转文本的代码,也可以网上找,反正数据量大会更好)->转成文本-> ...

  7. 基于tensorflow的minst手写体数字识别

    引言 TensorFlow 是一个采用数据流图,用于数值计算的开源软件库.它是一个不严格的"神经网络"库,可以利用它提供的模块搭建大多数类型的神经网络.它可以基于CPU或GPU运行 ...

  8. 基于matlab的手写体数字识别系统

    摘要:随着科学技术的发展,机器学习成为一大学科热门领域,是一门专门研究计算机怎样模拟或实现人类的学习行为的交叉学科.文章在matlab软件的基础上,利用BP神经网络算法完成手写体数字的识别. 机器学习 ...

  9. 手写体数字识别的两种方法

    基于贝叶斯模型和KNN模型分别对手写体数字进行识别 首先,我们准备了0~9的训练集和测试集,这些手写体全部经过像素转换,用0,1表示,有颜色的区域为0,没有颜色的区域为1.实现代码如下: # 图片处理 ...

最新文章

  1. Android数据持久化:文件存储
  2. 远程访问,文件的压缩,ip地址的设置(9,11,12unit)
  3. Buck开关电源拓扑结构分析
  4. 既要宽广,又要深邃,这也行
  5. Linux+pycharm下 安装tensorflow时遇到的bug
  6. MAC电脑:安装mysql报ERROR 1045 (28000)Access denied
  7. 项目中通用的顶部标题和返回的TitleBar
  8. Mybatis-Plus的SQL语句组拼原理
  9. 求出数组中元素的总和_数组中所有元素的总和可被给定数K整除
  10. Element-UI Form表单 resetFields() 重置表单无效问题
  11. 【MySQL】MySQL Shell 简介与使用
  12. c语言实验转换字母顺序结构,实验1顺序结构的程序设计-实验报告.doc
  13. 按下什么使物体复位_什么是继电器?继电器如何接线?
  14. 电机与拖动matlab仿真,电机与拖动MATLAB仿真与学习指导(普通高等教育十一五电气信息类规划教材)...
  15. mysql 查询父子关系_递归查询具有父子关系的表
  16. 1538G. Gift Set
  17. php 二维码在线识别api
  18. Unity FairyGUI(十二)
  19. JSF 2 简介,第 2 部分: 模板及复合组件
  20. 机器视觉和计算机视觉理解

热门文章

  1. 计算机本科毕业生去当兵,关于本科毕业生入伍,副连级待遇
  2. PC版马赛克拼图生成 AndreaMosaic单文件版!
  3. java导出excel限制大小_解决java poi导出excel2003不能超过65536行的问题
  4. 王垠:谈 Linux,Windows 和 Mac ( 2013)
  5. Immersionbar学习笔记
  6. jquery中.eq()与:eq()的区别
  7. 微星的测试软件显示教程,Dragon Center使用教程
  8. java开发环境搭建——UltraEdit下载安装
  9. 密码计算机手机版,手机密码软件
  10. 使用图形注意网络进行欺诈检测