主要内容

  • 什么是EMA?
  • 为什么EMA在测试过程中使用通常能提升模型表现?
  • Tensorflow实现
  • PyTorch实现
  • Refercences

什么是EMA?

滑动平均(exponential moving average),或者叫做指数加权平均(exponentially weighted moving average),可以用来估计变量的局部均值,使得变量的更新与一段时间内的历史取值有关。

滑动平均可以看作是变量的过去一段时间取值的均值,相比对变量直接赋值而言,滑动平均得到的值在图像上更加平缓光滑,抖动性更小,不会因为某次的异常取值而使得滑动平均值波动很大,如图 1所示。

假设我们得到一个参数θ\thetaθ在不同的 epoch 下的值
θ1,θ2,...,θt\theta_1,\theta_2,...,\theta_tθ1​,θ2​,...,θt​

当训练结束的θ\thetaθ的MovingAverage 就是:

vt=β∗vt−1+(1−β)∗vtv_t=\beta*v_{t-1}+(1-\beta)*v_tvt​=β∗vt−1​+(1−β)∗vt​

β\betaβ代表衰减率,该衰减率用于控制模型更新的速度。

Andrew Ng在Course 2 Improving Deep Neural Networks中讲到,ttt时刻变量vvv的滑动平均值大致等于过去1/(1−β)1/(1−\beta)1/(1−β)个时刻vvv值的平均。


图1 不同 β\betaβ 值做EMA的效果对比(天气预报数据)

当β\betaβ越大时,滑动平均得到的值越和vvv的历史值相关。如果β=0.9\beta=0.9β=0.9,则大致等于过去10个vvv值的平均;如果β=0.99\beta=0.99β=0.99,则大致等于过去100个vvv值的平均。(数学证明先省略,因为作者暂时没理解证明过程==)

滑动平均的好处:
  
占内存少,不需要保存过去10个或者100个历史vvv值,就能够估计其均值。(当然,滑动平均不如将历史值全保存下来计算均值准确,但后者占用更多内存和计算成本更高)

为什么EMA在测试过程中使用通常能提升模型表现?

滑动平均可以使模型在测试数据上更健壮(robust)。“采用随机梯度下降算法训练神经网络时,使用滑动平均在很多应用中都可以在一定程度上提高最终模型在测试数据上的表现。”

对神经网络边的权重 weights 使用滑动平均,得到对应的影子变量shadow_weights。在训练过程仍然使用原来不带滑动平均的权重 weights,以得到 weights 下一步更新的值,进而求下一步 weights 的影子变量 shadow_weights。之后在测试过程中使用shadow_weights 来代替 weights 作为神经网络边的权重,这样在测试数据上效果更好。因为 shadow_weights 的更新更加平滑,对于:

  • 随机梯度下降,更平滑的更新说明不会偏离最优点很远;
  • 梯度下降 batch gradient decent,影子变量作用可能不大,因为梯度下降的方向已经是最优的了,loss 一定减小;
  • mini-batch gradient decent,可以尝试滑动平均,因为mini-batch gradient decent 对参数的更新也存在抖动。

举例来说,设decay=0.999decay=0.999,直观理解,在最后的1000次训练过程中,模型早已经训练完成,正处于抖动阶段,而滑动平均相当于将最后的1000次抖动进行了平均,这样得到的权重会更加robust。

Tensorflow实现

TensorFlow 提供了 tf.train.ExponentialMovingAverage来实现滑动平均。

Example usage when creating a training model:

# Create variables.
var0 = tf.Variable(...)
var1 = tf.Variable(...)
# ... use the variables to build a training model...
...
# Create an op that applies the optimizer.  This is what we usually
# would use as a training op.
opt_op = opt.minimize(my_loss, [var0, var1])# Create an ExponentialMovingAverage object
ema = tf.train.ExponentialMovingAverage(decay=0.9999)with tf.control_dependencies([opt_op]):# Create the shadow variables, and add ops to maintain moving averages# of var0 and var1. This also creates an op that will update the moving# averages after each training step.  This is what we will use in place# of the usual training op.training_op = ema.apply([var0, var1])...train the model by running training_op...

There are two ways to use the moving averages for evaluations:

  • Build a model that uses the shadow variables instead of the variables.
    For this, use the average() method which returns the shadow variable
    for a given variable.
  • Build a model normally but load the checkpoint files to evaluate by using
    the shadow variable names. For this use the average_name() method. See
    the tf.train.Saver for more
    information on restoring saved variables.

Example of restoring the shadow variable values:

# Create a Saver that loads variables from their saved shadow values.
shadow_var0_name = ema.average_name(var0)
shadow_var1_name = ema.average_name(var1)
saver = tf.train.Saver({shadow_var0_name: var0, shadow_var1_name: var1})
saver.restore(...checkpoint filename...)
# var0 and var1 now hold the moving average values

PyTorch实现

PyTorch官方目前没有提供EMA的实现,不过自己实现也不会太复杂,下面提供一个网上大神的实现方法:

class EMA():def __init__(self, decay):self.decay = decayself.shadow = {}def register(self, name, val):self.shadow[name] = val.clone()def get(self, name):return self.shadow[name]def update(self, name, x):assert name in self.shadownew_average = (1.0 - self.decay) * x + self.decay * self.shadow[name]self.shadow[name] = new_average.clone()

