什么是EMA

指数移动平均(exponential moving average),也叫做权重移动平均(weighted moving average),可以用来估计变量的局部均值,使得变量的更新与一段时间内的历史取值有关。在采用 SGD 或者其他的一些优化算法 (Adam, Momentum) 训练神经网络时,通常会使用EMA的方法。 它的意义在于利用滑动平均的参数来提高模型在测试数据上的健壮性。(在SGD优化算法中,也会通过使用动量或者改变学习率的方式加快收敛速度)。

EMA公式

shadowVariable 为最后经过 EMA 处理后得到的参数值,Variable 为当前 epoch 轮次的参数值。EMA 对每一个待更新训练学习的变量 (variable) 都会维护一个影子变量 (shadow variable)。影子变量的初始值就是这个变量的初始值。由上述公式可知, decay 控制着模型更新的速度,越大越趋于稳定。实际运用中,通常会设为一个十分接近 1 的常数 (0.999 或 0.9999)。

EMA为什么可以提升模型性能

EMA可以使得模型在测试数据上更健壮,“采用随机梯度下降算法训练神经网络时,使用滑动平均在很多应用中都可以在一定程度上提高最终模型在测试数据上的表现。”

对神经网络边的的权重进行移动平均,得到对应的影子权重:shadow_weights,在训练过程中仍然使用不带滑动平均的权重(原始weights),以得到 weights 下一步更新的值,进而求下一步 weights 的影子权重 shadow_weights。在测试过程中,则使用影子权重代替原始weights,这样在测试数据上的效果更好,因为shadow_weights的更新更加平滑。

  • 随机梯度下降:更平滑的更新说明不会偏离最优点很远
  • batch gradient decent:影子变量作用可能不大,因为梯度下降的方向已经是最优的了,loss 一定减小
  • mini-batch gradient decent:可以尝试滑动平均,因为mini-batch gradient decent 对参数的更  新也存在抖动

pytorch实现 :

class EMA():def __init__(self, model, decay):self.model = modelself.decay = decayself.shadow = {}self.backup = {}def register(self):for name, param in self.model.named_parameters():if param.requires_grad:self.shadow[name] = param.data.clone()def update(self):for name, param in self.model.named_parameters():if param.requires_grad:assert name in self.shadownew_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]self.shadow[name] = new_average.clone()def apply_shadow(self):for name, param in self.model.named_parameters():if param.requires_grad:assert name in self.shadowself.backup[name] = param.dataparam.data = self.shadow[name]def restore(self):for name, param in self.model.named_parameters():if param.requires_grad:assert name in self.backupparam.data = self.backup[name]self.backup = {}# 初始化
ema = EMA(model, 0.999)
ema.register()# 训练过程中,更新完参数后,同步update shadow weights
def train():optimizer.step()ema.update()# eval前,apply shadow weights;eval之后,恢复原来模型的参数
def evaluate():ema.apply_shadow()# evaluateema.restore()

【深度学习】关于EMA:指数移动平均相关推荐

  1. TensorFlow学习--指数移动平均/tf.train.ExponentialMovingAverage

    时间序列模型 时间序列是指将同一统计指标的数值按其发生的时间先后顺序排列而成的数列.时间序列分析的主要目的是根据已有的历史数据对未来进行预测.处理与时间相关数据的方法叫做时间序列模型. 当一个平稳序列 ...

  2. 图说2016深度学习十大指数级增长

    转自:https://www.52ml.net/21402.html http://mp.weixin.qq.com/s?__biz=MzI3MTA0MTk1MA==&mid=26519906 ...

  3. EMA(指数移动平均)及其深度学习应用

    在深度学习中,经常会使用EMA(指数移动平均)这个方法对模型的参数做平均,以提高测试指标并增加模型鲁棒. 1.基于数学的介绍 1.1 公式例子 我们有关于"温度-天数"的数据 :在 ...

  4. 深度学习中EMA的使用场景

    目录 什么是EMA EMA在深度学习中的使用场景 实际代码比对-验证无法使用ema进行训练 使用实际值训练,使用ema测试,正常 使用ema测试与训练,accuracy异常 在复习<Tensor ...

  5. Python 计算EMA(指数移动平均线)

    总结 使用递归和循环两种方法来完成 python环境下循环相比于递归更快,更适应极端样本情况 递归 def _ema(arr,i=None):N = len(arr) α = 2/(N+1) #平滑指 ...

  6. python 移动平均线_Python 计算EMA(指数移动平均线)

    总结 使用递归和循环两种方法来完成 python环境下循环相比于递归更快,更适应极端样本情况 递归 def _ema(arr,i=None): N = len(arr) α = 2/(N+1) #平滑 ...

  7. EMA - 指数移动平均

    EMA 基本概念见 Wikipedia,本文不赘述. 基本公式 S[0] = Y[0] S[i] = Y[i] * alpha + S[i-1] * (1 - alpha) 其中 alpha 为平滑因 ...

  8. 【炼丹技巧】指数移动平均(EMA)【在一定程度上提高最终模型在测试数据上的表现(例如accuracy、FID、泛化能力...)】

    本文中心: 1.指数移动平均(Exponential Moving Average)EMA作用: ema不参与实际的训练过程,是用在测试过程的,相比对变量直接赋值而言,移动平均得到的值在图像上更加平缓 ...

  9. 【提分trick】SWA(随机权重平均)和EMA(指数移动平均)

    1. SWA随机权重平均 1.1步骤 1.2代码 2.EMA指数移动平均 2.1步骤 2.2代码 3.总结 在kaggle比赛中,不管是目标检测任务.语义分割任务中,经常能看到SWA(Stochast ...

最新文章

  1. 如何删除输入文本元素上的边框突出显示
  2. ultrascale和arm区别_[原创] Avnet Zynq UltraScale+MPSoC系列Ultra96开发方案
  3. 【Python】any() all() 用法
  4. 软件需求工程与UML建模第十二周作业
  5. 计算机老师教育叙事,信息技术教育叙事范文10篇 初中
  6. minimax算法_使用Minimax算法玩策略游戏
  7. LaTeX参考文献取消doi输出
  8. 出租车语音全自动服务器,出租车语音提示器工作原理
  9. 苹果6s最大屏幕尺寸_羡慕苹果3DTouch好用?安卓这个功能不比它差!
  10. 如何修改默认的FTP帐号或密码
  11. '.'和'..'还有'./'和'../'
  12. html5 手机模板 解放区,解放区异形模板
  13. python 将pcm编码文件转化为wav音频文件
  14. hihocoder 1272 买零食
  15. 论我的dfs经验总结
  16. qt mysql 不能创建表_Qt 数据库创建表失败原因之数据库关键字
  17. 斯坦福Introduction to NLP:第十讲关系抽取
  18. Scrum基础框架,快速配置Scrum自动化场景
  19. BNO55移植到STM32平台及其他单片机平台
  20. 禅道的安装及使用—以windows为例

热门文章

  1. 线程同步互斥机制--互斥锁
  2. 常用邮箱服务器地址、端口(POP3/SMTP)
  3. 文字识别(四)--大批量生成文字训练集
  4. 什么是Redis?为什么要用Redis?
  5. 什么是项目管理?项目经理应该如何进行管理?
  6. 《Microduino实战》——导读
  7. 滤波器基础03——Sallen-Key滤波器、多反馈滤波器与Bainter陷波器
  8. 默认连接电脑的模式为MTP
  9. MTP模式与USB存储模式(MTP in Android)
  10. vscode的下载与安装教程