tensorflow之ExponentialMovingAverage
tf.train.ExponentialMovingAverage
函数定义
tensorflow中提供了tf.train.ExponentialMovingAverage来实现滑动平均模型,他使用指数衰减来计算变量的移动平均值。
tf.train.ExponentialMovingAverage.__init__(self, decay, num_updates=None, zero_debias=False, name="ExponentialMovingAverage"):
decay是衰减率在创建ExponentialMovingAverage对象时,需指定衰减率(decay),用于控制模型的更新速度。影子变量的初始值与训练变量的初始值相同。当运行变量更新时,每个影子变量都会更新为:
shadowvariable=decay∗shadowvariable+(1−decay)∗variablenum_updates是ExponentialMovingAverage提供用来动态设置decay的参数,当初始化时提供了参数,即不为none时,每次的衰减率是:
min{decay,(1+num_updates)/(10+num_updates)}apply()方法添加了训练变量的影子副本,并保持了其影子副本中训练变量的移动平均值操作。在每次训练之后调用此操作,更新移动平均值。
average()和average_name()方法可以获取影子变量及其名称。
decay设置为接近1的值比较合理,通常为:0.999,0.9999等
实例代码如下:
v1 = tf.Variable(0, dtype=tf.float32) # 定义一个变量,初始值为0
step = tf.Variable(0, trainable=False) # step为迭代轮数变量,控制衰减率
ema = tf.train.ExponentialMovingAverage(0.99, step) # 初始设定衰减率为0.99
maintain_averages_op = ema.apply([v1]) # 更新列表中的变量
with tf.Session() as sess:init_op = tf.global_variables_initializer() # 初始化所有变量
sess.run(init_op)
print(sess.run([v1, ema.average(v1)])) # 输出初始化后变量v1的值和v1的滑动平均值
sess.run(tf.assign(v1, 5)) # 更新v1的值
sess.run(maintain_averages_op) # 更新v1的滑动平均值
print(sess.run([v1, ema.average(v1)]))
sess.run(tf.assign(step, 10000)) # 更新迭代轮转数step
sess.run(tf.assign(v1, 10))
sess.run(maintain_averages_op)
print(sess.run([v1, ema.average(v1)]))# 再次更新滑动平均值,
sess.run(maintain_averages_op)
print(sess.run([v1, ema.average(v1)]))# 更新v1的值为15
sess.run(tf.assign(v1, 15))sess.run(maintain_averages_op)
print(sess.run([v1, ema.average(v1)]))
#
# [0.0, 0.0]
# [5.0, 4.5]
# [10.0, 4.5549998]
# [10.0, 4.6094499]
# [15.0, 4.7133551]
计算步骤如下:
滑动平均模型的作用是提高测试值上的健壮性。那它是如何实现这个功能的呢?其实滑动平均模型的原理就是一阶滞后滤波法,其表达式如下:
上面的实例
**********************************************
输入 0.0
输出计算:
decay = min(0.99,(1+0)/(10+0)) =0.1
输出 = 0.1 * 0+(1-0.1)*0 = 0
**********************************************
输入 5.0
输出计算:
decay = min(0.99,(1+0)/(10+0)) =0.1
输出 = 0.1 * 0+(1-0.1)*5= 4.5
**********************************************
输入 10.0
输出计算:
decay = min(0.99,(1+10000)/(10+10000)) =0.99
输出 = 0.99 * 4.5+(1-0.99)*10= 4.555
**********************************************
输入 10.0
输出计算:
decay = min(0.99,(1+10000)/(10+10000)) =0.99
输出 = 0.99 * 4.555+(1-0.99)*15= 4.60945
**********************************************
输入 15.0
输出计算:
decay = min(0.99,(1+10000)/(10+10000)) =0.99
输出 = 0.99 * 4.60945+(1-0.99)*15= 4.713355
**********************************************
参考下面博客
https://blog.csdn.net/kuweicai/article/details/80517284
https://blog.csdn.net/qq_39521554/article/details/79028012
https://www.cnblogs.com/cloud-ken/p/7521609.html
tensorflow之ExponentialMovingAverage相关推荐
- tensorflow tf.train.ExponentialMovingAverage().variables_to_restore()函数 (用于加载模型时将影子变量直接映射到变量本身)
variables_to_restore函数,是TensorFlow为滑动平均值提供.之前,也介绍过通过使用滑动平均值可以让神经网络模型更加的健壮.我们也知道,其实在TensorFlow中,变量的滑动 ...
- tensorflow tf.train.ExponentialMovingAverage() (滑动平均模型)(移动平均法 Moving average,MA)(用于平滑数据波动对预测结果的影响)
tf.train.ExponentialMovingAverage 函数定义 tensorflow中提供了tf.train.ExponentialMovingAverage来实现滑动平均模型,他使用指 ...
- Tensorflow ExponentialMovingAverage 详解
tensorflow 中的 ExponentialMovingAverage 这时,再看官方文档中的公式: shadowVariable=decay∗shadowVariable+(1−decay)∗ ...
- 【TensorFlow】TensorFlow函数精讲之tf.train.ExponentialMovingAverage()
tf.train.ExponentialMovingAverage来实现滑动平均模型. 格式: tf.train.ExponentialMovingAverage(decay,num_step) 参数 ...
- tensorflow 滑动平均模型 ExponentialMovingAverage
____tz_zs学习笔记 滑动平均模型对于采用GradientDescent或Momentum训练的神经网络的表现都有一定程度上的提升. 原理:在训练神经网络时,不断保持和更新每个参数的滑动平均值, ...
- TensorFlow学习--指数移动平均/tf.train.ExponentialMovingAverage
时间序列模型 时间序列是指将同一统计指标的数值按其发生的时间先后顺序排列而成的数列.时间序列分析的主要目的是根据已有的历史数据对未来进行预测.处理与时间相关数据的方法叫做时间序列模型. 当一个平稳序列 ...
- tensorflow || 滑动平均的理解--tf.train.ExponentialMovingAverage
1 滑动平均的理解 滑动平均(exponential moving average),或者叫做指数加权平均(exponentially weighted moving average),可以用来估计变 ...
- TensorFlow学习笔记——实现经典LeNet5模型
TensorFlow实现LeNet-5模型 文章目录 TensorFlow实现LeNet-5模型 前言 一.什么是TensorFlow? 计算图 Session 二.什么是LeNet-5? INPUT ...
- TensorFlow练习16: 根据大脸判断性别和年龄
本帖使用TensorFlow做一个根据脸部推断照片人物年龄和性别的练习,网上有很多类似app. 训练数据 – Adience数据集 Adience数据集来源为Flickr相册,由用户使用iPhone或 ...
最新文章
- [CQOI2014]数三角形 组合数 + 容斥 + gcd
- shodan API 获取IP开放端口
- Windows7下安装LabelImg标注工具
- androidstudio表格中填充 宽跟长一样_Excel表格的基本操作教程,覆盖表格制作的10大知识!...
- PAT甲级题目翻译+答案 AcWing(动态规划)
- Some Fiori offline screenshot in Mac
- 娱乐项目和女朋友哪个重要?
- 1、Flutter Widget(IOS Style) - CupertinoApp;
- Android开发:Menu选项菜单
- 层层深入探究网络连接丢包之谜
- 下:比拼生态和未来,Spark和Flink哪家强?
- 推荐一款基于bootstrap的漂亮的前端模板——inspinia_admin(国内翻译的叫 H+后台主题UI框架)
- 计算机专业科研经费排名2015,2017中国大学科研经费排名
- Excel怎么批量设置图片大小
- JavaMail 发送邮件阻塞问题解决——设置 smtp 超时时间
- 数独问题(java)
- 趣图:五彩斑斓的黑,找到了
- win10如何使用WinSAT测试体验指数
- 如何避免拼多多售后?拼多多售后有哪些规则?
- data mining - 实用机器学习工具与技术 - 读书笔记( 一 )
热门文章
- KMS Server相关资料
- 使用java语言操作,如何来实现MySQL中Blob字段的存取
- VC++中多线程学习(MFC多线程)一(线程的创建、线程函数如何调用类成员呢?如何调用主对话框的成员?、MFC中的工作线程和界面线程的区别)
- html底部线条,这种APP底部横线+文字该怎么布局?css
- 2019.3.9日面试自我介绍
- vue中使用axios发送请求(二)
- 怎么做应力应变曲线_做了这么多年材料,这些力学性能测试你做对了吗?
- python用三种方式定义字符串、并依次输出_Python 字符串格式化输出的3种方式
- 使用php进行财务统计,基于php的基金财务数据接口调用代码实例
- 6年经验java笔试_不想搞Java了,6年经验去面试10分钟结束,现在Java面试为这么难...