文章目录

  • 相关链接
  • 基础
  • 思路
  • 主要内容
    • 概括
    • SWA图示
    • SWA算法
      • LR
      • The Algorithm
      • Batch normalization
  • 在PyTorch中使用swa
    • 最佳实践
    • Demo

最近在参加公司的AI竞赛,刚好用到了Stochastic Weight Averaging的方法,所以也简单看了下提出这个方法的论文Averaging Weights Leads to Wider Optima and Better Generalization,这是一种容易实现、简单、基本没有额外计算开销却能比较可观地提升DNN模型效果的方法,在这里写写自己对这篇论文的一些理解。

相关链接

paper:https://arxiv.org/abs/1803.05407

code:https://github.com/timgaripov/swa

pytorch:https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/

基础

Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs

swa这篇论文基本上是以上面的论文为基础,而且作者是一样的,是一脉相承的工作,所以上面的论文算是前作,其主要的发现和贡献有以下两点:

  • local optima found by SGD can be connected by simple curves of near constant loss(由SGD找到的局部最优解可以被近似恒定损失的简单曲线连接起来,这点很重要,后面可以看到swa其实就是对在这样的区域上探索得到的权重进行平均)

  • Fast Geometric Ensembling (FGE,根据上面的发现提出的一种集成方法)

    sample multiple nearby points in weight space to create high performing ensembles in the time required to train a single DNN(这种集成方法的主要思想是在训练单个DNN所需的时间内从权重空间中采样多个相近的点,用于创建高性能的模型集合)

思路

通过研读论文,我猜测其写作的大概思路是这样:

  1. FGE洞察train loss和test error几何平面,发现对FGE采样的模型的权重进行平均可以得到更好的模型权重

  2. 可否直接在SGD过程中得到这些点并进行平均

    We show that SGD with cyclical [e.g., Loshchilov and Hutter, 2017] and constant learning rates traverses regions of weight space corresponding to high-performing networks. We find that while these models are moving around this optimal set they never reach its central points. We show that we can move into this more desirable space of points by averaging the weights proposed over SGD iterations.

    简单地说,使用循环学习率调度或者固定学习率的SGD能够遍历和高性能网络相关联的权重空间区域,但是只是在这个区域周围移动而没有到达中心点。通过对SGD迭代产生的权重进行平均可以解决问题。

  3. 同样是较优的集合中的点,为什么靠近中心的点会有更好的泛化性能呢?

    • We demonstrate that SWA leads to solutions that are wider than the optima found by SGD. Keskar et al. [2017] and Hochreiter and Schmidhuber [1997] conjecture that the width of the optima is critically related to generalization. We illustrate that the loss on the train is shifted with respect to the test error. We show that SGD generally converges to a point near the boundary of the wide flat region of optimal points. SWA on the other hand is able to find a point centered in this region, often with slightly worse train loss but with substantially better test error.

    • We show that the loss function is asymmetric in the direction connecting SWA with SGD. In this direction, SGD is near the periphery of sharp ascent. Part of the reason SWA improves generalization is that it finds solutions in flat regions of the training loss in such directions.

    重点就是损失平面上平坦区域的解具有更好的泛化性能。

  4. 跟FGE的关系

    While FGE ensembles [Garipov et al., 2018] can be trained in the same time as a single model, test predictions for an ensemble of k models requires k times more computation. We show that SWA can be interpreted as an approximation to FGE ensembles but with the test-time, convenience, and interpretability of a single model.

    SWA可以解释为对FGE集成方法的近似,但是具有单个模型的测试时间、便利性和可解释性。

主要内容

概括

We emphasize that SWA is finding a solution in the same basin of attraction as SGD, as can be seen in Figure 1, but in a flatter region of the training loss. SGD typically finds points on the periphery of a set of good weights. By running SGD with a cyclical or high constant learning rate, we traverse the surface of this set of points, and by averaging we find a more centred solution in a flatter region of the training loss. Further, the training loss for SWA is often slightly worse than for SGD suggesting that SWA solution is not a local optimum of the loss.

The name SWA has two meanings:

  • it is an average of SGD weights

  • with a cyclical or constant learning rate, SGD proposals are approximately sampling from the loss surface of the DNN, leading to stochastic weights.

SWA图示

