一、断点续训

为防止突然断电、参数白跑的情况发生,在backward中加入类似于之前test中加载ckpt的操作,给所有w和b赋保存在ckpt中的值:
1. 如果存储断点文件的目录文件夹中,包含有效断点状态文件,则返回该文件:

  • 参数说明
    checkpoint_dir: 表示存储断点文件的目录
    latest_filename: 断点文件的可选名称,默认为checkpoint
ckpt = tf.train.get_checkpoint_state(checkpoint_dir,\latest_filename = None)

2. 如果ckpt存在,且保存的模型在指定路径中存在

 if ckpt and ckpt.model_checkpoint_path:

3. 恢复当前会话,将ckpt中的值赋给 w 和 b

  • 参数说明

sess:表示当前会话,之前保存的结果会被加载入这个会话
ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,因为有了位置会自动去查看checkpoint文件,看最新的模型叫什么

 saver.restore(sess, ckpt.model_checkpoint_path)

4. 完整代码:

# 断点续训 breakpoint_continue.py
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
# 恢复当前会话,将ckpt中的值赋给 w 和 b        saver.restore(sess, ckpt.model_checkpoint_path)

反向传播的with结构中加入加载ckpt的操作后:

# mnist_backward.py
# coding: utf-8import tensorflow as tf
# 导入imput_data模块
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os# 定义超参数
BATCH_SIZE = 200LEARNING_RATE_BASE = 0.1 #初始学习率
LEARNING_RATE_DECAY = 0.99 # 学习率衰减率
REGULARIZER = 0.0001 # 正则化参数STEPS = 50000 #训练轮数MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "./model/"
MODEL_NAME = "mnist_model"
def backward(mnist):# placeholder占位x = tf.placeholder(tf.float32, shape = (None, mnist_forward.INPUT_NODE))y_ = tf.placeholder(tf.float32, shape = (None, mnist_forward.OUTPUT_NODE))# 前向传播推测输出yy = mnist_forward.forward(x, REGULARIZER)# 定义global_step轮数计数器,定义为不可训练global_step = tf.Variable(0, trainable = False)# 包含正则化的损失函数# 交叉熵ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits = y, \labels = tf.argmax(y_, 1))cem = tf.reduce_mean(ce)# 使用正则化时的损失函数loss = cem + tf.add_n(tf.get_collection('losses'))# 定义指数衰减学习率learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step, mnist.train.num_examples/BATCH_SIZE,LEARNING_RATE_DECAY,staircase = True)# 定义反向传播方法:包含正则化train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,\global_step = global_step)# 定义滑动平均时,加上:ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)ema_op = ema.apply(tf.trainable_variables())with tf.control_dependencies([train_step, ema_op]):train_op = tf.no_op(name = 'train')# 实例化saversaver = tf.train.Saver()# 训练过程with tf.Session() as sess:# 初始化所有参数init_op = tf.global_variables_initializer()sess.run(init_op)# 断点续训 breakpoint_continue.pyckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)if ckpt and ckpt.model_checkpoint_path: # 恢复当前会话,将ckpt中的值赋给 w 和 b     saver.restore(sess, ckpt.model_checkpoint_path) # 循环迭代for i in range(STEPS):# 将训练集中一定batchsize的数据和标签赋给左边的变量xs, ys = mnist.train.next_batch(BATCH_SIZE)# 喂入神经网络,执行训练过程train_step_, loss_value, step = sess.run([train_op, loss, global_step], \feed_dict = {x: xs, y_: ys})if i % 1000 == 0: # 拼接成./MODEL_SAVE_PATH/MODEL_NAME-global_step路径# 打印提示print("after %d steps, loss on traing batch is %g" %(step, loss_value))saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), \global_step = global_step)def main():mnist = input_data.read_data_sets('./data/', one_hot = True)# 调用定义好的测试函数backward(mnist)
# 判断python运行文件是否为主文件,如果是,则执行
if __name__ == '__main__':main()

结果发现模型自动接着之前开机的结束的50000次开始往后训练了,在两次ctrl+C之后,再重新执行时仍可以从断点继续:


