文章目录

  • 模型保存
  • 模型读取
  • 测试模型
    • 搭建测试模型
  • 使用模型
  • 模型可视化

本文是在上一篇文章 《深度学习之TensorFlow》reading notes(2)—— MNIST手写数字识别的基础上写的,主要内容是进一步实现对模型的测试、保存和模型读取使用。

模型保存

先上代码:

import...
#建立模型...
#配置参数...
saver = th.train.Saver()
model_path = "log/521model.ckpt"with tf.Session as sess:...#这里是初始化和训练模型的过程#保存模型并打印save_path = saver.save(sess, model_path)print("Model saved in file: %s" % save_path)

只需要添加四句话就可以实现对训练模型的保存。好像也没啥需要解释的……

模型读取

import tensorflow as tf #导入tensorflow库
import pylab
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)tf.reset_default_graph()
# tf Graph Input
x = tf.placeholder(tf.float32, [None, 784])  # mnist data维度 28*28=784
y = tf.placeholder(tf.float32, [None, 10])  # 0-9 数字=> 10 classes
# Set model weights
W = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))model_path = "log/521model.ckpt"#读取模型
print("Starting 2nd session...")
saver = tf.train.Saver()
with tf.Session() as sess:# Initialize variables
#    sess.run(tf.global_variables_initializer())# Restore model weights from previously saved modelsaver.restore(sess, model_path)

读取模型参数前,一样需要将模型中的张量重新定义一遍,其实保存的模型中的参数值和模型结构,也就是W和b的值和向前传播的结构。用

saver.restore(sess, model_path)

就可以实现对模型的读取了,下面我们测试读取得到的模型,以及使用模型对一张手写图进行判断。

测试模型

搭建测试模型

首先,需要调用测试数据集中的数据,输入到模型中,看模型预测的结果与数据集的标签是否一致,如果一致则返回true,不一致则返回false。最后统计所有true的个数除以总数,即为模型准确率。

import tensorflow as tf  # 导入tensorflow库
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)tf.reset_default_graph()
# tf Graph Input
x = tf.placeholder(tf.float32, [None, 784])  # mnist data维度 28*28=784
y = tf.placeholder(tf.float32, [None, 10])  # 0-9 数字=> 10 classes
# Set model weights
W = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))
# 构建测试模型
s1 = tf.matmul(x, W) + b
pred = tf.nn.softmax(s1)
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
model_path = "log/521model.ckpt"
saver = tf.train.Saver()with tf.Session() as sess:saver.restore(sess, model_path)Accaa = sess.run(accuracy,feed_dict={x: mnist.test.images, y: mnist.test.labels})print ("Accuracy:", Accaa)

首先,依旧是用模型求取pred值,也就是预测值。之后一句:

correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))

其中tf.argmax()函数是用来检索最大值的,即得到最大值位置。由于y是onehot编码,所以,这个索引对应上即是预测正确,再用tf.equal()函数来确定其是否相等,就能得到正确的情况。
在通过下一句求平均,其实就是计数正确的个数,再除以总数。就可以得到准确率了。
运行结果:

runfile('E:/mnist_1/mnist_test.py', wdir='E:/mnist_1')
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
INFO:tensorflow:Restoring parameters from log/521model.ckpt
Accuracy: 0.8587

可以看到,判断的准确率还可以,达到了85.87%,这其实是我迭代了50次的结果,25次通常维持在82%~83%左右。再迭代效果不明显了。

使用模型

使用模型在读取模型之后,选用test数据集中的数据:mnist.test.next_batch(num)其中num为调用多少个数据进行测试。sess.run中将测试数据作为输入,得到模型预测的output和预测概率pred。最后,将预测值、预测概率、标签值和图形都进行输出。

import tensorflow as tf #导入tensorflow库
import pylab
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)tf.reset_default_graph()
# tf Graph Input
x = tf.placeholder(tf.float32, [None, 784])  # mnist data维度 28*28=784# Set model weights
W = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))# 构建模型
s1 = tf.matmul(x, W) + b
pred = tf.nn.softmax(s1)
output = tf.argmax(pred, 1)model_path = "log/521model.ckpt"num11 = 5
#读取模型
print("Starting 2nd session...")
saver = tf.train.Saver()
with tf.Session() as sess:# Initialize variables
#    sess.run(tf.global_variables_initializer())# Restore model weights from previously saved modelsaver.restore(sess, model_path)batch_xs, batch_ys = mnist.test.next_batch(num11)outputval,predv = sess.run([output,pred], feed_dict={x: batch_xs})for i in range(num11):print(outputval[i],predv[i,outputval[i]],batch_ys[i])im = batch_xs[i]im = im.reshape(-1,28)pylab.imshow(im)pylab.show()

输入结果:这里我取了两个0作为比较,可以看到第一个〇,预测概率为100%,第二个〇,预测概率为84%,从实际图中可以看出区别。

模型可视化

下次再说,玩一会自走棋去~