Illustrations of SWA and SGD with a Preactivation ResNet-164 on CIFAR-1001. Left: test error surface for three FGE samples and the corresponding SWA solution (averaging in weight space). Middle and Right: test error and train loss surfaces showing the weights proposed by SGD (at convergence) and SWA, starting from the same initialization of SGD after 125 training epochs.

SWA算法

LR

SWA is making use of multiple samples gathered through exploration of the set of points corresponding to high performing networks.** To enforce exploration we run SGD with constant or cyclical learning rates.**(为了加强与高性能网络相对应的权重集合的探索,以恒定或周期性的学习率执行SGD)

下面是两种学习率策略的计算公式:

  • cyclical learning rate schedule: linearly decrease the learning rate from α1\alpha_1α1​ to α2\alpha_2α2​(在每个周期内线性地将学习率从α1\alpha_1α1​减少到α2\alpha_2α2​,一个周期为ccc个epoch):

    α(i)=(1−t(i))α1+t(i)α2,t(i)=1c(mod(i−1,c)+1)\begin{aligned} \alpha(i) &=(1-t(i)) \alpha_{1}+t(i) \alpha_{2}, \\ t(i) &=\frac{1}{c}(\bmod (i-1, c)+1) \end{aligned} α(i)t(i)​=(1−t(i))α1​+t(i)α2​,=c1​(mod(i−1,c)+1)​

  • constant learning rate schedule:

    α(i)=α1\alpha(i)=\alpha_1 α(i)=α1​

When using a cyclical learning rate we capture the models wiw_iwi​ that correspond to the minimum values of the learning rate. For constant learning rates we capture models at each epoch. Next, we average the weights of all the captured networks wiw_iwi​ to get our final model wSWAw_{SWA}wSWA​.

对于两种学习率调度策略,用于统计平均权重的模型是不同的。对于周期性的学习率而言,会使用对应于最小学习率的模型,也就是每个周期中最后一个epoch产生的模型;而对于恒定的学习率则比较简单,每个epoch的模型都会被用于计算平均权重。

The Algorithm

算法还是比较简单,就是对SGD产生的模型权重做等权重平均,这里就不做过多解释了,请看下面的算法步骤:

Batch normalization

If the DNN uses batch normalization, we run one additional pass over the data, as in Garipov et al. [2018], to compute the running mean and standard deviation of the activations for each layer of the network with wSWAw_{SWA}wSWA​ weights after the training is finished, since these statistics are not collected during training. For most deep learning libraries, such as PyTorch or Tensorflow, one can typically collect these statistics by making a forward pass over the data in training mode.

最后需要注意的一点是,如果神经网络中用到了BN层,则需要在训练数据上再做一次额外的前向传播,用于计算BN层的均值、标准差这些统计信息,因为对于使用wSWAw_{SWA}wSWA​权重的网络,在训练过程中是没有收集这些统计信息的。

在PyTorch中使用swa

最佳实践

上面是实践当中验证过的最佳的使用swa的方式,前面75%的训练时间使用标准的学习率衰减策略,后面的25%的训练实践使用比较高的恒定学习率,而最终的swa模型权重是由最后的25%训练时间中每个epoch得到的模型权重计算平均得到的。

Demo

下面是在PyTorch中使用swa的示例

from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLRloader, optimizer, model, loss_fn = ...
swa_model = AveragedModel(model)
scheduler = CosineAnnealingLR(optimizer, T_max=100)
swa_start = 5
swa_scheduler = SWALR(optimizer, swa_lr=0.05)for epoch in range(100):for input, target in loader:optimizer.zero_grad()loss_fn(model(input), target).backward()optimizer.step()if epoch > swa_start:swa_model.update_parameters(model)swa_scheduler.step()else:scheduler.step()# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data
preds = swa_model(test_input)

