variables_to_restore函数,是TensorFlow为滑动平均值提供。之前,也介绍过通过使用滑动平均值可以让神经网络模型更加的健壮。我们也知道,其实在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身。

1、滑动平均值模型文件的保存

import tensorflow as tfif __name__ == "__main__":v = tf.Variable(0.,name="v")#设置滑动平均模型的系数ema = tf.train.ExponentialMovingAverage(0.99)#设置变量v使用滑动平均模型,tf.all_variables()设置所有变量op = ema.apply([v])#获取变量v的名字print(v.name)#v:0#创建一个保存模型的对象save = tf.train.Saver()sess = tf.Session()#初始化所有变量init = tf.initialize_all_variables()sess.run(init)#给变量v重新赋值sess.run(tf.assign(v,10))#应用平均滑动设置sess.run(op)#保存模型文件save.save(sess,"./model.ckpt")#输出变量v之前的值和使用滑动平均模型之后的值print(sess.run([v,ema.average(v)]))#[10.0, 0.099999905]

上面的代码,是如何来保存一个滑动平均值的模型文件,之前有介绍过滑动平均值和模型文件的保存,所以这里就不再重复了。

2、滑动平均值模型文件的读取

v = tf.Variable(1.,name="v")
#定义模型对象
saver = tf.train.Saver({"v/ExponentialMovingAverage":v})
sess = tf.Session()
saver.restore(sess,"./model.ckpt")
print(sess.run(v))
#0.0999999

对于模型文件的读取,在上一篇博客中有介绍过,这里特别需要注意的一个地方就是,在使用tf.train.Saver函数中,所传递的模型参数是{“v/ExponentialMovingAverage”:v}而不是{“v”:v},如果你使用的是后面的参数,那么你得到的结果将是10而不是0.09,那是因为后者获取的是变量本身而不是影子变量。是不是感觉使用这种方式来读取模型文件的时候,还需要输入一大串的变量名称。

3、variables_to_restore函数的使用

v = tf.Variable(1.,name="v")
#滑动模型的参数的大小并不会影响v的值
ema = tf.train.ExponentialMovingAverage(0.99)
print(ema.variables_to_restore())
#{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
sess = tf.Session()
saver = tf.train.Saver(ema.variables_to_restore())
saver.restore(sess,"./model.ckpt")
print(sess.run(v))
#0.0999999

通过使用variables_to_restore函数,可以使在加载模型的时候将影子变量直接映射到变量的本身,所以我们在获取变量的滑动平均值的时候只需要获取到变量的本身值而不需要去获取影子变量。

引用文章:对TensorFlow中的variables_to_restore函数详解

tensorflow tf.train.ExponentialMovingAverage().variables_to_restore()函数 (用于加载模型时将影子变量直接映射到变量本身)相关推荐

  1. Gazebo加载模型时黑屏

    Gazebo加载模型时黑屏 1. 黑屏状态 2. 解决办法1 3. 解决办法2 1. 黑屏状态 Gazebo加载模型的时候会发现一直处于这种状态 这可能是因为model库加载不正确导致的 2. 解决办 ...

  2. tensorflow tf.train.ExponentialMovingAverage() (滑动平均模型)(移动平均法 Moving average,MA)(用于平滑数据波动对预测结果的影响)

    tf.train.ExponentialMovingAverage 函数定义 tensorflow中提供了tf.train.ExponentialMovingAverage来实现滑动平均模型,他使用指 ...

  3. Tensorflow学习(二)之——保存加载模型、Saver的用法

    1. Saver的背景介绍 我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试.Tensorflow针对这一需求提供了Saver类. Saver类 ...

  4. keras dense sigmoid_tf.keras一个存在自定义层时加载模型时的小坑

    前言 Tensorflow在现在的doc里强推Keras,用过之后感觉真的很爽,搭模型简单,模型结构可打印,瞬间就能train起来不用自己写get_batch和evaluate啥的,跟用原生tenso ...

  5. pytorch加载模型时出现.....ckpt_100.pth is a zip archive (did you mean to use torch.jit.load()?)

    在测试加载训练好的模型时出现上方问题,参考这篇文章,原因是训练和测试的torch版本不一致. 训练的时候是1.6,测试的时候是1.2,因此需要先在1.6版本下加载模型,重新保存,在保存的时候设置use ...

  6. tensorflow tf.train.Saver.restore() (用于下次训练时恢复模型)

    # 保存当前的Session到文件目录tf.train.Saver().save(sess, 'net/my_net.ckpt') # 然后在下次训练时恢复模型: tf.train.Saver().r ...

  7. tf.train.ExponentialMovingAverage

    http://blog.csdn.net/uestc_c2_403/article/details/72235334 tf.train.ExponentialMovingAverage(decay, ...

  8. tensorflow 加载模型

    训练模型 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt money=np.array([[109 ...

  9. tensorflow中保存模型、加载模型做预测(不需要再定义网络结构)

    下面用一个线下回归模型来记载保存模型.加载模型做预测 参考文章: http://blog.csdn.net/thriving_fcl/article/details/71423039 训练一个线下回归 ...

最新文章

  1. 一站式论文提升服务,助您顺利发文章!
  2. 网页css样式中英对照,css中文样式(含中英文对照表).doc
  3. python(16)-列表list,for循环
  4. lvs工作在第几层_LVS 原理(调度算法、四种模式、四层负载均衡和七层 的区别)...
  5. iOS 后台语音播报功能开发过程中的那些坑
  6. 开源 非开源_开源周中的女性
  7. 黑马程序员C++学习笔记(第二阶段核心:面向对象)(二)
  8. C++ gdb core调试 崩溃日志 都是问号??
  9. 基于C#窗体的酒店管理系统
  10. R 计算时间序列自相关性教程
  11. 最详细的选型攻略!选择工业相机必须搞懂这10大要素!(建议收藏)
  12. Proxmark3系列教程1——PM3用法
  13. WebView下载文件
  14. Java中Springboot实战之签到功能详解(超全面)
  15. 文件转换-----(类型,格式)
  16. bzoj 4987 Tree
  17. 超详细的MySQL基本操作
  18. Android百度地图POI检索无标记显示问题
  19. 人工智能七大应用领域!你难道还没真香吗?
  20. 现代化医院PACS/RIS系统概述

热门文章

  1. php对称字符串,PHP实现简单的对称加密和解密方法 - str_split
  2. php7-sapnwrfc
  3. SAP-MM-移动类型解析之收货03--退货
  4. Smartforms 设置纸张打印格式
  5. 定义工厂(Plant)
  6. Excel单元格里面提取或去掉某些字符
  7. ALEIDoc EDI(5)--Inbound Function
  8. 跨工厂物料状态/特定工厂的物料状态
  9. SAP 选择屏幕的收起与展开(Collapse and Expand)
  10. 数据从程序中传入到form中