手写识别的应用场景有很多,智能手机、掌上电脑的信息工具的普及,手写文字输入,机器识别感应输出;还可以用来识别银行支票,如果准确率不够高,可能会引起严重的后果。当然,手写识别也是机器学习领域的一个Hello World任务,感觉每一个初识神经网络的人,搭建的第一个项目十之八九都是它。

我们来尝试搭建下手写识别中最基础的手写数字识别,与手写识别的不同是数字识别只需要识别0-9的数字,样本数据集也只需要覆盖到绝大部分包含数字0-9的字体类型,说白了就是简单,样本特征少,难度小很多。

一、目标

预期目标:传入一张数字图片给机器,机器通过识别,最后返回给用户图片上的数字

传入图片:

机器识别输出:

二、搭建(全连接神经网络)

环境:python3.6   tensorflow1.14

工具:pycharm

数据源:来自手写数据机器视觉数据库mnist数据集,包含7万张黑底白字手写数字图片,其中55000张为训练集,5000张为验证集,10000张为测试集。每张图片大小为28*28像素,图片纯黑色像素值为0,纯白色像素值为1。数据集的标签是长度为10的一维数组,数组中的每个元素索引号表示对应数字出现的概率。

可通过input_data模块中的read_data_sets()函数直接加载mnist数据集(详情见mnist_backward.py):

from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("./data/", one_hot=True)

一、定义网络模型,神经网络的前向传播(mnist_forward.py)

import tensorflow as tfINPUT_NODE=784   # 输入节点
OUTPUT_NODE=10   # 输出节点
LAYER1_NODE=500   # 隐藏节点def get_weight(shape,regularizer):w=tf.Variable(tf.truncated_normal(shape,stddev=0.1))if regularizer !=None:tf.add_to_collection('losses',tf.contrib.layers.l2_regularizer(regularizer)(w))return wdef get_bias(shape):b=tf.Variable(tf.zeros(shape))return bdef forward(x,regularizer):w1=get_weight([INPUT_NODE,LAYER1_NODE],regularizer)b1=get_bias(LAYER1_NODE)y1=tf.nn.relu(tf.matmul(x,w1)+b1)w2=get_weight([LAYER1_NODE,OUTPUT_NODE],regularizer)b2=get_bias([OUTPUT_NODE])y=tf.matmul(y1,w2)+b2return y

这里定义了网络模型输入输出节点的个数、隐藏层节点数、同时定义get_weigt()函数实现对参数w的设置,包括参数的形状和是否正则化的标志,从输入层到隐藏层的参数w1形状为[784,500],由隐藏层到输出层的参数w2形状为[500,10]。定义get_bias()实现对偏置b的设置。由输入层到隐藏层的偏置b1形状长度为500的一维数组,由隐藏层到输出层的偏置b2形状长度为10的一维数组,初始化值为全0。

二、神经网络的反向传播(mnist_backward.py)

利用训练数据集对神经网络进行训练,通过降低损失函数值,实现网络模型参数的优化,从而得到准确率高且泛化能力强的神经网络模型。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import osbatch_size=200
learning_rate_base=0.1    # 初始学习率
learning_rate_decay=0.99    # 学习率衰减率
regularizer=0.0001    # 正则化系数
steps=50000     # 训练轮数
moving_average_decay=0.99
model_save_path="./model/"  # 模型保存路径
model_name="mnist_model"def backward(mnist):x=tf.placeholder(tf.float32,[None,mnist_forward.INPUT_NODE])y_=tf.placeholder(tf.float32,[None,mnist_forward.OUTPUT_NODE])y=mnist_forward.forward(x,regularizer)  # 调用forward()函数,设置正则化,计算yglobal_step=tf.Variable(0,trainable=False)   # 当前轮数计数器设定为不可训练类型# 调用包含所有参数正则化损失的损失函数lossce=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1))cem=tf.reduce_mean(ce)loss=cem+tf.add_n(tf.get_collection('losses'))# 设定指数衰减学习率learning_ratelearning_rate=tf.train.exponential_decay(learning_rate_base,global_step,mnist.train.num_examples/batch_size,learning_rate_decay,staircase=True)# 梯度衰减对模型优化,降低损失函数train_step=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step)# 定义参数的滑动平均ema=tf.train.ExponentialMovingAverage(moving_average_decay,global_step)ema_op=ema.apply(tf.trainable_variables())with tf.control_dependencies([train_step,ema_op]):train_op=tf.no_op(name='train')saver=tf.train.Saver()with tf.Session() as sess:init_op=tf.global_variables_initializer()   # 所有参数初始化sess.run(init_op)ckpt = tf.train.get_checkpoint_state(model_save_path)  # 加载指定路径下的滑动平均if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path)for i in range(steps):   # 循环迭代steps轮xs,ys=mnist.train.next_batch(batch_size)_,loss_value,step=sess.run([train_op,loss,global_step],feed_dict={x:xs,y_:ys})if i %1000==0:print("After %d training step(s),loss on training batch is %g."%(step,loss_value))saver.save(sess,os.path.join(model_save_path,model_name),global_step=global_step)   # 当前会话加载到指定路径if __name__=='__main__':mnist = input_data.read_data_sets("./data/", one_hot=True)backward(mnist)