二、如何对输入的手写数字图片,输出正确预测结果

  • 除了minist_forward, mnist_backward, mnist_test之外,增加mnist_app.py一个py文件

自己遇到的问题之
(一)main函数没有写对

if __name__ == '__main__':main()

写成了

if __name__ == 'main':main()

结果代码根本跑不出结果!!!
(二)input从控制台读入返回的是str型!!!
参见博客https://blog.csdn.net/qq_41151066/article/details/81745352
所以导致输入图片张数时出现错误:

TypeError: 'str' object cannot be interpreted as an integer

(三)然后又发现这个错误

NameError: name 'raw_input' is not defined

好家伙,参考文章https://blog.csdn.net/hochean_/article/details/79582627
raw_input改成 input
最终代码改成:

# mnist_app.py
# coding:utf-8
import tensorflow as tf
import numpy as np
from PIL import Image
import mnist_forward
import mnist_backwarddef restore_model(testPicArr):# 重现计算图with tf.Graph().as_default() as tg:x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])y = mnist_forward.forward(x, None)preValue = tf.argmax(y, 1) # y 的最大值对应的列表索引号# 实例化带有滑动平均值的savervariable_averages = tf.train.ExponentialMovingAverage(\mnist_backward.MOVING_AVERAGE_DECAY)variables_to_restore = variable_averages.variables_to_restore()saver = tf.train.Saver(variables_to_restore)# 用with结构加载ckptwith tf.Session() as sess:ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)# 如果ckpt存在,恢复ckpt的参数和信息到当前会话if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path)# 把刚刚准备好的图片喂入网络,执行预测操作preValue = sess.run(preValue, feed_dict = {x: testPicArr})return preValueelse:print("No checkpoint file found!")return -1def pre_pic(picName):# 打开图片img = Image.open(picName)# 用消除锯齿的方法resize图片尺寸reIm = img.resize((28, 28), Image.ANTIALIAS)# 转化成灰度图,并转化成矩阵im_arr = np.array(reIm.convert('L'))# 二值化阈值threshold = 50# 模型要求黑底白字,故需要进行反色for i in range(28):for j in range(28):im_arr[i][j] = 255 - im_arr[i][j]# 二值化,过滤噪声,留下主要特征if(im_arr[i][j] < threshold):im_arr[i][j] = 0else: im_arr[i][j] = 255# 整理矩阵形状nm_arr = im_arr.reshape([1, 784])# 由于模型要求是浮点数,先改为浮点型nm_arr = nm_arr.astype(np.float32)# 0到255浮点转化成0到1浮点img_ready = np.multiply(nm_arr, 1.0/255.0)# 返回预处理完的图片return img_readydef application():# 输入要识别的图片数目 # input从控制台读入返回的是str型!!!testNum = int(input("Input the number of test pictures:") )for i in range(testNum):# 给出识别图片的路径 # raw_input从控制台读入字符串testPic = input("The path of test pictures:") # 接收的图片需进行预处理testPicArr = pre_pic(testPic)# 把整理好的图片喂入神经网络preValue = restore_model(testPicArr)# 输出预测结果print("The prediction number is :", preValue)# 程序从main函数开始执行
def main():# 调用application函数application()if __name__ == '__main__':main()

and then______
我的人工智障程序识别我画的没有封口的0的结果是2,超难过:

又画了一张数字1:

是不是没有训练好,明天再试试。


【接上】2018.11.15
今天重新写了一个2,并且图片改成500*500像素的图片,而不是之前的长宽不一的,如下图:

然后识别结果就正确了:

于是接着写了剩下的9个数字,结果如下:

然后改了一下数字,最终,原谅我只能识别6和9之外的8个数字:


【接上】2018.11.18更新
从 https://github.com/cj0012/AI-Practice-Tensorflow-Notes/tree/master/pic 下载了图片,进行识别,结果全都可以识别出来:

【注】内容来自mooc人工智能实践第六讲

