tensorflow保存模型和加载模型的方法(Python和Android)

一、tensorflow保存模型的几种方法:

(1) tf.train.saver()保存模型

使用 tf.train.saver()保存模型,该方法保存模型文件的时候会产生多个文件,会把计算图的结构和图上参数取值分成了不同的文件存储。这种方法是在TensorFlow中是最常用的保存方式。

例如:

import tensorflow as tf
# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:sess.run(init_op)print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比print("v2:", sess.run(v2))saver_path = saver.save(sess, "save/model.ckpt")  # 将模型保存到save/model.ckpt文件print("Model saved in file:", saver_path)

运行后,会在save目录下保存了四个文件:

    其中

  • checkpoint是检查点文件,文件保存了一个目录下所有的模型文件列表;
  • model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构,该文件可以被 tf.train.import_meta_graph 加载到当前默认的图来使用。
  • ckpt.data : 保存模型中每个变量的取值

参考资料:
https://blog.csdn.net/michael_yt/article/details/74737489

https://blog.csdn.net/lwplwf/article/details/62419087

(2)tf.train.write_graph()

使用 tf.train.write_graph()保存模型,该方法只是保存了模型的结构,并不保存训练完毕的参数值。

(3)convert_variables_to_constants固化模型结构

很多时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型结构的定义与权重),方便在其他地方使用(如在Android中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。 我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

TensoFlow为我们提供了convert_variables_to_constants()方法,该方法可以固化模型结构,将计算图中的变量取值以常量的形式保存。而且保存的模型可以移植到Android平台。

    参考资料:
    【1】https://blog.csdn.net/sinat_29957455/article/details/78511119

【2】这里主要实现第三种方法,因为该方法保存的模型可以移植到Android平台运行。以下Python代码,都共享在

Github:https://github.com/PanJinquan/tensorflow-learning-tutorials/tree/master/MNIST-Demo;

【3】移植Android的详细过程可参考本人的另一篇博客资料《将tensorflow MNIST训练模型移植到Android》:

       https://blog.csdn.net/guyuealian/article/details/79672257

二、训练和保存模型

以MNIST手写数字识别为例,这里首先使用Python版的TensorFlow实现SoftMax Regression分类器,并将训练好的模型的网络拓扑结构和参数保存为pb文件,其中convert_variables_to_constants函数,会将计算图中的变量取值以常量的形式保存:https://blog.csdn.net/sinat_29957455/article/details/78511119

#coding=utf-8
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
from tensorflow.python.framework import graph_util
print('tensortflow:{0}'.format(tf.__version__))mnist = input_data.read_data_sets("Mnist_data/", one_hot=True)#create model
with tf.name_scope('input'):x = tf.placeholder(tf.float32,[None,784],name='x_input')#输入节点名:x_inputy_ = tf.placeholder(tf.float32,[None,10],name='y_input')
with tf.name_scope('layer'):with tf.name_scope('W'):#tf.zeros([3, 4], tf.int32) ==> [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]W = tf.Variable(tf.zeros([784,10]),name='Weights')with tf.name_scope('b'):b = tf.Variable(tf.zeros([10]),name='biases')with tf.name_scope('W_p_b'):Wx_plus_b = tf.add(tf.matmul(x, W), b, name='Wx_plus_b')y = tf.nn.softmax(Wx_plus_b, name='final_result')# 定义损失函数和优化方法
with tf.name_scope('loss'):loss = -tf.reduce_sum(y_ * tf.log(y))
with tf.name_scope('train_step'):train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)print(train_step)
# 初始化
sess = tf.InteractiveSession()
init = tf.global_variables_initializer()
sess.run(init)
# 训练
for step in range(100):batch_xs,batch_ys =mnist.train.next_batch(100)train_step.run({x:batch_xs,y_:batch_ys})# variables = tf.all_variables()# print(len(variables))# print(sess.run(b))# 测试模型准确率
pre_num=tf.argmax(y,1,output_type='int32',name="output")#输出节点名:output
correct_prediction = tf.equal(pre_num,tf.argmax(y_,1,output_type='int32'))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
a = accuracy.eval({x:mnist.test.images,y_:mnist.test.labels})
print('测试正确率:{0}'.format(a))# 保存训练好的模型
#形参output_node_names用于指定输出的节点名称,output_node_names=['output']对应pre_num=tf.argmax(y,1,name="output"),
output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def,output_node_names=['output'])
with tf.gfile.FastGFile('model/mnist.pb', mode='wb') as f:#’wb’中w代表写文件,b代表将数据以二进制方式写入文件。f.write(output_graph_def.SerializeToString())
sess.close()# 注:
# convert_variables_to_constants函数,会将计算图中的变量取值以常量的形式保存:https://blog.csdn.net/sinat_29957455/article/details/78511119