反向传播中,首先定义了每轮喂入神经网络的图片数batch_size、初始学习率learning_rate_base、学习率衰减率learning_rate_decay、正则化系数regularizer、训练轮数steps、模型保存路径以及模型保存名称等相关信息。反向传播backward()函数中,先传入minist数据集,用tf.placeholder(dtype,shape)函数实现训练样本x和样本标签y_占位。y表示定义的前向传播函数forward;  tf.Variable(0,trainable=False)给当前轮数赋值,定义为不可训练类型。接着,loss表示定义的损失函数,一般为预测值与样本标签的交叉熵与正则化损失之和;train_step表示利用优化算法对模型参数进行优化,常用的优化算法有GradientDescentOptimizer、AdamOptimizer、MomentumOptimizer算法,这里使用的GradientDescentOptimizer梯度衰减算法。接着初始化saver对象,其中利用tf.global_variables_initializer()函数初始化所有模型参数,利用sess.run()函数实现模型的训练优化过程,并每隔一定轮数保存一次模型,模型训练好之后保存在ckpt中。

三、测试数据集,验证模型性能(mnist_test.py)

给神经网络模型输入测试集验证网络的准确性和泛化性(测试集和训练集是相互独立的)

# coding:utf-8
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import mnist_backwardtest_interval_secs=5     # 程序循环间隔时间5秒def test(mnist):with tf.Graph().as_default() as g:    # 复现计算图x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])y = mnist_forward.forward(x, None)# 实例化滑动平均的saver对象ema = tf.train.ExponentialMovingAverage(mnist_backward.moving_average_decay)ema_restore = ema.variables_to_restore()saver = tf.train.Saver(ema_restore)# 计算准确率correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))while True:with tf.Session() as sess:ckpt=tf.train.get_checkpoint_state(mnist_backward.model_save_path)   # 加载指定路径下的滑动平均if ckpt and ckpt.model_checkpoint_path:saver.restore(sess,ckpt.model_checkpoint_path)global_step=ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]accuracy_score=sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels})print("After %s training step(s),test accuracy= %g."%(global_step,accuracy_score))else:print('No checkpoint file found')returntime.sleep(test_interval_secs)if __name__=='__main__':mnist = input_data.read_data_sets("./data/", one_hot=True)test(mnist)

首先,制定模型测试函数test(),通过tf.placeholder()给x,y_占位,调用mnist_forward文件中的前向传播过程forward()函数计算y,mnist_backward.moving_average_decay表示滑动衰减率。在with结构中,ckpt是加载训练好的模型,如果已有ckpt模型则恢复会话、轮数等。其次,制定main()函数,加载测试数据集,调用定义好的测试函数test()就行。

通过对测试数据的预测得到准确率,从而判断出训练出的神经网络模型性能的好坏。当准确率低时,可能原因有模型需要改进,或者是训练数据量太少导致过拟合等。

运行以上三个文件,运行结果如下:

从终端显示的运行结果可以看出,随着训练轮数的增加,网络模型的损失函数值在不断降低,在测试集上的准确率也在不断提升,具有较好的泛化能力。