TensorFlow神经网络(五)输入手写数字图片进行识别相关推荐

  1. 《人工智能实践:Tensorflow笔记》听课笔记22_6.1输入手写数字图片输出识别结果

    附:课程链接 第六讲.全连接网络实践 6.1输入手写数字图片输出识别结果 由于个人使用Win7系统,并未完全按照课程所讲,以下记录的也基本是我的结合课程做的Windows系统+PyCharm操作.且本 ...

  2. DL之LiRDNNCNN:利用LiR、DNN、CNN算法对MNIST手写数字图片(csv)识别数据集实现(10)分类预测

    DL之LiR&DNN&CNN:利用LiR.DNN.CNN算法对MNIST手写数字图片(csv)识别数据集实现(10)分类预测 目录 输出结果 设计思路 核心代码 输出结果 数据集:Da ...

  3. PyTorch之实现LeNet-5卷积神经网络对mnist手写数字图片进行分类

    论文:Gradient-based learning applied to document recognition 简单介绍 意义: 对手写数据集进行识别,对后续卷积网络的发展起到了奠基作用 特点: ...

  4. Tensorflow学习教程------模型参数和网络结构保存且载入,输入一张手写数字图片判断是几...

    首先是模型参数和网络结构的保存 #coding:utf-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist impor ...

  5. TensorFlow 2.0 mnist手写数字识别(CNN卷积神经网络)

    TensorFlow 2.0 (五) - mnist手写数字识别(CNN卷积神经网络) 源代码/数据集已上传到 Github - tensorflow-tutorial-samples 大白话讲解卷积 ...

  6. DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测

    DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 下边两张 ...

  7. TF之LoR:基于tensorflow利用逻辑回归算LoR法实现手写数字图片识别提高准确率

    TF之LoR:基于tensorflow利用逻辑回归算LoR法实现手写数字图片识别提高准确率 目录 输出结果 设计代码 输出结果 设计代码 #TF之LoR:基于tensorflow实现手写数字图片识别准 ...

  8. TF之LiR:基于tensorflow实现手写数字图片识别准确率

    TF之LiR:基于tensorflow实现手写数字图片识别准确率 目录 输出结果 代码设计 输出结果 Extracting MNIST_data\train-images-idx3-ubyte.gz ...

  9. 02:一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)

    标签(空格分隔): 王小草Tensorflow笔记 笔记整理者:王小草 笔记整理时间2017年2月24日 Tensorflow官方英文文档地址:https://www.tensorflow.org/g ...

最新文章

  1. MySQL开启federated引擎实现数据库表映射
  2. shiro教程:整合ehcache缓存
  3. 最短路径(Dijkstra、Bellman-Ford和SPFA算法)
  4. 今天,在苏州落户了.
  5. 设计模式之Builder模式
  6. Oracle数据库之导入导出
  7. 定时语音提醒软件实现
  8. COGS2259 异化多肽
  9. learn git branching 重新开始
  10. EXCEL学生成绩里计算年级名次、班级名次
  11. 如何进行旅游app开发定制
  12. [机器学习笔记] 用Python分析 TED演讲数据(更新中)
  13. word如何设置上标形式_word怎样设置上标
  14. php邀请码插件,织梦DedeCMS的会员邀请码注册插件 后台可生成邀请码
  15. 年轻输得起,蓝桥杯明年我要拿国一
  16. 大一下期计算机考试试题操作题,2016年大一计算机考试操作题
  17. 使用RXTXcomm进行串口通信
  18. 广州交警发布科目三路考秘笈
  19. 发动机测试人员如何无需取样就能精准把握机油质量?
  20. cpu的外频,内频,超频

热门文章

  1. 2018.11.04 洛谷P1081 开车旅行(倍增)
  2. ADO.NET 2.0 的并行控制与数据存取冲突侦测
  3. 7.Springcloud的Ribbon的自定义算法实现
  4. Linux万兆网络配置
  5. Spring MVC数据验证简介
  6. 03-04 元素定位工具
  7. BurpSuite下载CA证书
  8. python01_Python编码环境安装与基本语法
  9. c++ 从double变为long int 数据丢失_面试官:Java 中有几种基本数据类型是什么?各自占用多少字节?...
  10. raster | R语言中的空间栅格对象及其基本处理方法(Ⅳ):数据聚合、重采样