Tensorflow训练的模型,如何保存与载入?

目的:学习tensorflow框架的DNN,掌握如何将tensorflow训练得到的模型保存并载入,做预测?

内容:

1、tensorflow模型保存与载入的两种方法

2、实例分析1——线性回归

3、实例分析2——mnist分类


一、tensorflow模型保存与载入的两种方法

参考网址:https://blog.csdn.net/thriving_fcl/article/details/71423039,tensorflow模型保存与载入的两种方法:

方法一:

保存模型(定义变量 + 使用saver.save()方法保存)

import tensorflow as tf
import numpy as npW = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w')
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b')init = tf.initialize_all_variables()
saver = tf.train.Saver()
with tf.Session() as sess:sess.run(init)save_path = saver.save(sess,"save/model.ckpt")

载入模型(定义变量 + 使用saver.restore()方法载入)

import tensorflow as tf
import numpy as npW = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w')
b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b')saver = tf.train.Saver()
with tf.Session() as sess:saver.restore(sess,"save/model.ckpt")

该方法的缺点:在使用模型时,必须把模型的结构重新定义一次,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。

方法二:不重新定义网络结构的方法

具体地址见点击打开链接


二、实例分析1——线性回归

python代码如下:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as pltplotdata = { "batchsize":[], "loss":[] }
def moving_average(a, w=10):if len(a) < w: return a[:]    return [val if idx < w else sum(a[(idx-w):idx])/w for idx, val in enumerate(a)]#生成模拟数据
train_X = np.linspace(-1, 1, 100)
train_Y = 2 * train_X + np.random.randn(*train_X.shape) * 0.3 # y=2x,但是加入了噪声
#显示模拟数据点
plt.plot(train_X, train_Y, 'ro', label='Original data')
plt.legend()
plt.show()# 创建模型
# 占位符
X = tf.placeholder("float")
Y = tf.placeholder("float")
# 模型参数
W = tf.Variable(tf.random_normal([1]), name="weight")
b = tf.Variable(tf.zeros([1]), name="bias")# 前向结构
z = tf.multiply(X, W)+ b
tf.summary.histogram("z",z)#反向优化
cost =tf.reduce_mean( tf.square(Y - z))
tf.summary.scalar('loss_function',cost)learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) #Gradient descent# 初始化变量
init = tf.global_variables_initializer()
# 训练参数
training_epochs = 20
display_step = 2# 启动session
with tf.Session() as sess:sess.run(init)merged_summary_op = tf.summary.merge_all()summary_writer = tf.summary.FileWriter('log/mnist_with_summaries',sess.graph)# Fit all training datafor epoch in range(training_epochs):for (x, y) in zip(train_X, train_Y):sess.run(optimizer, feed_dict={X: x, Y: y})#生成summarysummary_str = sess.run(merged_summary_op,feed_dict={X:x,Y:y})summary_writer.add_summary(summary_str,epoch) #将summary写入文件#显示训练中的详细信息if epoch % display_step == 0:loss = sess.run(cost, feed_dict={X: train_X, Y:train_Y})print ("Epoch:", epoch+1, "cost=", loss,"W=", sess.run(W), "b=", sess.run(b))if not (loss == "NA" ):plotdata["batchsize"].append(epoch)plotdata["loss"].append(loss)print (" Finished!")print ("cost=", sess.run(cost, feed_dict={X: train_X, Y: train_Y}), "W=", sess.run(W), "b=", sess.run(b))#print ("cost:",cost.eval({X: train_X, Y: train_Y}))#图形显示plt.plot(train_X, train_Y, 'ro', label='Original data')plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fitted line')plt.legend()plt.show()plotdata["avgloss"] = moving_average(plotdata["loss"])plt.figure(1)plt.plot(plotdata["batchsize"], plotdata["avgloss"], 'b--')plt.xlabel('Minibatch number')plt.ylabel('Loss')plt.title('Minibatch run vs. Training loss')plt.show()print ("x=0.2,z=", sess.run(z, feed_dict={X: 0.2}))

运行结果:





三、实例分析2——mnist分类

python代码如下:

import tensorflow as tf #导入tensorflow库
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
import pylab tf.reset_default_graph()
# tf Graph Input
x = tf.placeholder(tf.float32, [None, 784]) # mnist data维度 28*28=784
y = tf.placeholder(tf.float32, [None, 10]) # 0-9 数字=> 10 classes# Set model weights
W = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))# 构建模型
pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax分类# Minimize error using cross entropy
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))#参数设置
learning_rate = 0.01
# 使用梯度下降优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)training_epochs = 25
batch_size = 100
display_step = 1
saver = tf.train.Saver()
model_path = "log/521model.ckpt"# 启动session
with tf.Session() as sess:sess.run(tf.global_variables_initializer())# Initializing OP# 启动循环开始训练for epoch in range(training_epochs):avg_cost = 0.total_batch = int(mnist.train.num_examples/batch_size)# 遍历全部数据集for i in range(total_batch):batch_xs, batch_ys = mnist.train.next_batch(batch_size)# Run optimization op (backprop) and cost op (to get loss value)_, c = sess.run([optimizer, cost], feed_dict={x: batch_xs,y: batch_ys})# Compute average lossavg_cost += c / total_batch# 显示训练中的详细信息if (epoch+1) % display_step == 0:print ("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost))print( " Finished!")# 测试 modelcorrect_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))# 计算准确率accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))# Save model weights to disksave_path = saver.save(sess, model_path)print("Model saved in file: %s" % save_path)#读取模型
print("Starting 2nd session...")
with tf.Session() as sess:# Initialize variablessess.run(tf.global_variables_initializer())# Restore model weights from previously saved modelsaver.restore(sess, model_path)# 测试 modelcorrect_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))# 计算准确率accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))output = tf.argmax(pred, 1)batch_xs, batch_ys = mnist.train.next_batch(2)outputval,predv = sess.run([output,pred], feed_dict={x: batch_xs})print(outputval,predv,batch_ys)im = batch_xs[0]im = im.reshape(-1,28)pylab.imshow(im)pylab.show()im = batch_xs[1]im = im.reshape(-1,28)pylab.imshow(im)pylab.show()

运行结果:

Epoch: 0001 cost= 7.752772305
Epoch: 0002 cost= 4.151113472
Epoch: 0003 cost= 2.902867300
Epoch: 0004 cost= 2.292819615
Epoch: 0005 cost= 1.945046414
Epoch: 0006 cost= 1.721682055
Epoch: 0007 cost= 1.565224952
Epoch: 0008 cost= 1.448184885
Epoch: 0009 cost= 1.357409785
Epoch: 0010 cost= 1.283956942
Epoch: 0011 cost= 1.223152844
Epoch: 0012 cost= 1.171679115
Epoch: 0013 cost= 1.127339950
Epoch: 0014 cost= 1.089194359
Epoch: 0015 cost= 1.055257367
Epoch: 0016 cost= 1.025059551
Epoch: 0017 cost= 0.997867818
Epoch: 0018 cost= 0.973305143
Epoch: 0019 cost= 0.951017423
Epoch: 0020 cost= 0.930552574
Epoch: 0021 cost= 0.911731513
Epoch: 0022 cost= 0.894192883
Epoch: 0023 cost= 0.878128686
Epoch: 0024 cost= 0.862873784
Epoch: 0025 cost= 0.848758641Finished!
Accuracy: 0.8355
Model saved in file: log/521model.ckpt
Starting 2nd session...
Accuracy: 0.8355
[2 6] [[  4.44788748e-05   5.84178214e-13   9.99922991e-01   1.38609546e-094.55205260e-08   5.39752136e-06   2.67073501e-05   2.82684276e-164.42017324e-07   2.21145666e-14][  1.80836395e-08   5.05934682e-18   5.30333818e-05   5.56881845e-142.38929709e-10   7.70143487e-08   9.99946833e-01   5.28569544e-099.02450684e-11   3.62926039e-14]] [[ 0.  0.  1.  0.  0.  0.  0.  0.  0.  0.][ 0.  0.  0.  0.  0.  0.  1.  0.  0.  0.]]
换种方式,读入模型:

import tensorflow as tf #导入tensorflow库
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
import pylab########################################################################
from pylab import *
mpl.rcParams['font.sans-serif'] = ['SimHei'] # 若不添加,中文无法在图中显示
# import matplotlib
# matplotlib.rcParams['axes.unicode_minus']=False # 若不添加,无法在图中显示负号
###########################################################################tf.reset_default_graph()
# tf Graph Input
x = tf.placeholder(tf.float32, [None, 784]) # mnist data维度 28*28=784
y = tf.placeholder(tf.float32, [None, 10]) # 0-9 数字=> 10 classes# Set model weights
W = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))# 构建模型
pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax分类# Minimize error using cross entropy
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))#参数设置
learning_rate = 0.01
# 使用梯度下降优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)saver = tf.train.Saver()
model_path = "log/521model.ckpt"
###############################################################################
# 读取模型
print("Starting 2nd session...")
with tf.Session() as sess:# Initialize variablessess.run(tf.global_variables_initializer())# Restore model weights from previously saved modelsaver.restore(sess, model_path)# # 测试 model# correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))# # 计算准确率# accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))# print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))output = tf.argmax(pred, 1)batch_xs, batch_ys = mnist.train.next_batch(2)outputval, predv = sess.run([output, pred], feed_dict={x: batch_xs})# print(outputval, predv, batch_ys)#######################################################################print(outputval)pylab.subplot(121)im = batch_xs[0]im = im.reshape(-1, 28)pylab.title('该图片中的数字为:'+ str(outputval[0]))pylab.imshow(im)pylab.subplot(122)im = batch_xs[1]im = im.reshape(-1, 28)pylab.title('该图片中的数字为:' + str(outputval[1]))pylab.imshow(im)pylab.show()

运行结果:

Starting 2nd session...
[6 1]
参考网址:

1、https://www.cnblogs.com/bonelee/p/8445261.html

2、https://blog.csdn.net/luoyexuge/article/details/78243117

3、https://blog.csdn.net/BugCreater/article/details/53293075

Tensorflow训练的模型,如何保存与载入?相关推荐

  1. TF:利用TF的train.Saver将训练好的W、b模型文件保存+新建载入刚训练好模型(用于以后预测新的数据)

    TF:利用TF的train.Saver将训练好的W.b模型文件保存+新建载入刚训练好模型(用于以后预测新的数据) 目录 输出结果 代码设计 输出结果 代码设计 import tensorflow as ...

  2. tensorflow笔记:模型的保存与训练过程可视化

    tensorflow笔记系列:  (一) tensorflow笔记:流程,概念和简单代码注释  (二) tensorflow笔记:多层CNN代码分析  (三) tensorflow笔记:多层LSTM代 ...

  3. 使用TensorFlow训练WDL模型性能问题定位与调优

    简介 TensorFlow是Google研发的第二代人工智能学习系统,能够处理多种深度学习算法模型,以功能强大和高可扩展性而著称.TensorFlow完全开源,所以很多公司都在使用,但是美团点评在使用 ...

  4. pytorch多卡并行模型的保存与载入

    pytorch多卡并行模型的保存与载入 当模型是在数据并行方式在多卡上进行训练的训练和保存,那么载入的时候也是一样需要是多卡.并且,load_state_dict()函数的调用要放在DataParal ...

  5. 将TensorFlow训练的模型移植到Android手机

    2019独角兽企业重金招聘Python工程师标准>>> 前言 本文中出现的TF皆为TensorFlow的简称. 先说两句题外话吧,TensorFlow 前两天热热闹闹的发布了正式版r ...

  6. C#使用公共语言拓展(CLE)调用Python3(使用TensorFlow训练的模型)

    对于Python2来说,使用IronPython可以方便的实现C#调用Python,但是对于特定需求,比如使用TensorFlow(最低支持Python3.5),就没办法使用IronPython了,为 ...

  7. tensorflow训练yolov3模型(检测雪人为例,自己的数据和标签,windows环境)

    惯例先放效果 所有代码包含训练.测试图片视频打包在: 地址 下载代码: git clone https://github.com/YunYang1994/tensorflow-yolov3 或者点此下 ...

  8. 深度学习小技巧(二):如何保存和恢复scikit-learn训练的模型

    深度学习小技巧(一):如何保存和恢复TensorFlow训练的模型 在许多情况下,在使用scikit学习库的同时,你需要将预测模型保存到文件中,然后在使用它们的时候还原它们,以便重复使用以前的工作.比 ...

  9. 如何用java语言调用tensorflow训练好的模型

    1.TensorFlow的训练模型在Android和Java的应用及调用 2.tensorflow的python离线训练java在线预测方案 3.tensorflow训练的模型在java中的使用 4. ...

最新文章

  1. pandas读写结构化数据(read_csv,read_table, read_excel, read_html, read_sql)
  2. ORACLE查看当前连接用户的权限信息或者角色信息
  3. SAP 电商云 Spartacus UI ROUTING_FEATURE 的使用场景
  4. 万字详解Lambda、Stream和日期
  5. gm怎么刷东西 rust_RUST:2020年7月第三周 修补和更新
  6. java中线程池的使用方法
  7. 处理血压信号_测血压检测健康,8款高品质血压计推荐
  8. java-net-php-python-jsp刺绣作品展示网站计算机毕业设计程序
  9. STM32实战总结:HAL之触摸屏
  10. 永恒之蓝漏洞复现及上传后门程序
  11. 大数据Flink面试考题___Flink高频考点,万字超全整理(建议)
  12. 闲鱼双11全链路营销体系初体验
  13. itunes备份文件的位置在哪
  14. mysql的用户名迁移SCHEMA_数据库实时转移之Confluent环境搭建(二)
  15. 前端正则表达式指定邮箱域名匹配
  16. 6-23 sdust-Java-可实现多种排序的Book类
  17. GitLab安装使用(SSH+Docker两种方式)
  18. 后台执行linux命令
  19. Matlab自带排序函数sort用法
  20. STM32 的核心Cortex-M3 处理器

热门文章

  1. 计算机一级考试选择题知识点,计算机一级选择题必背知识点 考试题型有哪些...
  2. ISCC2021wp
  3. Python 模拟Laguerre Polynomial拉盖尔多项式
  4. api存在csrf攻击吗_使用rest api防止单页应用上的csrf攻击
  5. hdu 6863 Isomorphic Strings
  6. 解决Idea中yml文件不显示小绿叶图标
  7. 安装snipe-IT遇到的php问题
  8. C#中如何使用Sqlite、SqliCe等本地数据库?
  9. python刷步数程序设计_乐心健康间接修改微信步数-Docker持久运行python脚本
  10. 微信发布的辟谣小程序