四、输入真实图片,输出预测结果(mnist_app.py)

任务分两个函数完成:

(1)pre_pic()函数,对手写数字图片做预处理

(2)restore_model()函数,将符合神经网络输入要求的图片喂给复现的神经网络模型,输出预测值。

# coding:utf-8
import tensorflow as tf
import mnist_forward
import mnist_backward
from PIL import Image
import numpy as npdef restore_model(testPicArr):with tf.Graph().as_default() as tg:     # 创建一个默认图x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])y = mnist_forward.forward(x, None)preValue=tf.argmax(y,1)   # 得到概率最大的预测值'''实现滑动平均模型,参数moving_average_decay用于控制模型的更新速度,训练过程会对每一个变量维护一个影子变量这个影子变量的初始值就是相应变量的初始值,每次变量更新时,影子变量随之更细'''variable_averages=tf.train.ExponentialMovingAverage(mnist_backward.moving_average_decay)variable_to_restore=variable_averages.variables_to_restore()saver=tf.train.Saver(variable_to_restore)with tf.Session() as sess:# 通过checkpoint文件定位到最新保存的模型ckpt=tf.train.get_checkpoint_state(mnist_backward.model_save_path)if ckpt and ckpt.model_checkpoint_path:saver.restore(sess,ckpt.model_checkpoint_path)preValue=sess.run(preValue,feed_dict={x:testPicArr})return preValueelse:print("no checkpoint file found")return -1# 预处理函数,包括resize,转变灰度图,二值化操作等
def pre_pic(picName):img=Image.open(picName)reIm=img.resize((28,28),Image.ANTIALIAS)im_arr=np.array(reIm.convert('L'))threshold=50     # 设定合理的阈值for i in range(28):for j in range(28):im_arr[i][j]=255-im_arr[i][j]   # 模型要求黑底白字,输入图为白底黑字,对每个像素点的值改为255-原值=互补的反色if (im_arr[i][j]<threshold):im_arr[i][j]=0else:im_arr[i][j]=255nm_arr=im_arr.reshape([1,784])   # 1行784列nm_arr=nm_arr.astype(np.float32)img_ready=np.multiply(nm_arr,1.0/255.0)   # 从0-255之间的数变为0-1之间的浮点数return img_readyif __name__=='__main__':testNum=int(input("input the number of test pictures:"))for i in range(testNum):testPic=input("the path of test picture:")testPicArr=pre_pic(testPic)preValue=restore_model(testPicArr)print("the prediction number is",preValue)

在pre_pic()函数中,网络要求输入是28*28像素点的值,先将图片尺寸resize,模型要求的是黑底白字,但输入的图是白底黑字,则每个像素点的值改为255减去原值得到互补的反色。再对图片做二值化处理,这样可以滤掉噪声。nm_arr把图片拉成1行784列,并把值变为浮点数。restore_model()函数,计算输出y,网络输出的是一个一维数组(10个可能性概率),数组中最大的那个元素所对应的索引号就是预测的结果。

运行mnist_app.py文件,结果如下:

先输入需要识别的图片number数,然后传入图片路径,最后返回识别结果。我们传入的图片2.jpg,5.jpg如下所示:

         

预测结果也是2,5,说明模型还可以。但是,前面我们也提到过,如果数字识别用来识别银行支票97%的准确率不算高,然后卷积神经网络就开始大放异彩了...........................

最后,本人微信公众号:

放心,不用出钱也不用报班,只是单纯的想多两个粉丝罢了。