《深度学习之TensorFlow》reading notes(3)—— MNIST手写数字识别之二相关推荐

  1. 深度学习案例之基于 CNN 的 MNIST 手写数字识别

    一.模型结构 本文只涉及利用Tensorflow实现CNN的手写数字识别,CNN的内容请参考:卷积神经网络(CNN) MNIST数据集的格式与数据预处理代码input_data.py的讲解请参考 :T ...

  2. 深度学习入门实例——基于keras的mnist手写数字识别

    本文介绍了利用keras做mnist数据集的手写数字识别. 参考网址 http://www.cnblogs.com/lc1217/p/7132364.html mnist数据集中的图片为28*28的单 ...

  3. python cnn代码详解图解_基于TensorFlow的CNN实现Mnist手写数字识别

    本文实例为大家分享了基于TensorFlow的CNN实现Mnist手写数字识别的具体代码,供大家参考,具体内容如下 一.CNN模型结构 输入层:Mnist数据集(28*28) 第一层卷积:感受视野5* ...

  4. 深度学习笔记:01快速构建一个手写数字识别系统以及张量的概念

    深度学习笔记:01快速构建一个手写数字识别系统 神经网络代码最好运行在GPU中,但是对于初学者来说运行在GPU上成本太高了,所以先运行在CPU中,就是慢一些. 一.安装keras框架 使用管理员模式打 ...

  5. 深度学习入门项目:PyTorch实现MINST手写数字识别

    完整代码下载[github地址]:https://github.com/lmn-ning/MNIST_PyTorch.git 目录 一.MNIST数据集介绍及下载地址 二.代码结构 三.代码 data ...

  6. 【FPGA教程案例100】深度学习1——基于CNN卷积神经网络的手写数字识别纯Verilog实现,使用mnist手写数字数据库

    FPGA教程目录 MATLAB教程目录 ---------------------------------------- 目录 1.软件版本 2.CNN卷积神经网络的原理 2.1 mnist手写数字数 ...

  7. 深度学习之基于CNN实现汉字版手写数字识别(Chinese-Mnist)

    Mnist数据集是深度学习入门的数据集,昨天发现了Chinese-Mnist数据集,与Mnist数据集类似,只不过是汉字数字,例如'一'.'二'.'三'等,本次实验利用自己搭建的CNN网络实现Chin ...

  8. mnist手写数字识别python_基于tensorflow的MNIST手写数字识别(二)--入门篇

    一.本文的意义 因为谷歌官方其实已经写了MNIST入门和深入两篇教程了,那我写这些文章又是为什么呢,只是抄袭?那倒并不是,更准确的说应该是笔记吧,然后用更通俗的语言来解释,并且补充更多,官方文章中没有 ...

  9. Python 3深度置信网络(DBN)在Tensorflow中的实现MNIST手写数字识别

    任何程序错误,以及技术疑问或需要解答的,请扫码添加作者VX:1755337994 使用DBN识别手写体 传统的多层感知机或者神经网络的一个问题: 反向传播可能总是导致局部最小值. 当误差表面(erro ...

最新文章

  1. JAVA多线程Thread VS Runnable详解
  2. linux wifi关闭5g,TP-Link路由器如何关闭5G无线Wi-Fi信号?
  3. python变量名可以包含的字符有问号吗,带问号文字的Python正则表达式
  4. 入门微信小程序(含实战) [第九篇] -- 下拉刷新和上拉加载
  5. mysql的dockerfile_dockerfile构建mysql镜像
  6. python查找当前路径,在Python中查找当前终端选项卡的当前目录
  7. Leetcode142. Linked List Cycle II环形链表2
  8. 怎样借助营销圈帮助企业扩大品牌知名度呢?
  9. show tables mysql_MySQL_解析MYSQL显示表信息的方法,在用mysql时(show tables),有时候 - phpStudy...
  10. extThree20XML extThree20JSON 引入到工程中的方式
  11. 谈谈JS的全局变量跟局部变量
  12. centos7下学习Redis(一)
  13. excel做ns流程图_NS流程图是什么图?用这款软件轻松画NS流程图
  14. 【数学模型】基于Matlab模拟超市排队系统
  15. 比例风险(Cox)回归模型——Proportional hazards model
  16. 查找算法——adjacent_find
  17. Widows Tips
  18. GPS北斗卫星时钟同步系统的原理和技术
  19. AD(altium designer)15原理图与PCB设计教程(十)——信号完整性分析
  20. 域控服务器里没有internet时间,server2008r2域控时间设置internet时间同步的方法

热门文章

  1. 关于VMWare Data Protection VDP的使用心得
  2. github push 出错:fatal: Authentication failed for 'https://github.com/ ..的解决
  3. Markdown的使用之一:表格和公式
  4. 花游双人、三级跳斩获金银
  5. Cartesi 举办的2023 黑客马拉松
  6. JAVA —— 比较日期时间大小
  7. java.lang.IllegalArgumentException: java.security.InvalidKeyException: Illegal key siz
  8. 亲子关系-《亲子关系全面技巧》书中的精髓:学会正确处理亲子关系的技巧,与孩子建立良好的关系。
  9. 使用Node.js创建命令行工具
  10. xshell编程自动备份数据库