本文转载自微信公众号[机器学习炼丹术]
https://blog.csdn.net/qq_34107425/article/details/107722503

这两天被朋友推荐看了一篇热乎的新型优化器的文章,文章目前还只挂在arxiv上,还没发表到顶会上。本着探索的目的,把这个论文给复现了一下,顺便弥补自己在优化器方面鲜有探索的不足。

论文标题:Averaging Weights Leads to Wider Optima and Better Generalization

论文链接:https://arxiv.org/abs/1803.05407.pdf

官方代码:https://github.com/timgaripov/swa


随机权重平均(SWA)和随机梯度下降(SGD)有很多相似之处,准确来说SWA是基于SGD的改进版本。文章的主要比较对象也是SGD。

。典型的深度神经网络训练过程就是用 SGD 来优化一个损失函数,同时使用一个衰减的学习率,直到收敛为止。SWA 论文的结果证明,取 SGD 轨迹的多点简单平均值,以一个周期或者不变的学习 率,会比传统训练有更好的泛化效果。论文的结果同样了证明了,随机权重平均 (SWA)相比可以找到更广的最优值域。

SWA的灵感来自于实际观察,每次学习率循环结束时产生的局部最小值趋向于在损失面的边缘区域累积,这些边缘区域上的损失值较小(上面左图中,显示低损失的红色区域上的点W1,W2和W3)。通过对几个这样的点取平均,很有可能得到一个甚至更低损失的、全局化的通用解(上面左图上的Wswa)。

看一下swa的算法流程:

对 WSWA做了一个周期性的滑动平均,周期为 c,即每 c 步进行一次滑动平均, 其他时间就按普通 SGD 进行更新。可见,计算操作也是十分的简便。

就论文中的结果而言,可以看出 SWA相对于 SGD 来说有以下优点:1,不依赖学 习率的变化,设置恒定学习率的时候依然可以达到很好的收敛效果。而对于 SGD 而言,恒定学习率会导致收敛情况变差。2,收敛速度十分快,在原测试集上可 以在 150 个 epoch 就收敛得非常好,也非常平稳振荡幅度非常小。

个人主观分析:SWA 加入了周期性滑动平均操作来限制权重的变化,解决 了传统 SGD 在反向过程中的权重振荡问题。 SGD 是依靠当前 batch 的数据来更新 参数,每一个 epoch 都会调整一次参数,随机挑选的梯度方向极有可能不是最佳 梯度方向,甚至与最佳梯度方向有一个很大的夹角,这样大刀阔斧调整的参数, 极其容易振荡。而 SWA限制更新频率,对周期内参数取滑动均值,这样就解决 了 SGD 的问题。


关于复现

这次并没有用作者给的官方代码实现,而是用keras框架搭的一个自己的网络,用自己的数据集,来比较SGD和SWA的差异。代码暂时不方便上传,不过估计也没人care。。。不过对于SWA优化器的keras实现参考了https://github.com/kristpapadopoulos/keras-stochastic-weight-averaging

先看SGD的结果:

train loss和val loss大相径庭的原因跟数据集有关系,不用太care,主要关注的是train loss。

再看SWA的结果:

对 SGD 而言,每个 epoch 都会大刀阔斧调整参数(因为我设定的是恒定学习率)。 所以 SGD 的 loss 曲线抖动会比较大,而且要接近 200 个 epoch 才有比较好的收 敛效果。 对于 SWA而言,抖动幅度会比 SGD 小很多,而且 cycle 设置越大抖动会越小。大 概 150 个 epoch 就有比较好的收敛情况。这也是作者分析的,SWA 可以更容易 找到最优值域。


总结

SGD 的出现就是为了解决 mini-BatchGD 计算比较费时的问题,然而 SGD 的收敛 结果不会太好,往往不会好于 mini-batchGD。SGD 为了计算节省时间,舍弃了训 练稳定、最优收敛结果等性质。然而,SWA 刚好补充了 SGD 的这几点不足,在 计算方面,SWA 相对于 SGD 的计算量的增长可以忽略不记,只多了一个周期性 的权重平均而已。但是在稳定性和最优值域的寻找方面,SWA 是要远好于 SGD 的。尤其是这种训练时的稳定性,使得 SWA 比 SGD 更适合用于在线学习等模型 要求稳定性的领域。