三、加载和测试

批量测试:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt#模型路径
model_path = 'model/mnist.pb'
#测试数据
mnist = input_data.read_data_sets("Mnist_data/", one_hot=True)
x_test = mnist.test.images
x_labels = mnist.test.labels;with tf.Graph().as_default():output_graph_def = tf.GraphDef()with open(model_path, "rb") as f:output_graph_def.ParseFromString(f.read())tf.import_graph_def(output_graph_def, name="")with tf.Session() as sess:tf.global_variables_initializer().run()# x_test = x_test.reshape(1, 28 * 28)input_x = sess.graph.get_tensor_by_name("input/x_input:0")output = sess.graph.get_tensor_by_name("output:0")# 【1】下面是进行批量测试----------------------------------------------------------pre_num = sess.run(output, feed_dict={input_x: x_test})#利用训练好的模型预测结果#结果批量测试的准确率correct_prediction = tf.equal(pre_num, tf.argmax(x_labels, 1,output_type='int32'))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))acc = sess.run(accuracy, feed_dict={input_x: x_test})# a = accuracy.eval({x: mnist.test.images, y_: mnist.test.labels})print('测试正确率:{0}'.format(acc))#【2】下面是进行单张图片的测试-----------------------------------------------------testImage=x_test[0];test_input = testImage.reshape(1, 28 * 28)pre_num = sess.run(output, feed_dict={input_x: test_input})#利用训练好的模型预测结果print('模型预测结果为:',pre_num)#显示测试的图片testImage = testImage.reshape(28, 28)testImage=np.array(testImage * 255, dtype="int32")fig = plt.figure(), plt.imshow(testImage, cmap='binary')  # 显示图片plt.title("prediction result:"+str(pre_num))plt.show()#保存测定的图片testImage = Image.fromarray(testImage)testImage = testImage.convert('L')testImage.save("data/test_image.jpg")# matplotlib.image.imsave('data/name.jpg', im)sess.close()

单个样本测试:

import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt#模型路径
model_path = 'model/mnist.pb'
#测试图片
testImage = Image.open("data/test_image.jpg");with tf.Graph().as_default():output_graph_def = tf.GraphDef()with open(model_path, "rb") as f:output_graph_def.ParseFromString(f.read())tf.import_graph_def(output_graph_def, name="")with tf.Session() as sess:tf.global_variables_initializer().run()# x_test = x_test.reshape(1, 28 * 28)input_x = sess.graph.get_tensor_by_name("input/x_input:0")output = sess.graph.get_tensor_by_name("output:0")#对图片进行测试testImage=testImage.convert('L')testImage = testImage.resize((28, 28))test_input=np.array(testImage)test_input = test_input.reshape(1, 28 * 28)pre_num = sess.run(output, feed_dict={input_x: test_input})#利用训练好的模型预测结果print('模型预测结果为:',pre_num)#显示测试的图片# testImage = test_x.reshape(28, 28)fig = plt.figure(), plt.imshow(testImage,cmap='binary')  # 显示图片plt.title("prediction result:"+str(pre_num))plt.show()

读取图片进行测试:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv
#模型路径
model_path = 'model/mnist.pb'
#测试图片
testImage = cv.imread("data/test_image.jpg");with tf.Graph().as_default():output_graph_def = tf.GraphDef()with open(model_path, "rb") as f:output_graph_def.ParseFromString(f.read())tf.import_graph_def(output_graph_def, name="")with tf.Session() as sess:tf.global_variables_initializer().run()# x_test = x_test.reshape(1, 28 * 28)input_x = sess.graph.get_tensor_by_name("input/x_input:0")output = sess.graph.get_tensor_by_name("output:0")#对图片进行测试testImage=cv.cvtColor(testImage, cv.COLOR_BGR2GRAY)testImage=cv.resize(testImage,dsize=(28, 28))test_input=np.array(testImage)test_input = test_input.reshape(1, 28 * 28)pre_num = sess.run(output, feed_dict={input_x: test_input})#利用训练好的模型预测结果print('模型预测结果为:',pre_num)# cv.imshow("image",testImage)# cv.waitKey(0)#显示测试的图片fig = plt.figure(), plt.imshow(testImage,cmap='binary')  # 显示图片plt.title("prediction result:"+str(pre_num))plt.show()

源码Github:https://github.com/PanJinquan/MNIST-TensorFlow-Python

