滑动平均模型:

用途:用于控制变量的更新幅度,使得模型在训练初期参数更新较快,在接近最优值处参数更新较慢,幅度较小
方式:主要通过不断更新衰减率来控制变量的更新幅度

衰减率计算公式 :
    decay = min{init_decay , (1 + num_update) / (10 + num_update)}
    其中 init_decay 为设置的初始衰减率 ,num_update 为模型参数更新次数,由此可见,随着 num_update 更新次数的增加,(1 + num_update) / (10 + num_update 这一项的计算结果越接近1

参数更新公式:
    shadow_variable = decay * shadow_variable + (1 - decay) * variable
    其中 shadow_variable 为变量更新前的数值,variable为变量更新后的数值

例如:
    x = 0
    x = 1
    此时 shadow_variable 就是 0 , variable 就是 1 , 假如此时的 衰减率 decay 是 0.5,则更新后的 x 取值为 0.5 * 0 + (1 - 0.5) * 1 = 0.5

通过以上公式可以发现,随着模型迭代次数的增加,(1 + num_update) / (10 + num_update) 这一项的计算结果越接近1,也就是 (1 - decay) * variable 更接近于 0 ,此时模型参数变化幅度减小 , 也就是 shadow_variable == decay * shadow_variable 等式越成立。

tensorflow代码示例:

#coding:utf-8
"""Created by cheng star at 2018/8/26 10:46@email : xxcheng0708@163.com
"""import tensorflow as tfv1 = tf.Variable(0.0 , dtype=tf.float32)
step = tf.Variable(0 , trainable=False)ema = tf.train.ExponentialMovingAverage(decay=0.99 , num_updates=step)
# maintain_average_op 每执行一次,其中的变量就会被更新
maintain_average_op = ema.apply([v1])with tf.Session() as sess :init = tf.global_variables_initializer()sess.run(init)# 变量初始化之后,变量的数值和滑动平均值相同,均为 0print(sess.run([v1 , ema.average(v1)]))     # [0.0 , 0.0]sess.run(maintain_average_op)print(sess.run([v1 , ema.average(v1)]))     # [0.0 , 0.0]# 更新变量的赋值sess.run(tf.assign(v1 , 5))"""执行maintain_average_op 操作,此时 step = 0 , 使用公式 min{decay , (1 + num_update)/(10 + num_update)} 计算衰减率因此,decay衰减率是 min{init_decay = 0.99 , (1 + 0) / (10 + 0) = 0.1} = 0.1因此,此时的 v1 变量值是 0 * 0.1 + (1 - 0.1) * 5 = 4.5"""sess.run(maintain_average_op)print(sess.run([v1 , ema.average(v1)]))     # [5 , 4.5]sess.run(tf.assign(step , 1000))sess.run(tf.assign(v1 , 10))"""decay = min{0.99 , (1 + 1000)/(10 + 1000) = 0.99} = 0.99    衰减率不变v1 = 4.5 * 0.99 + (1 - 0.99) * 10 = 4.555"""sess.run(maintain_average_op)print(sess.run([v1 , ema.average(v1)]))     # [10.0, 4.5549998]"""decay = min{0.99 , (1 + 1000)/(10 + 1000) = 0.99} = 0.99    衰减率不变v1 = 4.555 * 0.99 + (1 - 0.99) * 10 = 4.609"""sess.run(maintain_average_op)print(sess.run([v1 , ema.average(v1)]))     # [10.0, 4.6094499]

参考文献:Tensorflow实战Google深度学习框架. 才云科技Caicloud 郑泽宇 顾思宇 著

深度学习中滑动平均模型的作用、计算方法及tensorflow代码示例相关推荐

  1. 全民 Transformer (一): Attention 在深度学习中是如何发挥作用的

    <Attention 在深度学习中是如何发挥作用的:理解序列模型中的 Attention>    Transformer 的出现让 Deep Learning 出现了大一统的局面.Tran ...

  2. 深度学习中的 Batch_Size的作用

    Batch_Size(批尺寸)是机器学习中一个重要参数,涉及诸多矛盾,下面逐一展开. 首先,为什么需要有 Batch_Size 这个参数? Batch 的选择,首先决定的是下降的方向.如果数据集比较小 ...

  3. 深度学习中一些注意力机制的介绍以及pytorch代码实现

    文章目录 前言 注意力机制 软注意力机制 代码实现 硬注意力机制 多头注意力机制 代码实现 参考 前言 因为最近看论文发现同一个模型用了不同的注意力机制计算方法,因此懵了好久,原来注意力机制也是多种多 ...

  4. 深度学习中Flatten层的作用

    Flatten层的实现在Keras.layers.core.Flatten()类中. 作用: Flatten层用来将输入"压平",即把多维的输入一维化,常用在从卷积层到全连接层的过 ...

  5. 深度学习中常用的学习率衰减策略及tensorflow实现

    目录 引言 (1)分段常数衰减 (2)指数衰减 (3)自然指数衰减 (4)多项式衰减 (5)余弦衰减 (6)线性余弦衰减 (7)噪声线性余弦衰减 (8)倒数衰减 引言 学习率(learning rat ...

  6. 机器学习/深度学习中的常用损失函数公式、原理与代码实践(持续更新ing...)

    诸神缄默不语-个人CSDN博文目录 最近更新时间:2023.5.8 最早更新时间:2022.6.12 本文的结构是首先介绍一些常见的损失函数,然后介绍一些个性化的损失函数实例. 文章目录 1. 分类 ...

  7. 深度学习中的反向卷积

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自:opencv学堂 图像卷积最常见的一个功能就是输出模糊( ...

  8. sigmoid函数_深度学习中激活函数总结

    一.前言 前段时间通过引入新的激活函数Dice,带来了指标的提升,借着这个机会,今天总结下常用的一些激活函数. 激活函数在深度学习中起着非常重要的作用,本文主要介绍下常用的激活函数及其优缺点.主要分为 ...

  9. 卷积在深度学习中的作用(转自http://timdettmers.com/2015/03/26/convolution-deep-learning/)...

    卷积可能是现在深入学习中最重要的概念.卷积网络和卷积网络将深度学习推向了几乎所有机器学习任务的最前沿.但是,卷积如此强大呢?它是如何工作的?在这篇博客文章中,我将解释卷积并将其与其他概念联系起来,以帮 ...

最新文章

  1. 自动化测试(二) 单元测试junit的Test注解突然不能使用原因以及解决方案
  2. Qt Creator中常用快捷键和小技巧
  3. ICLR 2022 | 在注意力中重新思考Softmax,商汤提出cosFormer实现多项SOTA
  4. PHP3d地球,three.js绘制地球、飞机与轨迹的效果示例
  5. python表格写操作单元格合并
  6. 解决eclipse 端口被占用问题
  7. Find Any File for Mac(文件搜索软件)
  8. java语言 跨平台_Java语言不一定就跨平台
  9. html中如何调整图片的对比色,风光照片如何调出冷暖对比色?后期案例分享
  10. 人工智能的历史(History of artificial intelligence)
  11. 奇葩报错之返回值为 -1073741515 (0xc0000135) ‘未找到依赖 dll‘
  12. 海豚选房获中视银宗基金300万元天使轮融资,专注司法拍房
  13. node php v2ex,仿V2EX开源二次元论坛程序+安装教程
  14. 【动态规划】多重背包问题
  15. UUID去横杠的5种方式
  16. Unity导航寻路系统插件--A* Pathfinding Project
  17. ImageJ 用户手册——第三部分(ImageJ扩展)
  18. 政务大数据系列9:政务大数据的价值链
  19. python抓取直播源 并更新_Python爬虫实例(二)使用selenium抓取斗鱼直播平台数据...
  20. 000031中粮地产:持有大量券商股权的地产新锐

热门文章

  1. 专门打游戏的手机精选:rog3散热好 续航好 玩游戏更好!
  2. 用伪类添加翘边阴影::before和::after
  3. BGP Confederation(BGP联邦)
  4. 超详细的wireshark笔记(2)-wireshark的使用技巧
  5. 开源是不是程序员悲剧的根源?
  6. 提问:微信网页授权到第三方调用错误、调用微信公众号扫码登陆错误、微信SCOP权限错误或没有权限
  7. 永恒之蓝复现(win7/2008)
  8. 报告老板,我们的H5页面在iOS11系统上白屏了!
  9. Windows cmd卸载程序
  10. 计算机专业英语教程(第二版)Chapter 4 Database Fundamentals