说明

使用pytorch框架,实现对MNIST手写数字数据集的训练和识别。重点是,自己手写数字,手机拍照后传入电脑,使用你自己训练的权重和偏置能够识别。数据预处理过程的代码是重点。

分析

要识别自己用手在纸上写的数字,从特征上来看,手写数字相比于普通的电脑上的数字最大的 不同就是数字的边缘会发生不同幅度的抖动。而且,在MNIST数据集中的数字是边缘为黑色的,然后数字是不同灰度的白色的,如下所示:

在数据集中,每个数据都是28∗2828*2828∗28的灰度图,并且黑色部分都是零,其余白色的灰度值并不统一。因为如果训练时背景都是统一的时候我们测试用的图片背景也必须是统一的,否则基本无法识别出来。除非训练的时候换各种不同的背景大数据进行训练,这样特征就不会依托着背景而存在,剩下的就是要识别的物体自己所拥有的特征了。所以在这里我要做的就是在图片预处理的时候尽量让图片处理成接近测试图片的样子。

训练网络

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader# 下载训练集
train_dataset = datasets.MNIST(root='./data/',train=True,transform=transforms.ToTensor(),download=False)
# 下载测试集
test_dataset = datasets.MNIST(root='./data/',train=False,transform=transforms.ToTensor(),download=False)# 设置批次数
batch_size = 100# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset = train_dataset,batch_size = batch_size,shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset = test_dataset,batch_size = batch_size,shuffle = True)# 自定义手写数字识别网络
class net(nn.Module):def __init__(self):super(net, self).__init__()self.Conn_layers = nn.Sequential(nn.Linear(784, 100),nn.Sigmoid(),nn.Linear(100, 10),nn.Sigmoid())def forward(self, input):output = self.Conn_layers(input)return output# 定义学习率
LR = 0.1# 定义一个网络对象
net = net()# 损失函数使用交叉熵
loss_function = nn.CrossEntropyLoss()# 优化函数使用 SGD
optimizer = optim.SGD(net.parameters(),lr = LR,momentum = 0.9,weight_decay = 0.0005
)# 定义迭代次数
epoch = 20# 进行迭代训练
for epoch in range(epoch):for i, data in enumerate(train_loader):inputs, labels = data# 转换下输入形状inputs = inputs.reshape(batch_size, 784)inputs, labels = Variable(inputs), Variable(labels)outputs = net(inputs)loss = loss_function(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()# 初始化正确结果数为0test_result = 0# 用测试数据进行测试for data_test in test_loader:images, labels = data_test# 转换下输入形状images = images.reshape(batch_size, 784)images, labels = Variable(images), Variable(labels)output_test = net(images)# 对一个批次的数据的准确性进行判断for i in range(len(labels)):# 如果输出结果的最大值的索引与标签内正确数据相等,准确个数累加if torch.argmax(output_test[i]) == labels[i]:test_result += 1# 打印每次迭代后正确的结果数print("Epoch {} : {} / {}".format(epoch, test_result, len(test_dataset)))# 保存权重模型
torch.save(net, 'weight/test.pkl')

至此,对手写数字网络的训练已经结束,且训练的准确性为:

这个网络比较粗糙,所以准确性也只是一般,但如果要精确起来后面有很多文章可做。

图像预处理

因为我们手机拍的照片和训练集的图片有很大的区别,所以无法将手机上拍的照片直接丢到训练好的网络模型中进行识别,需要先对图片进行预处理。有几点需要对原图进行改变:

  1. 图片的大小:肯定得将拍摄到的图片转换成28∗2828*2828∗28尺寸大小的图片。
  2. 图片的通道数:由于MNIST是灰度图,所以原图的channel也得转换成1。
  3. 图片的背景:图片的背景得转换成MNIST相同的黑色,这样识别结果准确性更高。
  4. 数字的颜色:毋庸置疑,数字的颜色得变成MNIST相同的白色。
  5. 数字颜色中间深边缘前:观察MNIST的白色部分并不都是255全白,而是有渐变色的,这个渐变色模拟起来比较困难,算是难度最大的一点了。
    接下来直接上代码了:
