存储Tensorflow训练网络的参数
正文共3565个字,预计阅读时间10分钟。
上海站 | 高性能计算之GPU CUDA培训
训练一个神经网络的目的是啥?不就是有朝一日让它有用武之地吗?可是,在别处使用训练好的网络,得先把网络的参数(就是那些variables)保存下来,怎么保存呢?其实,tensorflow已经给我们提供了很方便的API,来帮助我们实现训练参数的存储与读取,如果想了解详情,请看晦涩难懂的官方API,接下来我简单介绍一下我的理解。
保存与读取数据全靠下面这个类实现:
class tf.train.Saver
当我们需要存储数据时,下面2条指令就够了
saver = tf.train.Saver() save_path = saver.save(sess, model_path)
解释一下,首先创建一个saver类,然后调用saver的save方法(函数),save需要传递两个参数,一个是你的训练session,另一个是文件存储路径,例如“/tmp/superNet.ckpt”,这个存储路径是可以包含文件名的。save方法会返回一个存储路径。当然,save方法还有别的参数可以传递,这里不再介绍。
然后怎么读取数据呢?看下面
saver = tf.train.Saver() load_path = saver.restore(sess, model_path)
和存储数据神似啊!不再赘述。
下面是重点!关于tf.train.Saver()使用的几点小心得!
1、save方法在实现数据读取时,它仅仅读数据,关键是得有一些提前声明好的variables来接受这些数据,因此,当save读取数据到sess时,需要提前声明与数据匹配的variables,否则程序就报错了。
2、save读取的数据不需要initialize。
3、目前想到的就这么多,随时补充。
为了对数据存储和读取有更直观的认识,我自己写了两个实验小程序,下面是第一个,训练网络并存储数据,用的MNIST数据集
import tensorflow as tf
import sys
# load MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data', one_hot=True)
# 一些 hyper parameters
activation = tf.nn.relu batch_size = 100
iteration = 20000
hidden1_units = 30
# 注意!这里是存储路径!
model_path = sys.path[0] + '/simple_mnist.ckpt'
X = tf.placeholder(tf.float32, [None, 784]) y_ = tf.placeholder(tf.float32, [None, 10]) W_fc1 = tf.Variable(tf.truncated_normal([784, hidden1_units], stddev=0.2)) b_fc1 = tf.Variable(tf.zeros([hidden1_units])) W_fc2 = tf.Variable(tf.truncated_normal([hidden1_units, 10], stddev=0.2)) b_fc2 = tf.Variable(tf.zeros([10]))
def inference(img): fc1 = activation(tf.nn.bias_add(tf.matmul(img, W_fc1), b_fc1)) logits = tf.nn.bias_add(tf.matmul(fc1, W_fc2), b_fc2)
return logits
def loss(logits, labels): cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits, labels) loss = tf.reduce_mean(cross_entropy)
return loss
def evaluation(logits, labels):
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
return accuracy logits = inference(X) loss = loss(logits, y_) train_op = tf.train.AdamOptimizer(1e-4).minimize(loss) accuracy = evaluation(logits, y_)
# 先实例化一个Saver()类saver = tf.train.Saver() init = tf.initialize_all_variables()
with tf.Session() as sess: sess.run(init) for i in xrange(iteration): batch = mnist.train.next_batch(batch_size)
if i%1000 == 0 and i: train_accuracy = sess.run(accuracy, feed_dict={X: batch[0], y_: batch[1]})
print "step %d, train accuracy %g" %(i, train_accuracy)
sess.run(train_op, feed_dict={X: batch[0], y_: batch[1]})
print '[+] Test accuracy is %f' % sess.run(accuracy, feed_dict={X: mnist.test.images, y_: mnist.test.labels})
# 存储训练好的variables
save_path = saver.save(sess, model_path)
print "[+] Model saved in file: %s" % save_path
接下来是读取数据并做测试!
import tensorflow as tf
import sys
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data', one_hot=True)
activation = tf.nn.relu hidden1_units = 30
model_path = sys.path[0] + '/simple_mnist.ckpt'
X = tf.placeholder(tf.float32, [None, 784]) y_ = tf.placeholder(tf.float32, [None, 10]) W_fc1 = tf.Variable(tf.truncated_normal([784, hidden1_units], stddev=0.2)) b_fc1 = tf.Variable(tf.zeros([hidden1_units])) W_fc2 = tf.Variable(tf.truncated_normal([hidden1_units, 10], stddev=0.2)) b_fc2 = tf.Variable(tf.zeros([10]))
def inference(img): fc1 = activation(tf.nn.bias_add(tf.matmul(img, W_fc1), b_fc1)) logits = tf.nn.bias_add(tf.matmul(fc1, W_fc2), b_fc2)
return logits
def evaluation(logits, labels): correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
return accuracy logits = inference(X) accuracy = evaluation(logits, y_) saver = tf.train.Saver()with tf.Session() as sess:
# 读取之前训练好的数据 load_path = saver.restore(sess, model_path)
print "[+] Model restored from %s" % load_path
print '[+] Test accuracy is %f' % sess.run(accuracy, feed_dict={X: mnist.test.images, y_: mnist.
原文链接:https://www.jianshu.com/p/83fa3aa2d0e9
查阅更为简洁方便的分类文章以及最新的课程、产品信息,请移步至全新呈现的“LeadAI学院官网”:
www.leadai.org
请关注人工智能LeadAI公众号,查看更多专业文章
大家都在看
LSTM模型在问答系统中的应用
基于TensorFlow的神经网络解决用户流失概览问题
最全常见算法工程师面试题目整理(一)
最全常见算法工程师面试题目整理(二)
TensorFlow从1到2 | 第三章 深度学习革命的开端:卷积神经网络
装饰器 | Python高级编程
今天不如来复习下Python基础
存储Tensorflow训练网络的参数相关推荐
- tensorflow 模型预训练后的参数restore finetuning
之前训练的网络中有一部分可以用到一个新的网络中,但是不知道存储的参数如何部分恢复到新的网络中,也了解到有许多网络是通过利用一些现有的网络结构,通过finetuning进行改造实现的,因此了解了一下关于 ...
- 深度网络的训练经验总结(参数篇)
(续我的上一篇博客)最早训练神经网络的时候完全什么技巧都不懂,能成功运行开源代码,并且看到loss下降就放心跑着了.随着对网络越来越多的接触,发现从数据集(train/val/test)的准备到训 ...
- 使用预训练网络训练的两种方式:Keras Applications、TensorFlow Hub
日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) tensorflow 2.0 画出model网络模型的拓扑图 ...
- 基于FPGA的一维卷积神经网络CNN的实现(三)训练网络搭建及参数导出(附代码)
训练网络搭建 环境:Pytorch,Pycham,Matlab. 说明:该网络反向传播是通过软件方式生成,FPGA内部不进行反向传播计算. 该节通过Python获取训练数据集,并通过Pytorch框架 ...
- caffe框架训练网络参数详解
建立train_sh脚本文件 #!/sur/bin/env sh set -e /home/yourname/caffe/build/tools/caffe train --solver=/home/ ...
- 【深度学习】Weight Normalization: 一种简单的加速深度网络训练的重参数方法
前言:为什么要Normalization 深度学习是一种在给定数据的情况下,学习求解目标函数最小化或者最大化的模型.在深度网络中,模型参数往往包含了大量的weights和biases.在求解优化模型的 ...
- CVPR 2023 点云系列 | Point-NN无需训练的非参数、即插即用网络
CV - 计算机视觉 | ML - 机器学习 | RL - 强化学习 | NLP 自然语言处理 我们提出了一种用于 3D 点云分析的非参数网络 Point-NN,它由纯不可学习的组件组成:最远点采样 ...
- 使用PaddleFluid和TensorFlow训练RNN语言模型
专栏介绍:Paddle Fluid 是用来让用户像 PyTorch 和 Tensorflow Eager Execution 一样执行程序.在这些系统中,不再有模型这个概念,应用也不再包含一个用于描述 ...
- 将tensorflow训练好的模型移植到Android (MNIST手写数字识别)
将tensorflow训练好的模型移植到Android (MNIST手写数字识别) [尊重原创,转载请注明出处]https://blog.csdn.net/guyuealian/article/det ...
最新文章
- 《结对-贪吃蛇-设计文档》
- 求字典key的和python_python怎么将字典key相同的value值, 合并
- Hive 元数据库表信息
- ZZULIOJ 1128: 课程平均分
- TypeScript入坑
- 中国条码解码器市场趋势报告、技术动态创新及市场预测
- 线性回归的简洁实现(pytorch框架)
- 白鸦:我印象中的Keso
- 1.MATLAB简要介绍
- html关机命令,shutdown关机命令不起作用
- 昨天面了一位,见识到了Spring的天花板~
- 抖音自媒体是如何赚钱的,怎么做才能挣到更多的钱?
- 2020软件构造实验三
- 将ipad作为Windows10系统的的扩展显示屏
- Android 应用程序开发
- 深克隆和浅克隆的区别
- vivo手机便签如何快速彻底一键换机使用?
- css3实现向一个方向无缝连接滚动
- 2010-2019考研英语二 阅读真题+答案
- ROS机器人实践---小乌龟画圆
热门文章
- 北航 计算机学院 2011级学生会,北航学生会主席在2011级新生开学典礼发言稿.doc...
- mysql数据库搜索引擎要先进入_Mysql搜索引擎都有哪些区别
- 各纬度气候分布图_印度和中国都是季风气候显著的国家,但冬夏季风的强弱却完全不同...
- 批量修改linux换行格式,linux中sed命令批量修改
- poi生成word不可以修改_操作不懂技术就可以做小程序无限生成平台的创业项目实操教程...
- java.nio教程_Java NIO系列教程(三) Buffer
- python增量爬虫_python增量爬虫pyspider
- 大数据:技术与应用实践指南_大数据技术与应用社团 社会实践总结篇
- tshark mysql_使用tshark抓包分析http请求
- 中采购订单批导的bapi_五:认识SAP SD销售模式之第三方销售和单独采购