tensorflow保存模型和加载模型的方法(Python和Android)
tensorflow保存模型和加载模型的方法(Python和Android)
一、tensorflow保存模型的几种方法:
(1) tf.train.saver()保存模型
使用 tf.train.saver()保存模型,该方法保存模型文件的时候会产生多个文件,会把计算图的结构和图上参数取值分成了不同的文件存储。这种方法是在TensorFlow中是最常用的保存方式。
例如:
import tensorflow as tf
# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:sess.run(init_op)print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比print("v2:", sess.run(v2))saver_path = saver.save(sess, "save/model.ckpt") # 将模型保存到save/model.ckpt文件print("Model saved in file:", saver_path)
运行后,会在save目录下保存了四个文件:
其中
- checkpoint是检查点文件,文件保存了一个目录下所有的模型文件列表;
- model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构,该文件可以被 tf.train.import_meta_graph 加载到当前默认的图来使用。
- ckpt.data : 保存模型中每个变量的取值
参考资料:
https://blog.csdn.net/michael_yt/article/details/74737489https://blog.csdn.net/lwplwf/article/details/62419087
(2)tf.train.write_graph()
使用 tf.train.write_graph()保存模型,该方法只是保存了模型的结构,并不保存训练完毕的参数值。
(3)convert_variables_to_constants固化模型结构
很多时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型结构的定义与权重),方便在其他地方使用(如在Android中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。 我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。
TensoFlow为我们提供了convert_variables_to_constants()方法,该方法可以固化模型结构,将计算图中的变量取值以常量的形式保存。而且保存的模型可以移植到Android平台。
参考资料:
【1】https://blog.csdn.net/sinat_29957455/article/details/78511119【2】这里主要实现第三种方法,因为该方法保存的模型可以移植到Android平台运行。以下Python代码,都共享在
Github:https://github.com/PanJinquan/tensorflow-learning-tutorials/tree/master/MNIST-Demo;
【3】移植Android的详细过程可参考本人的另一篇博客资料《将tensorflow MNIST训练模型移植到Android》:
https://blog.csdn.net/guyuealian/article/details/79672257
二、训练和保存模型
以MNIST手写数字识别为例,这里首先使用Python版的TensorFlow实现SoftMax Regression分类器,并将训练好的模型的网络拓扑结构和参数保存为pb文件,其中convert_variables_to_constants函数,会将计算图中的变量取值以常量的形式保存:https://blog.csdn.net/sinat_29957455/article/details/78511119
#coding=utf-8
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
from tensorflow.python.framework import graph_util
print('tensortflow:{0}'.format(tf.__version__))mnist = input_data.read_data_sets("Mnist_data/", one_hot=True)#create model
with tf.name_scope('input'):x = tf.placeholder(tf.float32,[None,784],name='x_input')#输入节点名:x_inputy_ = tf.placeholder(tf.float32,[None,10],name='y_input')
with tf.name_scope('layer'):with tf.name_scope('W'):#tf.zeros([3, 4], tf.int32) ==> [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]W = tf.Variable(tf.zeros([784,10]),name='Weights')with tf.name_scope('b'):b = tf.Variable(tf.zeros([10]),name='biases')with tf.name_scope('W_p_b'):Wx_plus_b = tf.add(tf.matmul(x, W), b, name='Wx_plus_b')y = tf.nn.softmax(Wx_plus_b, name='final_result')# 定义损失函数和优化方法
with tf.name_scope('loss'):loss = -tf.reduce_sum(y_ * tf.log(y))
with tf.name_scope('train_step'):train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)print(train_step)
# 初始化
sess = tf.InteractiveSession()
init = tf.global_variables_initializer()
sess.run(init)
# 训练
for step in range(100):batch_xs,batch_ys =mnist.train.next_batch(100)train_step.run({x:batch_xs,y_:batch_ys})# variables = tf.all_variables()# print(len(variables))# print(sess.run(b))# 测试模型准确率
pre_num=tf.argmax(y,1,output_type='int32',name="output")#输出节点名:output
correct_prediction = tf.equal(pre_num,tf.argmax(y_,1,output_type='int32'))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
a = accuracy.eval({x:mnist.test.images,y_:mnist.test.labels})
print('测试正确率:{0}'.format(a))# 保存训练好的模型
#形参output_node_names用于指定输出的节点名称,output_node_names=['output']对应pre_num=tf.argmax(y,1,name="output"),
output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def,output_node_names=['output'])
with tf.gfile.FastGFile('model/mnist.pb', mode='wb') as f:#’wb’中w代表写文件,b代表将数据以二进制方式写入文件。f.write(output_graph_def.SerializeToString())
sess.close()# 注:
# convert_variables_to_constants函数,会将计算图中的变量取值以常量的形式保存:https://blog.csdn.net/sinat_29957455/article/details/78511119
三、加载和测试
批量测试:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt#模型路径
model_path = 'model/mnist.pb'
#测试数据
mnist = input_data.read_data_sets("Mnist_data/", one_hot=True)
x_test = mnist.test.images
x_labels = mnist.test.labels;with tf.Graph().as_default():output_graph_def = tf.GraphDef()with open(model_path, "rb") as f:output_graph_def.ParseFromString(f.read())tf.import_graph_def(output_graph_def, name="")with tf.Session() as sess:tf.global_variables_initializer().run()# x_test = x_test.reshape(1, 28 * 28)input_x = sess.graph.get_tensor_by_name("input/x_input:0")output = sess.graph.get_tensor_by_name("output:0")# 【1】下面是进行批量测试----------------------------------------------------------pre_num = sess.run(output, feed_dict={input_x: x_test})#利用训练好的模型预测结果#结果批量测试的准确率correct_prediction = tf.equal(pre_num, tf.argmax(x_labels, 1,output_type='int32'))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))acc = sess.run(accuracy, feed_dict={input_x: x_test})# a = accuracy.eval({x: mnist.test.images, y_: mnist.test.labels})print('测试正确率:{0}'.format(acc))#【2】下面是进行单张图片的测试-----------------------------------------------------testImage=x_test[0];test_input = testImage.reshape(1, 28 * 28)pre_num = sess.run(output, feed_dict={input_x: test_input})#利用训练好的模型预测结果print('模型预测结果为:',pre_num)#显示测试的图片testImage = testImage.reshape(28, 28)testImage=np.array(testImage * 255, dtype="int32")fig = plt.figure(), plt.imshow(testImage, cmap='binary') # 显示图片plt.title("prediction result:"+str(pre_num))plt.show()#保存测定的图片testImage = Image.fromarray(testImage)testImage = testImage.convert('L')testImage.save("data/test_image.jpg")# matplotlib.image.imsave('data/name.jpg', im)sess.close()
单个样本测试:
import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt#模型路径
model_path = 'model/mnist.pb'
#测试图片
testImage = Image.open("data/test_image.jpg");with tf.Graph().as_default():output_graph_def = tf.GraphDef()with open(model_path, "rb") as f:output_graph_def.ParseFromString(f.read())tf.import_graph_def(output_graph_def, name="")with tf.Session() as sess:tf.global_variables_initializer().run()# x_test = x_test.reshape(1, 28 * 28)input_x = sess.graph.get_tensor_by_name("input/x_input:0")output = sess.graph.get_tensor_by_name("output:0")#对图片进行测试testImage=testImage.convert('L')testImage = testImage.resize((28, 28))test_input=np.array(testImage)test_input = test_input.reshape(1, 28 * 28)pre_num = sess.run(output, feed_dict={input_x: test_input})#利用训练好的模型预测结果print('模型预测结果为:',pre_num)#显示测试的图片# testImage = test_x.reshape(28, 28)fig = plt.figure(), plt.imshow(testImage,cmap='binary') # 显示图片plt.title("prediction result:"+str(pre_num))plt.show()
读取图片进行测试:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv
#模型路径
model_path = 'model/mnist.pb'
#测试图片
testImage = cv.imread("data/test_image.jpg");with tf.Graph().as_default():output_graph_def = tf.GraphDef()with open(model_path, "rb") as f:output_graph_def.ParseFromString(f.read())tf.import_graph_def(output_graph_def, name="")with tf.Session() as sess:tf.global_variables_initializer().run()# x_test = x_test.reshape(1, 28 * 28)input_x = sess.graph.get_tensor_by_name("input/x_input:0")output = sess.graph.get_tensor_by_name("output:0")#对图片进行测试testImage=cv.cvtColor(testImage, cv.COLOR_BGR2GRAY)testImage=cv.resize(testImage,dsize=(28, 28))test_input=np.array(testImage)test_input = test_input.reshape(1, 28 * 28)pre_num = sess.run(output, feed_dict={input_x: test_input})#利用训练好的模型预测结果print('模型预测结果为:',pre_num)# cv.imshow("image",testImage)# cv.waitKey(0)#显示测试的图片fig = plt.figure(), plt.imshow(testImage,cmap='binary') # 显示图片plt.title("prediction result:"+str(pre_num))plt.show()
源码Github:https://github.com/PanJinquan/MNIST-TensorFlow-Python
上面TensorFlow保存训练好的模型,可以移植到Android,详细过程可以参考另一篇博客资料《将tensorflow MNIST训练模型移植到Android》:https://blog.csdn.net/guyuealian/article/details/79672257
tensorflow保存模型和加载模型的方法(Python和Android)相关推荐
- Keras框架下的保存模型和加载模型
在Keras框架下训练深度学习模型时,一般思路是在训练环境下训练出模型,然后拿训练好的模型(即保存模型相应信息的文件)到生产环境下去部署.在训练过程中我们可能会遇到以下情况: 需要运行很长时间的程序在 ...
- tensorflow 保存训练loss_tensorflow2.0保存和加载模型 (tensorflow2.0官方教程翻译)
最新版本:https://www.mashangxue123.com/tensorflow/tf2-tutorials-keras-save_and_restore_models.html 英文版本: ...
- python保存模型与参数_基于pytorch的保存和加载模型参数的方法
当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了. 保存和加载模型参数有两种方式: 方式一: torc ...
- TensorFlow 保存和加载模型
参考: 保存和恢复模型官方教程 tensorflow2保存和加载模型 TensorFlow2.0教程-keras模型保存和序列化
- PyTorch | 保存和加载模型教程
点击上方"算法猿的成长",选择"加为星标" 第一时间关注 AI 和 Python 知识 图片来自 Unsplash,作者: Jenny Caywood 2019 ...
- Pytorch 保存和加载模型
当保存和加载模型时,需要熟悉三个核心功能: 1. torch.save :将序列化对象保存到磁盘.此函数使用Python的 pickle 模块进行序列化.使 用此函数可以保存如模型.tensor.字典 ...
- 【pytorch】(六)保存和加载模型
文章目录 保存和加载模型 保存加载模型参数 保存加载模型和参数 保存和加载模型 import torch from torch import nn from torch.utils.data impo ...
- 机器学习代码实战——保存和加载模型(Save and Load Model)
文章目录 1.实验目的 2.保存与加载模型 2.1.pickle方法 2.2.joblib方法 1.实验目的 每当我们训练完一个模型后,我们需要保存训练好的模型留给下次用或者再次训练,因此我将给出两种 ...
- pytorch load state dict_Pytorch学习记录-使用Pytorch进行深度学习,保存和加载模型
新建 Microsoft PowerPoint 演示文稿 (2).jpg 保存和加载模型 在完成60分钟入门之后,接下来有六节tutorials和五节关于文本处理的tutorials.争取一天一节.不 ...
最新文章
- Win64 驱动内核编程-22.SHADOW SSDT HOOK(宋孖健)
- Lindström–Gessel–Viennot lemma
- 面对想法简单客户的有效需求分析
- PowerShell 扩展工具第四波!
- c语言文件归并问题_通知:土地有变!土地归并:每亩补9万?明年起:合村并镇!能否启动?1个好消息!...
- vue-count-to插件使用方法
- libtorch下载
- Waiting Processed Cancelable ShowDialog (Release 2)
- C++下用什么矩阵运算库比较好
- 算法笔记_面试_0.刷leetcode_基础知识范围
- 基于FPGA的中值滤波器设计
- 计算机管理系统论文参考文献,关于计算机系统管理的论文参考文献 计算机系统管理论文参考文献哪里找...
- 学生环境保护绿色家园 WEB静态网页作业模板 大学生环境保护网页代码 dreamweaver网页设计作业制作 dw个人网页作业成品
- 大众点评字体解密(最新)2020/4/17
- 谈谈电子设计中PCB上的ESD防护方法
- 数字手写体识别python实现(全连接神经网络)
- vulnhub靶机-DC2-Writeup
- Excel中列和行之间的互换技巧。
- Linux USB驱动分析(一)----USB2.0协议分析
- 论文精读2: Ground-to-Aerial Image Geo-LocalizationWith a Hard Exemplar Reweighting Triplet Loss