import cv2
import numpy as npdef image_preprocessing():# 读取图片img = cv2.imread("picture/test8.jpeg")# =====================图像处理======================== ## 转换成灰度图像gray_img = cv2.cvtColor(img , cv2.COLOR_BGR2GRAY)# 进行高斯滤波gauss_img = cv2.GaussianBlur(gray_img, (5,5), 0, 0, cv2.BORDER_DEFAULT)# 边缘检测img_edge1 = cv2.Canny(gauss_img, 100, 200)# ==================================================== ## =====================图像分割======================== ## 获取原始图像的宽和高high = img.shape[0]width = img.shape[1]# 分别初始化高和宽的和add_width = np.zeros(high, dtype = int)add_high = np.zeros(width, dtype = int)# 计算每一行的灰度图的值的和for h in range(high):for w in range(width):add_width[h] = add_width[h] + img_edge1[h][w]# 计算每一列的值的和for w in range(width):for h in range(high):add_high[w] = add_high[w] + img_edge1[h][w]# 初始化上下边界为宽度总值最大的值的索引acount_high_up = np.argmax(add_width)acount_high_down = np.argmax(add_width)# 将上边界坐标值上移,直到没有遇到白色点停止,此为数字的上边界while add_width[acount_high_up] != 0:acount_high_up = acount_high_up + 1# 将下边界坐标值下移,直到没有遇到白色点停止,此为数字的下边界while add_width[acount_high_down] != 0:acount_high_down = acount_high_down - 1# 初始化左右边界为宽度总值最大的值的索引acount_width_left = np.argmax(add_high)acount_width_right = np.argmax(add_high)# 将左边界坐标值左移,直到没有遇到白色点停止,此为数字的左边界while add_high[acount_width_left] != 0:acount_width_left = acount_width_left - 1# 将右边界坐标值右移,直到没有遇到白色点停止,此为数字的右边界while add_high[acount_width_right] != 0:acount_width_right = acount_width_right + 1# 求出宽和高的间距width_spacing = acount_width_right - acount_width_lefthigh_spacing = acount_high_up - acount_high_down# 求出宽和高的间距差poor = width_spacing - high_spacing# 将数字进行正方形分割,目的是方便之后进行图像压缩if poor > 0:tailor_image = img[acount_high_down - poor // 2 - 5:acount_high_up + poor - poor // 2 + 5, acount_width_left - 5:acount_width_right + 5]else:tailor_image = img[acount_high_down - 5:acount_high_up + 5, acount_width_left + poor // 2 - 5:acount_width_right - poor + poor // 2 + 5]# ==================================================== ## ======================小图处理======================= ## 将裁剪后的图片进行灰度化gray_img = cv2.cvtColor(tailor_image , cv2.COLOR_BGR2GRAY)# 高斯去噪gauss_img = cv2.GaussianBlur(gray_img, (5,5), 0, 0, cv2.BORDER_DEFAULT)# 将图像形状调整到28*28大小zoom_image = cv2.resize(gauss_img, (28, 28))# 获取图像的高和宽high = zoom_image.shape[0]wide = zoom_image.shape[1]# 将图像每个点的灰度值进行阈值比较for h in range(high):for w in range(wide):# 若灰度值大于100,则判断为背景并赋值0,否则将深灰度值变白处理if zoom_image[h][w] > 100:zoom_image[h][w] = 0else:zoom_image[h][w] = 255 - zoom_image[h][w]# ==================================================== #return zoom_image

在此,我在纸上写了个6,如下图所示:

然后是对图像进行分割,首先要介绍下我分割图像的方法。下面是一张进行canny边缘检测后的6:

在这里这个6有个特点,就是被白边给包围着了,因为白色的灰度值为255,黑色的灰度值为0,所以我就假设以高为很坐标,然后每个高对应着的宽的灰度值进行相加。所以会很明显发现就6这个字的整体的值比较聚集,当然有可能有零星的散点,但并不影响对6所在位置的判断。最后以高为例,得到的值的坐标图如下:

因为最大值比较容易找到,所以就找到最大值然后向两边延伸,当发现值为零时就可以把边界给标定出来了。
最后进行分割分割注意的是后面对图像进行裁剪的时候是将宽和高较长的一边减去较短的一边然后除以2平分给较短的一边的两侧,为了防止边缘检测没有包裹着数字,于是在数字四周都加了五个像素点进行裁剪,最后裁剪出来的效果如下:

这个图片就是上述代码中的tailor_image所显示出来的图片,因为显示图片的代码只作为测试使用,而且又很简单,这里就没有展示出来。
好了,接下来就是要对辛辛苦苦裁剪出来的小图进行图像进行处理了,首先还是最基本的灰度化和高斯滤波处理,然后就是对图像进行大小转换,因为MNIST数据形状就是28∗2828*2828∗28所以也要将输入图片转换成28∗2828*2828∗28的大小。大小转换完成后,就是要完成把灰度图转换成背景为0,然后数字变成白色的图片,因为这样和MNIST数据集里的数字图片特别的像。在这里我用了阈值控制的方法将背景变成黑色的。至于这100当然是将图片的灰度值打出来后观察得出来的。但是这种方法是比较危险的,因为这样的鲁棒性并不强,但后面如果要加强鲁棒性则同样可以用边缘检测把数字包裹住,然后数字之外的背景清零,这确实是一个很好的思路,但在这里就建议的用阈值控制的方法来实现背景黑化了。黑化背景后当然就是将数字白化了,之前有将数字部分都是255值,但发现识别的效果并不理想,所以这里我采用了用255-原先数字的值,这样如果原先的数字黑度深的部分就会变成白色程度深,就简单的实现了数字边缘浅,中间深的变换。最后处理得到的图像如下:

虽说看起来没有第一张图那么完美,但大概还是能达到验证数据所需的要求了。至此,数据预处理已经完成了,接下来就是激动的预测了。

预测

预测代码如下:

import torch# pretreatment.py为上面图片预处理的文件名,导入图片预处理文件
import pretreatment as PRE# 加载网络模型
net = torch.load('weight/test.pkl')# 得到返回的待预测图片值,就是pretreatment.py中的zoom_image
img = PRE.image_preprocessing()# 将待预测图片转换形状
inputs = img.reshape(-1, 784)# 输入数据转换成tensor张量类型,并转换成浮点类型
inputs = torch.from_numpy(inputs)
inputs = inputs.float()# 丢入网络进行预测,得到预测数据
predict = net(inputs)# 打印对应的最后的预测结果
print("The number in this picture is {}".format(torch.argmax(predict).detach().numpy()))

最后得到结果如图所示:

这样,整个手写数字识别基本已经完成了。