【机器学习的Tricks】随机权值平均优化器swa相关推荐

  1. 目标检测之五:随机权值平均(Stochastic Weight Averaging,SWA)---木有看懂

    随机权值平均(Stochastic Weight Averaging,SWA) 随机权值平均只需快速集合集成的一小部分算力,就可以接近其表现.SWA 可以用在任意架构和数据集上,都会有不错的表现.根据 ...

  2. PyTorch的损失函数和优化器

    文章目录 PyTorch的损失函数和优化器 损失函数 优化器 总结 PyTorch的损失函数和优化器 损失函数 一般来说,PyTorch的损失函数有两种形式:函数形式和模块形式.前者调用的是torch ...

  3. 使用什么优化器_优化器怎么选?一文教你选择适合不同ML项目的优化器

    选自lightly.ai 机器之心编译 编辑:小舟.杜伟 为机器学习项目选择合适的优化器不是一件简单的事. 优化器是深度学习领域的重要组成模块之一,执行深度学习任务时采用不同的优化器会产生截然不同的效 ...

  4. oracle stalestats_深入理解oracle优化器统计数据(Optimizer Statistics)

    理解oracle优化器统计数据 首先来介绍oracle数据库使用基于规则优化器(RBO)来决定如何执行一个sql语句.基于规则优化器顾名思义,它是遵循一组规则来判断一个sql语句的执行计划.这组规则是 ...

  5. mysql 优化器代码_Mysql查询优化器

    Mysql查询优化器 本文的目的主要是通过告诉大家,查询优化器为我们做了那些工作,我们怎么做,才能使查询优化器对我们的sql进行优化,以及启示我们sql语句怎么写,才能更有效率.那么到底mysql到底 ...

  6. mysql 优化器提示_Mysql查询优化器

    Mysql查询优化器 本文的目的主要是通过告诉大家,查询优化器为我们做了那些工作,我们怎么做,才能使查询优化器对我们的sql进行优化,以及启示我们sql语句怎么写,才能更有效率.那么到底mysql到底 ...

  7. Oracle 19C优化器中自动使用了filter操作

    文章涉及问题和Oracle 19c版本及新特性并无直接关系,团队正在进行有条不紊的Oracle19c升级工作,刚好此版本遇到问题,影响大且波及范围广,故记录之,以便大家有他山石可引鉴.话不多说,看正文 ...

  8. 3.6 权值初始化-机器学习笔记-斯坦福吴恩达教授

    权值初始化 0值初始化 在逻辑回归中,我们通常会初始化所有权值为 0 ,假如在如下的神经网络也采用 0 值初始化: 则可以得到: a1(1)=a2(2)a^{(1)}_1=a^{(2)}_2a1(1) ...

  9. 单目标应用:白鲸优化算法(Beluga whale optimization,BWO)优化双向长短时记忆BiLSTM的权值和阈值(提供MATLAB代码)

    一.算法简介 白鲸优化算法(Beluga whale optimization,BWO)由Changting Zhong等人于2022年提出,该算法模拟了白鲸游泳,觅食和"鲸鱼坠落" ...

最新文章

  1. Navicat for Oracle
  2. python3 _笨方法学Python_日记_DAY3
  3. 洛阳理工Linux实验报告,洛阳理工学院实验报告.doc
  4. 关于程序员面试的一点想法
  5. 阿里云插件新版发布,多特性助力提升开发者体验
  6. Python学习入门1:Python 新手入门引导
  7. 操作系统——进程的状态及转换
  8. Oracle01877,Cognos错误:RQP-DEF-0177 执行操作“sqlOpenResult”(状态为“-28”)时出错...
  9. python面试题之Python是如何进行内存管理的
  10. 能买?这款手机搭载联发科P60+32G,仅售399元
  11. Abaqus取消汉化(汉译英,英译汉)
  12. python 立方体切割块数_用参数化su计算立方体切割体积
  13. telnet php,使用php实现telnet功能
  14. 【渝粤题库】陕西师范大学202131组织行为学作业(高起本、专升本)
  15. RHEL8破解root密码
  16. 购买代购的产品算违法吗——看空姐代购被判刑有感
  17. exchange服务器维护,Exchange服务器之禁用和删除Exchange邮箱深入探讨
  18. Linux下dcm2niix使用
  19. Android RrecyclerView条目跳转到指定位置
  20. MySQL基础学习笔记(带目录)

热门文章

  1. win10彻底关闭电脑的自动更新
  2. 特此郑重声明!我的文章全部是原创作品!转载请注明出处!
  3. NVM 安装node.js后没有npm
  4. CubeMX生成的代码烧录一次后无法再烧录(识别)STM32,需按住reset后放开才能烧录
  5. 亚马逊首席技术官:2023年及未来五大技术趋势预测 | 美通社头条
  6. java的抽象方法_java抽象方法是什么
  7. STL的allocaotr
  8. 网络数据的背后——网络日志的分析指标
  9. 程序员只能吃“青春饭”?IT行业年龄焦虑如何破局?
  10. 元宇宙产业委联席秘书长叶毓睿:去中心化和去中介化的定义、区别,以及和元宇宙的关系