使用方法,分为初始化、注册和更新三个步骤。

// init
ema = EMA(0.999)// register
for name, param in model.named_parameters():if param.requires_grad:ema.register(name, param.data)// update
for name, param in model.named_parameters():if param.requires_grad:ema.update(name, param.data) 

Refercences

[1]. 理解滑动平均(exponential moving average)

[2]. EMA 指数滑动平均原理和实现 (PyTorch)
[3]. tf.train.ExponentialMovingAverage

机器学习模型性能提升技巧:指数加权平均(EMA)相关推荐

  1. 提高机器学习模型性能的五个关键方法

    提高机器学习模型性能的五个关键方法 1. 数据预处理 2. 特征工程 3. 机器学习算法 4. 模型集成与融合 5. 数据增强 以下是各个方面的具体分析和方法: [ 说明:1.这里主要是各个关键方法的 ...

  2. R使用交叉验证(cross validation)进行机器学习模型性能评估

    R使用交叉验证(cross validation)进行机器学习模型性能评估 目录 R使用交叉验证(cross validation)进行机器学习模型性能评估

  3. 机器学习 模型性能评估_如何评估机器学习模型的性能

    机器学习 模型性能评估 Table of contents: 目录: Why evaluation is necessary?为什么需要评估? Confusion Matrix混淆矩阵 Accurac ...

  4. 在数据增强、蒸馏剪枝下ERNIE3.0分类模型性能提升

    在数据增强.蒸馏剪枝下ERNIE3.0模型性能提升 项目链接: https://aistudio.baidu.com/aistudio/projectdetail/4436131?contributi ...

  5. 机器学习模型性能评估(二):P-R曲线和ROC曲线

    上文简要介绍了机器学习模型性能评估的四种方法以及应用场景,并详细介绍了错误率与精度的性能评估方法.本文承接上文,继续介绍模型性能评估方法:P-R曲线和ROC曲线.                   ...

  6. 4个提高深度学习模型性能的技巧

    点击上方"AI遇见机器学习",选择"星标"公众号 原创干货,第一时间送达 深度学习是一个广阔的领域,但我们大多数人在构建模型时都面临一些共同的难题 在这里,我们 ...

  7. 一份风控模型性能提升秘籍奉上|附视频+实操(详版)

    最近,番茄星球课堂为大家带来了一次主题为"信贷风控拒绝演绎实战"的直播课盛宴,内容充实,干货满满! 课程分为两次专题展开,分别为<拒绝推论场景描述.方法介绍与案例分享> ...

  8. 第三代英特尔 至强 可扩展处理器(Ice Lake)和英特尔 深度学习加速助力阿里巴巴 Transformer 模型性能提升

    第三代英特尔® 至强® 可扩展处理器采用了英特尔10 纳米 + 制程技术.相比于第二代英特尔® 至强® 可扩展处理器,该系列处理器内核更多.内存容量和频率更高.阿里巴巴集团和英特尔的技术专家共同探索了 ...

  9. 使用学习曲线诊断机器学习模型性能

    学习曲线是模型学习性能随经验或时间变化的曲线. 学习曲线是机器学习中广泛使用的诊断工具,用于从训练数据集中增量学习算法.该模型可以在训练数据集和每次训练更新后的验证数据集上进行评估,并可以创建测试性能 ...

最新文章

  1. List集合去重的一种方法
  2. windows编程,消息函数中拦截消息的问题
  3. 接口 DataInput
  4. Asp.Net Web控件 (八)(TabControl 选项卡控件)
  5. windows设备管理器
  6. 中小型研发团队架构落地实践18篇,含案例、代码
  7. 面向对象的三大特性和五大原则
  8. drop user和drop user cascade的区别
  9. 单片机原理及应用复习
  10. php txt bom,使用 PHP 函数或者软件去除文件的 BOM 头字符 - 文章教程
  11. vue实现div高度可拖拽
  12. 服务器装系统鼠标键盘不能动,装系统鼠标键盘不能动
  13. windows照片查看器无法显示此图片问题
  14. python多个if_Python之条件判断/if嵌套/如何写嵌套代码
  15. day08---(05)课程大纲-章节和小节列表功能(接口)
  16. 处理txt文件下载下来以后,排版格式不对的问题
  17. swift - 不成文规定
  18. C语言递归之苹果分盘问题
  19. 什么是深拷贝和浅拷贝?以及怎么实现深拷贝和浅拷贝?
  20. 声音管理AudioManager

热门文章

  1. 【Autogluon】傻瓜式深度学习框架
  2. idea在plugins中搜不到插件MyBatisX
  3. 常见的VC Link错误
  4. 如何在谷歌地球上画路线或者运动轨迹?根据纬经高信息在谷歌地球Google earth中画运动轨迹,首先将Excel文件纬经高信息转换为.csv文件,再转换为.kml文件,最终在谷歌地球中显示。
  5. 西南民族大学第十届校赛(同步赛)ABCEHJM题解
  6. 几种常见的数据分区方法
  7. git小乌龟解决代码冲突
  8. fastdfs上传文件资料(PDF,视频,图片,FileCaseUtil,FileUploadUtil)并生成缩略图
  9. 免费的电脑监控软件有哪些?可以一直免费使用的
  10. 【WIN7系统不是万能滴】