利用Tensorflow实现手写数字识别(附python代码)相关推荐

  1. 【手写数字识别】基于Lenet网络实现手写数字识别附matlab代码

    1 内容介绍 当今社会,人工智能得到快速发展,而模式识 别作为人工智能的一个重要应用领域也得到了飞 速发展,它利用计算机通过计算的方法根据样本的 特征对样本进行分类,其中的光学字符识别技术受 到广大研 ...

  2. OpenCV+TensorFlow图片手写数字识别(附源码)

    初次接触TensorFlow,而手写数字训练识别是其最基本的入门教程,网上关于训练的教程很多,但是模型的测试大多都是官方提供的一些素材,能不能自己随便写一串数字让机器识别出来呢?纸上得来终觉浅,带着这 ...

  3. 图像识别:利用KNN实现手写数字识别(mnist数据集)

    图像识别:利用KNN实现手写数字识别(mnist数据集) 步骤: 1.数据的加载(trainSize和testSize不要设置的太大) 2.k值的设定(不宜过大) 3.KNN的核心:距离的计算 4.k ...

  4. 基于tensorflow的手写数字识别

    基于tensorflow的手写数字识别 数据准备 引入包 加载数据 查看数据信息 查看一张图片 数据预处理 搭建网络模型 模型的预测与评价 模型的展示 对一张图片进行预测 准确率 数据准备 引入包 i ...

  5. 利用CNN进行手写数字识别

    资源下载地址:https://download.csdn.net/download/sheziqiong/85884967 资源下载地址:https://download.csdn.net/downl ...

  6. 实战六:手把手教你用TensorFlow进行手写数字识别

    手把手教你用TensorFlow进行手写数字识别 github下载地址 目录 手写体数字MNIST数据集介绍 MNIST Softmax网络介绍 实战MNIST Softmax网络 MNIST CNN ...

  7. 基于深度学习的手写数字识别、python实现

    基于深度学习的手写数字识别.python实现 一.what is 深度学习 二.加深层可以减少网络的参数数量 三.深度学习的手写数字识别 一.what is 深度学习 深度学习是加深了层的深度神经网络 ...

  8. 教你用TensorFlow实现手写数字识别

    弱者用泪水安慰自己,强者用汗水磨练自己. 这段时间因为项目中有一块需要用到图像识别,最近就一直在炼丹,宝宝心里苦,但是宝宝不说... 能点开这篇文章的朋友估计也已经对TensorFlow有了一定了解, ...

  9. 利用机器学习进行手写数字识别

    本次案例中,我们的目标是从数万个手写图像的数据集中正确识别数字. 数据介绍: 数据文件 train.csv 和 test.csv 包含从 0 到 9 的手绘数字的灰度图像. 每个图像高 28 像素,宽 ...

最新文章

  1. linux新内核的freeze框架以及意义
  2. kibana 更新 索引模式_elasticsearch – 如何在kibana中自动配置索引模式
  3. web page web form php,Web Pages Razor
  4. 创新创业计划书_创践——大学生创新创业实务 ——如何撰写一份优秀的商业计划书...
  5. 【转】Microsoft Cloud全新认证体系介绍
  6. android fragment学习6--FragmentTabHost底部布局
  7. Spark中的数据本地性
  8. c++成员声明中的非法限定名_new 一个对象有哪两个过程?很多人在面试中都问住了...
  9. GRE阅读高频机经原文及答案之Design-Engineering
  10. aliplayer阿里云播放器直播及录播前端代码
  11. 万网域名是否注册批量查询工具
  12. OSChina 周三乱弹 —— 我在 if 里,你却在 else
  13. 如何正确地卸载Service Worker?
  14. 电商设计师(美工)必备的素材网站!
  15. Uncaught TypeError: Cannot read property ‘length‘ of null解决经验贴
  16. Python pandas库|任凭弱水三千,我只取一瓢饮(4)
  17. 用 JS 进行 Base64 编码、解码
  18. GIT如何设置只提交文件夹或者目录,而忽略内容?
  19. TJPU-32 分解质因数
  20. 背景建模技术(四):视频分析(VideoAnalysis)模块

热门文章

  1. HDOJ-2614 Beat bfs广搜
  2. 超级玛丽游戏python实现
  3. 计算机中的数据存储与PTA
  4. c语言notify头文件,SendNotifyMessage()函数
  5. mysql timestamp 类型_MySQL timestamp类型
  6. 解决RabbitMQ保错 Error: unable to connect to node rabbit@localhost: nodedown
  7. Intranet+Intranet QA-11/20 游记
  8. sqlmap问题及解决办法
  9. IE8展示SVG图像问题处理
  10. 演讲稿丨钟义信 弘扬 Simon的源头创新精神开拓“AI”的新理念新路径