上面TensorFlow保存训练好的模型,可以移植到Android,详细过程可以参考另一篇博客资料《将tensorflow MNIST训练模型移植到Android》:https://blog.csdn.net/guyuealian/article/details/79672257

tensorflow保存模型和加载模型的方法(Python和Android)相关推荐

  1. Keras框架下的保存模型和加载模型

    在Keras框架下训练深度学习模型时,一般思路是在训练环境下训练出模型,然后拿训练好的模型(即保存模型相应信息的文件)到生产环境下去部署.在训练过程中我们可能会遇到以下情况: 需要运行很长时间的程序在 ...

  2. tensorflow 保存训练loss_tensorflow2.0保存和加载模型 (tensorflow2.0官方教程翻译)

    最新版本:https://www.mashangxue123.com/tensorflow/tf2-tutorials-keras-save_and_restore_models.html 英文版本: ...

  3. python保存模型与参数_基于pytorch的保存和加载模型参数的方法

    当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了. 保存和加载模型参数有两种方式: 方式一: torc ...

  4. TensorFlow 保存和加载模型

    参考: 保存和恢复模型官方教程 tensorflow2保存和加载模型 TensorFlow2.0教程-keras模型保存和序列化

  5. PyTorch | 保存和加载模型教程

    点击上方"算法猿的成长",选择"加为星标" 第一时间关注 AI 和 Python 知识 图片来自 Unsplash,作者: Jenny Caywood 2019 ...

  6. Pytorch 保存和加载模型

    当保存和加载模型时,需要熟悉三个核心功能: 1. torch.save :将序列化对象保存到磁盘.此函数使用Python的 pickle 模块进行序列化.使 用此函数可以保存如模型.tensor.字典 ...

  7. 【pytorch】(六)保存和加载模型

    文章目录 保存和加载模型 保存加载模型参数 保存加载模型和参数 保存和加载模型 import torch from torch import nn from torch.utils.data impo ...

  8. 机器学习代码实战——保存和加载模型(Save and Load Model)

    文章目录 1.实验目的 2.保存与加载模型 2.1.pickle方法 2.2.joblib方法 1.实验目的 每当我们训练完一个模型后,我们需要保存训练好的模型留给下次用或者再次训练,因此我将给出两种 ...

  9. pytorch load state dict_Pytorch学习记录-使用Pytorch进行深度学习,保存和加载模型

    新建 Microsoft PowerPoint 演示文稿 (2).jpg 保存和加载模型 在完成60分钟入门之后,接下来有六节tutorials和五节关于文本处理的tutorials.争取一天一节.不 ...

最新文章

  1. Win64 驱动内核编程-22.SHADOW SSDT HOOK(宋孖健)
  2. Lindström–Gessel–Viennot lemma
  3. 面对想法简单客户的有效需求分析
  4. PowerShell 扩展工具第四波!
  5. c语言文件归并问题_通知:土地有变!土地归并:每亩补9万?明年起:合村并镇!能否启动?1个好消息!...
  6. vue-count-to插件使用方法
  7. libtorch下载
  8. Waiting Processed Cancelable ShowDialog (Release 2)
  9. C++下用什么矩阵运算库比较好
  10. 算法笔记_面试_0.刷leetcode_基础知识范围
  11. 基于FPGA的中值滤波器设计
  12. 计算机管理系统论文参考文献,关于计算机系统管理的论文参考文献 计算机系统管理论文参考文献哪里找...
  13. 学生环境保护绿色家园 WEB静态网页作业模板 大学生环境保护网页代码 dreamweaver网页设计作业制作 dw个人网页作业成品
  14. 大众点评字体解密(最新)2020/4/17
  15. 谈谈电子设计中PCB上的ESD防护方法
  16. 数字手写体识别python实现(全连接神经网络)
  17. vulnhub靶机-DC2-Writeup
  18. Excel中列和行之间的互换技巧。
  19. Linux USB驱动分析(一)----USB2.0协议分析
  20. 论文精读2: Ground-to-Aerial Image Geo-LocalizationWith a Hard Exemplar Reweighting Triplet Loss

热门文章

  1. 关于0xFFFFFFFF和alpha,温故而知新
  2. 物理磁盘空间使用已满导致数据库hang起
  3. IOS将字符串转换为日期时间格式
  4. ZOJ-2770 Burn the Linked Camp 差分约束
  5. hahahahahah
  6. css display属性理解
  7. 脚本配置文件(通过一个案例解释下什么叫脚本配置文件)
  8. iOS-Xcode代码统计
  9. 你知道JVM内存的那些事吗?
  10. 首个AI国际标准有望明年出台,创新工场等多家国内公司已参与