TensorFlow神经网络(五)输入手写数字图片进行识别
一、断点续训
为防止突然断电、参数白跑的情况发生,在backward中加入类似于之前test中加载ckpt
的操作,给所有w和b赋保存在ckpt中的值:
1. 如果存储断点文件的目录文件夹中,包含有效断点状态文件,则返回该文件:
- 参数说明
checkpoint_dir
: 表示存储断点文件的目录
latest_filename
: 断点文件的可选名称,默认为checkpoint
ckpt = tf.train.get_checkpoint_state(checkpoint_dir,\latest_filename = None)
2. 如果ckpt存在,且保存的模型在指定路径中存在
if ckpt and ckpt.model_checkpoint_path:
3. 恢复当前会话,将ckpt中的值赋给 w 和 b
- 参数说明
sess
:表示当前会话,之前保存的结果会被加载入这个会话
ckpt.model_checkpoint_path
:表示模型存储的位置,不需要提供模型的名字,因为有了位置会自动去查看checkpoint文件,看最新的模型叫什么
saver.restore(sess, ckpt.model_checkpoint_path)
4. 完整代码:
# 断点续训 breakpoint_continue.py
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
# 恢复当前会话,将ckpt中的值赋给 w 和 b saver.restore(sess, ckpt.model_checkpoint_path)
在反向传播的with结构中加入加载ckpt的操作后:
# mnist_backward.py
# coding: utf-8import tensorflow as tf
# 导入imput_data模块
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os# 定义超参数
BATCH_SIZE = 200LEARNING_RATE_BASE = 0.1 #初始学习率
LEARNING_RATE_DECAY = 0.99 # 学习率衰减率
REGULARIZER = 0.0001 # 正则化参数STEPS = 50000 #训练轮数MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "./model/"
MODEL_NAME = "mnist_model"
def backward(mnist):# placeholder占位x = tf.placeholder(tf.float32, shape = (None, mnist_forward.INPUT_NODE))y_ = tf.placeholder(tf.float32, shape = (None, mnist_forward.OUTPUT_NODE))# 前向传播推测输出yy = mnist_forward.forward(x, REGULARIZER)# 定义global_step轮数计数器,定义为不可训练global_step = tf.Variable(0, trainable = False)# 包含正则化的损失函数# 交叉熵ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits = y, \labels = tf.argmax(y_, 1))cem = tf.reduce_mean(ce)# 使用正则化时的损失函数loss = cem + tf.add_n(tf.get_collection('losses'))# 定义指数衰减学习率learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step, mnist.train.num_examples/BATCH_SIZE,LEARNING_RATE_DECAY,staircase = True)# 定义反向传播方法:包含正则化train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,\global_step = global_step)# 定义滑动平均时,加上:ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)ema_op = ema.apply(tf.trainable_variables())with tf.control_dependencies([train_step, ema_op]):train_op = tf.no_op(name = 'train')# 实例化saversaver = tf.train.Saver()# 训练过程with tf.Session() as sess:# 初始化所有参数init_op = tf.global_variables_initializer()sess.run(init_op)# 断点续训 breakpoint_continue.pyckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)if ckpt and ckpt.model_checkpoint_path: # 恢复当前会话,将ckpt中的值赋给 w 和 b saver.restore(sess, ckpt.model_checkpoint_path) # 循环迭代for i in range(STEPS):# 将训练集中一定batchsize的数据和标签赋给左边的变量xs, ys = mnist.train.next_batch(BATCH_SIZE)# 喂入神经网络,执行训练过程train_step_, loss_value, step = sess.run([train_op, loss, global_step], \feed_dict = {x: xs, y_: ys})if i % 1000 == 0: # 拼接成./MODEL_SAVE_PATH/MODEL_NAME-global_step路径# 打印提示print("after %d steps, loss on traing batch is %g" %(step, loss_value))saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), \global_step = global_step)def main():mnist = input_data.read_data_sets('./data/', one_hot = True)# 调用定义好的测试函数backward(mnist)
# 判断python运行文件是否为主文件,如果是,则执行
if __name__ == '__main__':main()
结果发现模型自动接着之前开机的结束的50000次开始往后训练了,在两次ctrl+C之后,再重新执行时仍可以从断点继续:
二、如何对输入的手写数字图片,输出正确预测结果
- 除了
minist_forward, mnist_backward, mnist_test
之外,增加mnist_app.py
一个py文件
自己遇到的问题之
(一)main函数没有写对
把
if __name__ == '__main__':main()
写成了
if __name__ == 'main':main()
结果代码根本跑不出结果!!!
(二)input从控制台读入返回的是str型!!!
参见博客https://blog.csdn.net/qq_41151066/article/details/81745352
所以导致输入图片张数时出现错误:
TypeError: 'str' object cannot be interpreted as an integer
(三)然后又发现这个错误
NameError: name 'raw_input' is not defined
好家伙,参考文章https://blog.csdn.net/hochean_/article/details/79582627
把raw_input
改成 input
最终代码改成:
# mnist_app.py
# coding:utf-8
import tensorflow as tf
import numpy as np
from PIL import Image
import mnist_forward
import mnist_backwarddef restore_model(testPicArr):# 重现计算图with tf.Graph().as_default() as tg:x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])y = mnist_forward.forward(x, None)preValue = tf.argmax(y, 1) # y 的最大值对应的列表索引号# 实例化带有滑动平均值的savervariable_averages = tf.train.ExponentialMovingAverage(\mnist_backward.MOVING_AVERAGE_DECAY)variables_to_restore = variable_averages.variables_to_restore()saver = tf.train.Saver(variables_to_restore)# 用with结构加载ckptwith tf.Session() as sess:ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)# 如果ckpt存在,恢复ckpt的参数和信息到当前会话if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path)# 把刚刚准备好的图片喂入网络,执行预测操作preValue = sess.run(preValue, feed_dict = {x: testPicArr})return preValueelse:print("No checkpoint file found!")return -1def pre_pic(picName):# 打开图片img = Image.open(picName)# 用消除锯齿的方法resize图片尺寸reIm = img.resize((28, 28), Image.ANTIALIAS)# 转化成灰度图,并转化成矩阵im_arr = np.array(reIm.convert('L'))# 二值化阈值threshold = 50# 模型要求黑底白字,故需要进行反色for i in range(28):for j in range(28):im_arr[i][j] = 255 - im_arr[i][j]# 二值化,过滤噪声,留下主要特征if(im_arr[i][j] < threshold):im_arr[i][j] = 0else: im_arr[i][j] = 255# 整理矩阵形状nm_arr = im_arr.reshape([1, 784])# 由于模型要求是浮点数,先改为浮点型nm_arr = nm_arr.astype(np.float32)# 0到255浮点转化成0到1浮点img_ready = np.multiply(nm_arr, 1.0/255.0)# 返回预处理完的图片return img_readydef application():# 输入要识别的图片数目 # input从控制台读入返回的是str型!!!testNum = int(input("Input the number of test pictures:") )for i in range(testNum):# 给出识别图片的路径 # raw_input从控制台读入字符串testPic = input("The path of test pictures:") # 接收的图片需进行预处理testPicArr = pre_pic(testPic)# 把整理好的图片喂入神经网络preValue = restore_model(testPicArr)# 输出预测结果print("The prediction number is :", preValue)# 程序从main函数开始执行
def main():# 调用application函数application()if __name__ == '__main__':main()
and then______
我的人工智障程序识别我画的没有封口的0的结果是2,超难过:
又画了一张数字1:
是不是没有训练好,明天再试试。
【接上】2018.11.15
今天重新写了一个2,并且图片改成500*500像素的图片,而不是之前的长宽不一的,如下图:
然后识别结果就正确了:
于是接着写了剩下的9个数字,结果如下:
然后改了一下数字,最终,原谅我只能识别6和9之外的8个数字:
【接上】2018.11.18更新
从 https://github.com/cj0012/AI-Practice-Tensorflow-Notes/tree/master/pic 下载了图片,进行识别,结果全都可以识别出来:
【注】内容来自mooc人工智能实践第六讲
TensorFlow神经网络(五)输入手写数字图片进行识别相关推荐
- 《人工智能实践:Tensorflow笔记》听课笔记22_6.1输入手写数字图片输出识别结果
附:课程链接 第六讲.全连接网络实践 6.1输入手写数字图片输出识别结果 由于个人使用Win7系统,并未完全按照课程所讲,以下记录的也基本是我的结合课程做的Windows系统+PyCharm操作.且本 ...
- DL之LiRDNNCNN:利用LiR、DNN、CNN算法对MNIST手写数字图片(csv)识别数据集实现(10)分类预测
DL之LiR&DNN&CNN:利用LiR.DNN.CNN算法对MNIST手写数字图片(csv)识别数据集实现(10)分类预测 目录 输出结果 设计思路 核心代码 输出结果 数据集:Da ...
- PyTorch之实现LeNet-5卷积神经网络对mnist手写数字图片进行分类
论文:Gradient-based learning applied to document recognition 简单介绍 意义: 对手写数据集进行识别,对后续卷积网络的发展起到了奠基作用 特点: ...
- Tensorflow学习教程------模型参数和网络结构保存且载入,输入一张手写数字图片判断是几...
首先是模型参数和网络结构的保存 #coding:utf-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist impor ...
- TensorFlow 2.0 mnist手写数字识别(CNN卷积神经网络)
TensorFlow 2.0 (五) - mnist手写数字识别(CNN卷积神经网络) 源代码/数据集已上传到 Github - tensorflow-tutorial-samples 大白话讲解卷积 ...
- DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测
DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 下边两张 ...
- TF之LoR:基于tensorflow利用逻辑回归算LoR法实现手写数字图片识别提高准确率
TF之LoR:基于tensorflow利用逻辑回归算LoR法实现手写数字图片识别提高准确率 目录 输出结果 设计代码 输出结果 设计代码 #TF之LoR:基于tensorflow实现手写数字图片识别准 ...
- TF之LiR:基于tensorflow实现手写数字图片识别准确率
TF之LiR:基于tensorflow实现手写数字图片识别准确率 目录 输出结果 代码设计 输出结果 Extracting MNIST_data\train-images-idx3-ubyte.gz ...
- 02:一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)
标签(空格分隔): 王小草Tensorflow笔记 笔记整理者:王小草 笔记整理时间2017年2月24日 Tensorflow官方英文文档地址:https://www.tensorflow.org/g ...
最新文章
- MySQL开启federated引擎实现数据库表映射
- shiro教程:整合ehcache缓存
- 最短路径(Dijkstra、Bellman-Ford和SPFA算法)
- 今天,在苏州落户了.
- 设计模式之Builder模式
- Oracle数据库之导入导出
- 定时语音提醒软件实现
- COGS2259 异化多肽
- learn git branching 重新开始
- EXCEL学生成绩里计算年级名次、班级名次
- 如何进行旅游app开发定制
- [机器学习笔记] 用Python分析 TED演讲数据(更新中)
- word如何设置上标形式_word怎样设置上标
- php邀请码插件,织梦DedeCMS的会员邀请码注册插件 后台可生成邀请码
- 年轻输得起,蓝桥杯明年我要拿国一
- 大一下期计算机考试试题操作题,2016年大一计算机考试操作题
- 使用RXTXcomm进行串口通信
- 广州交警发布科目三路考秘笈
- 发动机测试人员如何无需取样就能精准把握机油质量?
- cpu的外频,内频,超频
热门文章
- 2018.11.04 洛谷P1081 开车旅行(倍增)
- ADO.NET 2.0 的并行控制与数据存取冲突侦测
- 7.Springcloud的Ribbon的自定义算法实现
- Linux万兆网络配置
- Spring MVC数据验证简介
- 03-04 元素定位工具
- BurpSuite下载CA证书
- python01_Python编码环境安装与基本语法
- c++ 从double变为long int 数据丢失_面试官:Java 中有几种基本数据类型是什么?各自占用多少字节?...
- raster | R语言中的空间栅格对象及其基本处理方法(Ⅳ):数据聚合、重采样