手写数字识别(识别纸上手写的数字)相关推荐

  1. 【Keras】30 秒上手 Keras+实例对mnist手写数字进行识别准确率达99%以上

    本文我们将学习使用Keras一步一步搭建一个卷积神经网络.具体来说,我们将使用卷积神经网络对手写数字(MNIST数据集)进行识别,并达到99%以上的正确率. @为什么选择Keras呢? 主要是因为简单 ...

  2. 手写数字的识别分类+技术总结

    (1)学习转载一篇关于机器学习手写数字的识别 Python 3 利用机器学习模型 进行手写体数字检测 Python 3 生成手写体数字数据集 (2)技术总结 机器学习代码实现的初级阶段,既要自己上手项 ...

  3. linux手写数字识别opencv,opencv实现KNN手写数字的识别

    人工智能是当下很热门的话题,手写识别是一个典型的应用.为了进一步了解这个领域,我阅读了大量的论文,并借助opencv完成了对28x28的数字图片(预处理后的二值图像)的识别任务. 预处理一张图片: 首 ...

  4. DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测

    DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 下边两张 ...

  5. DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测

    DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 1.10 ...

  6. DL之DNN:利用DNN【784→50→100→10】算法对MNIST手写数字图片识别数据集进行预测、模型优化

    DL之DNN:利用DNN[784→50→100→10]算法对MNIST手写数字图片识别数据集进行预测.模型优化 导读 目的是建立三层神经网络,进一步理解DNN内部的运作机制 目录 输出结果 设计思路 ...

  7. Dataset之Handwritten Digits:Handwritten Digits(手写数字图片识别)数据集简介、安装、使用方法之详细攻略

    Dataset之Handwritten Digits:Handwritten Digits(手写数字图片识别)数据集简介.安装.使用方法之详细攻略 目录 Handwritten Digits数据集的简 ...

  8. TF之NN:利用DNN算法(SGD+softmax+cross_entropy)对mnist手写数字图片识别训练集(TF自带函数下载)实现87.4%识别

    TF之NN:利用DNN算法(SGD+softmax+cross_entropy)对mnist手写数字图片识别训练集(TF自带函数下载)实现87.4%识别 目录 输出结果 代码设计 输出结果 代码设计 ...

  9. TF:基于CNN(2+1)实现MNIST手写数字图片识别准确率提高到99%

    TF:基于CNN(2+1)实现MNIST手写数字图片识别准确率提高到99% 导读 与Softmax回归模型相比,使用两层卷积的神经网络模型借助了卷积的威力,准确率高非常大的提升. 目录 输出结果 代码 ...

最新文章

  1. html资源文件记载进度条,用进度条显示文件读取进度《 HTML5:文件 API 》
  2. GitHub日收12000星,微软新命令行工具引爆程序员圈!
  3. R语言非独立多分组非参数检验、Kruskal–Wallis检验进行非独立多分组非参数检验(Nonparametric multiple comparisons)、当ANOVA不满足条件的情况下
  4. FPM傅里叶叠层衍射成像笔记
  5. python 守护程序检测进程是否存在_python创建守护进程的疑问
  6. Google C++ Testing Framework之断言
  7. mysql数据库自学_MySQL数据库自学
  8. mysql复制以及一主多从等常见集群概述
  9. mysql密码修改无效后,修改方法
  10. [HTML5]块和内联元素的嵌套
  11. HTML5超炫砸蛋抽奖源码
  12. t分布(Student t distribution)——正态分布的小样本抽样分布
  13. 4G通信简单验证(合宙Air720H)
  14. Word如何将A4纸打印成上下两部分可复写的二联单
  15. dbf文件怎么还原到oracle中,oracle dbf文件恢复数据
  16. 计算机怎么设置加密文件,怎么把电脑文件加密_怎么把文件加密-win7之家
  17. css 设置元素背景为透明
  18. 和差角证明托勒密定理
  19. linux c open flush,ctrl+c以及写操作失败和flush
  20. 计算机组成原理 day01 - 第一章

热门文章

  1. 传奇单机架设教程及游戏GM设置方法
  2. 多签名基础——General forking lemma(分叉引理)
  3. python斜杠作用_Python中正反斜杠(‘/’和‘\’)的意义与用法
  4. pdf合到一起java_将多个PDF文件合并/转换为一个PDF
  5. ssm+java+vue基于微信小程序的游泳馆管理系统#毕业设计
  6. win7如何更改计算机管理员用户名和密码,win7系统下修改administrator管理员账户密码的设置方法?...
  7. 告诉你宇宙的真相:现代观点
  8. Windows 窗口发送消息参数详解
  9. 黑苹果OC配置工具:OpenCore Configurator for Mac(2.48.0.0中文)
  10. 新整理的开源Odoo13发布更新的部分功能模块信息