Stochastic Weight Averaging (SWA) 随机权重平均相关推荐

  1. SWA(随机权重平均) for Pytorch

    Stochastic Weight Averaging for Pytorch 随机权重平均 一.什么是Stochastic Weight Averaging(SWA) 二.SWA与SGD的对比 三. ...

  2. 炼丹系列2: Stochastic Weight Averaging (SWA) Exponential Moving Average(EMA)

    这个系列将记录下本人平时在深度学习方面觉得实用的一些trick,可能会包括性能提升和工程优化等方面. 该系列的代码会更新到Github 炼丹系列1: 分层学习率&梯度累积 炼丹系列2: Sto ...

  3. SWA(随机权重平均)——一种全新的模型优化方法

    这两天被朋友推荐看了一篇热乎的新型优化器的文章,文章目前还只挂在arxiv上,还没发表到顶会上.本着探索的目的,把这个论文给复现了一下,顺便弥补自己在优化器方面鲜有探索的不足. 论文标题:Averag ...

  4. 模型泛化技巧“随机权重平均(Stochastic Weight Averaging, SWA)”介绍与Pytorch Lightning的SWA实现讲解

    文章目录 SWA简介 SWA公式 SWA常见参数 Pytorch Lightning的SWA源码分析 SWALR 参考资料 SWA简介 SWA,全程为"Stochastic Weight A ...

  5. 【提分trick】SWA(随机权重平均)和EMA(指数移动平均)

    1. SWA随机权重平均 1.1步骤 1.2代码 2.EMA指数移动平均 2.1步骤 2.2代码 3.总结 在kaggle比赛中,不管是目标检测任务.语义分割任务中,经常能看到SWA(Stochast ...

  6. 目标检测之五:随机权值平均(Stochastic Weight Averaging,SWA)---木有看懂

    随机权值平均(Stochastic Weight Averaging,SWA) 随机权值平均只需快速集合集成的一小部分算力,就可以接近其表现.SWA 可以用在任意架构和数据集上,都会有不错的表现.根据 ...

  7. SWA Object Detection随机权重平均【论文+代码】

    随机权重平均 摘要 Introduction SWA 实验部分 消融实验 摘要 您想在不增加推断成本和不改变检测器的情况下提高对象检测器的1.0 AP吗?让我们告诉您一个这样的秘方.这个秘方令人惊讶地 ...

  8. Stochastic Weight Averaging

    PyTorch从1.6.0版本以后开始支持Stochastic Weight Averaging. That is, after the conventional training of an obj ...

  9. SWA(Stochastic Weight Averaging)实验

    有论文说swa能涨分,那么我来实验一下 那么我将在cifar10数据集上进行实验 原理 论文地址:https://arxiv.org/pdf/2012.12645.pdf SGD倾向于收敛到loss的 ...

最新文章

  1. 一个“退学生”到CTO的逆袭之路
  2. React编写一个简易的评论区组件
  3. 让浏览器非阻塞加载javascript的几种方式
  4. 写缓冲器 + 无效队列,优化MESI协议的性能
  5. vim的配置安装和Python安装细节记录20190109
  6. 《HTML5触摸界面设计与开发》——1.4 神秘谷,是什么让触摸界面反应灵敏?...
  7. @NotEmpty@NotNull和@NotBlank的区别
  8. 《Head First》 MVC运用的设计模式
  9. 不参与,你怎么知道能有多刺激——一个币客与市场的深入对话
  10. jsp汽车4S店维修管理系统
  11. Python 计算思维训练——公式编程
  12. 霍纳法则(Horner Rule)--计算多项式的值
  13. 响铃:揭底滴滴们跨界营销“真相”,再教你玩一出好戏
  14. word用宏修改文档中图片大小
  15. grep 忽略大小写、忽略grep命令本身
  16. 《滴滴重MVVM框架Chameleon》架构篇读后感
  17. 企业如何建设网站之基础建站教程
  18. 当我谈秋招时,我谈些什么
  19. C语言实现超长整数减法
  20. python做飞机大战游戏单机_Python制作简易版飞机大战小游戏

热门文章

  1. django中的关联查询
  2. tab标签页-选项卡后边+后端所返数据的数量
  3. Linux自定义日志文件设置回滚(避免信息溢出)
  4. 基于工业4g网关的危化品运输车监控方案
  5. 计算机常用的四种加密方法,电脑常见的几种加密方法
  6. 消息推送技术干货:美团实时消息推送服务的技术演进之路
  7. PTA IP地址转换
  8. qt在表格中如何画线_如何在电子表格中的某单元格内画一根长线
  9. Oracle问题imp-10019:由于ORACLE错误12899而拒绝行
  10. Python——pyqt5的计算器(源码+打包)