模型泛化技巧“随机权重平均(Stochastic Weight Averaging, SWA)”介绍与Pytorch Lightning的SWA实现讲解
文章目录
- SWA简介
- SWA公式
- SWA常见参数
- Pytorch Lightning的SWA源码分析
- SWALR
- 参考资料
SWA简介
SWA,全程为“Stochastic Weight Averaging”(随机权重平均)。它是一种深度学习中提高模型泛化能力的一种常用技巧。
其思路为:对于模型的权重,不直接使用最后的权重,而是将之前的权重做个平均。
该方法适用于深度学习,不限领域、不限Optimzer,可以和多种技巧同时使用。
SWA公式
我们的模型参数记为: θ = { w 0 , w 1 , w 2 , ⋯ , w n } \theta=\{w_0, w_1, w_2, \cdots, w_n\} θ={w0,w1,w2,⋯,wn}, n n n 为模型总参数量。
对于模型的训练,会在epoch结束后保存一个副本,第 t t t 个epoch的模型参数记为 θ t \theta_t θt。
则我们模型的最终参数为:
θ ˉ = 1 T ∑ t = 1 T θ t \bar{\theta} = \frac{1}{T} \sum^T_{t=1}\theta_t θˉ=T1t=1∑Tθt
其中 T T T 表示我们有 T T T 个不同个模型参数的副本。
该公式的意思就是将前面t个模型的权重取平均,然后作为最终的模型参数。
注意事项:
- 通常只在一个epoch结束后保存模型参数副本。
- 并不是每个epoch都要保存模型副本。通常会从模型开始很好地收敛后再开始保存模型参数副本。
SWA常见参数
通常我们在使用SWA时会有如下的超参数:
- SWA Start:从第几个epoch再开始保存模型副本。若在模型还不能很好的收敛时就开始保存模型参数副本,可能会损害模型的性能。
- SWA Learning Rate:在SWA期间采用学习率。例如,我们设置在第20个epoch开始进行SWA,则在第20个epoch后就会采用你指定的SWA Learning Rate,而不是之前的。
Pytorch Lightning的SWA源码分析
本节展示一下Pytorch Lightning中对SWA的实现,以便更清晰的认识SWA。
在开始看代码前,明确几个在Pytorch Lightning实现中的几个重要的概念:
- 平均模型(self._average_model):Pytorch Lightning会将平均的后的模型存入该变量中。
- pl_module:该变量为当前的模型。
class StochasticWeightAveraging(Callback):def __init__(self,swa_lrs: Union[float, List[float]], # swa的学习率# swa_epoch_start: 从第0.8位置的epoch开始,例如一共100个epoch,那就从第81个epoch开始swa。# 若指定整数,则会从指定的epoch开始swa。swa_epoch_start: Union[int, float] = 0.8, annealing_epochs: int = 10, # 模拟退火的epoch数。SWALR学习策略用的参数annealing_strategy: str = "cos", # 模拟退火策略。SWALR学习策略用的参数avg_fn: Optional[_AVG_FN] = None, # 平局函数,做模型参数平均时使用的函数,通常不需要指定。会使用默认的。device: Optional[Union[torch.device, str]] = torch.device("cpu"), # 平均后的model存在哪个device上):...def on_train_epoch_start(self, ...): # 在每个epoch开始前执行if (not self._initialized) and (self.swa_start <= trainer.current_epoch <= self.swa_end):# 初始化SWA,在整个SWA过程中只执行一遍self._initialized = True...# 使用原来的optimizeroptimizer = trainer.optimizers[0]...# 使用SWALR学习率策略(SWA Learning Scheduler),后面会讲self._swa_scheduler = cast(LRScheduler,SWALR(optimizer,swa_lr=self._swa_lrs, # type: ignore[arg-type]anneal_epochs=self._annealing_epochs,anneal_strategy=self._annealing_strategy,last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,),)# end if, 初始化代码结束。# 接下来是SWA在epoch开始前的处理逻辑if (self.swa_start <= trainer.current_epoch <= self.swa_end):# 在SWA期间,每个epoch开始前将当前的模型参数更新到“平均模型”上。self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn)if trainer.current_epoch == self.swa_end + 1:# 到最后结束的时候,将平均模型的参数迁移到模型上。self.transfer_weights(self._average_model, pl_module)@staticmethoddef update_parameters(average_model: "pl.LightningModule", model: "pl.LightningModule", n_averaged: Tensor, avg_fn: _AVG_FN) -> None:for p_swa, p_model in zip(average_model.parameters(), model.parameters()):device = p_swa.devicep_swa_ = p_swa.detach()p_model_ = p_model.detach().to(device)src = p_model_ if n_averaged == 0 else avg_fn(p_swa_, p_model_, n_averaged.to(device))p_swa_.copy_(src)n_averaged += 1@staticmethoddef avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor) -> Tensor:return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1)
从上述Pytorch Lightning对SWA实现的源码中我们可以获得以下信息:
- 使用SWA需要指定
SWA学习率
和从哪个epoch
开始这两个最重要的参数。 - 在开始SWA后,将会使用新的“swa_lrs”学习率和新的“SWALR”学习率策略。(但在“退火”期间,会参考模型原本的学习率)
- 每个epoch开始前,会把上一个epoch学习到的模型参数更新到“平均模型”上。
- SWA期间,使用的Optimizer和之前一样。例如你模型训练时用的是Adam,则SWA期间也用Adam。
SWALR
在上面我们提到了Pytorch Lightning实现中,在SWA期间使用的是SWALR。
SWALR使用的是“模拟退火”策略,简单来说就是:学习率是从原本的学习率逐渐过度到SWA学习率的。例如,原本你使用的学习率是0.1,指定的SWA学习率为0.01,从第20个epoch开始进行SWA。那么并不是到第20个epoch后学习率立刻从0.1变到0.01,而是从0.1逐渐过度到0.01,过度的epoch数就是指定的annealing_epochs
参数,而过度时减小的策略就是annealing_strategy
参数。
这里不使用难以理解的源码或数学,而是来通过几组实验来直观的观察一下SWALR策略下的学习率的变化来进行解释:
上述实验为:模型训练过程中学习率随epoch的变化,横坐标为epoch,纵坐标为这个epoch使用的学习率。其中图上的几个参数分别为:
- model_lr:模型一开始使用的学习率。
- swa_lr:用户指定的swa学习率
- swa_epoch_start:从第几个epoch开始swa
- annealing_epoch:模拟退火的epoch数
- annealing_strategy:模拟退火策略。目前仅支持“cos”和“linear”两种。
例如对于图一意思就是:模型一开始在Optimizer上指定的学习率是0.1,SWA学习率为0.001,从第2个epoch开始进行SWA,总共进行10(annealing_epochs
) 个epoch将学习率从0.1逐渐过度到0.001,学习率调整使用cos策略。
从上述图中很容易得出以下结论:
- 所谓的SWALR学习率策略就是让学习率从原来的学习率逐渐过度到swa学习率。过度的epoch数就是annealing_epoch
- 若你指定的swa学习率和之前的是一样的,那么SWALR相当于什么都没做。(图二)
- 若你指定的swa学习率比之前的学习率高,那么学习率就会逐渐升高(图三)。不过通常不会这么做,通常swa_lr要比model_lr小才对,因为到后面模型都稳定了,不能再用更高的学习率了。
- 若annealing_epoch数较小,那么“退火”速度较快,即从model_lr到swa_lr的过度速度就较快(图四),反正则慢。
- “cos”退火策略下学习率变化是先慢,然后快,最后再慢。(图五),而“linear”实现线性策略变化速度是一样的。(图六)
实验环境与代码如下:
lightning==2.0.1
pytorch==1.13.0
实验代码如下:
import torch
import torch.nn as nnimport lightning.pytorch as pl
from lightning.pytorch.callbacks import StochasticWeightAveragingfrom matplotlib import pyplot as pltimport numpy as npdef plot_swa_lr_curve(model_lr, # 模型的学习率swa_lr, # swa的学习率swa_epoch_start=2, # 从哪个epoch开始swaannealing_epochs=10, # 模拟退火的epoch数annealing_strategy='cos' # 模拟退火策略):lrs = []# 定义一个简单的模型,用于测试class SimpleModel(pl.LightningModule):def __init__(self):super(SimpleModel, self).__init__()self.linear = nn.Linear(1, 1)def training_step(self, batch, batch_idx, *args, **kwargs):return nn.functional.mse_loss(self.linear(torch.rand(4, 1)), torch.rand(4, 1))def configure_optimizers(self):# 使用model_lr作为测试模型的学习率return torch.optim.SGD(self.parameters(), lr=model_lr)# 重写一下StochasticWeightAveraging,用于记录学习率变化class MyStochasticWeightAveraging(StochasticWeightAveraging):def on_train_epoch_start(self, *args, **kwargs):super().on_train_epoch_start(*args, **kwargs)if hasattr(self._swa_scheduler, "_last_lr"):# 记录lr的变化lrs.append(self._swa_scheduler._last_lr[0])else:lrs.append(model_lr)# 定义trainer进行训练trainer = pl.Trainer(callbacks=[MyStochasticWeightAveraging(swa_lrs=swa_lr, swa_epoch_start=swa_epoch_start,annealing_epochs=annealing_epochs,annealing_strategy=annealing_strategy)],max_epochs=20,num_sanity_val_steps=0,enable_progress_bar=False, # Use custom progress baraccelerator='cpu',)# 训练模型trainer.fit(SimpleModel(), train_dataloaders=range(10))plt.plot(np.arange(1, len(lrs)+1).astype(dtype=np.str), lrs)plt.xlabel("epoch")plt.ylabel("learning rate")plt.text(0.7, 0.9, "model_lr: %s" % model_lr, fontsize=11, transform=plt.gca().transAxes)plt.text(0.7, 0.8, "swa_lr: %s" % swa_lr, fontsize=11, transform=plt.gca().transAxes)plt.text(0.6, 0.7, "swa_epoch_start: %s" % swa_epoch_start, fontsize=11, transform=plt.gca().transAxes)plt.text(0.6, 0.6, "annealing_epochs: %s" % annealing_epochs, fontsize=11, transform=plt.gca().transAxes)plt.text(0.6, 0.5, "annealing_strategy: %s" % annealing_strategy, fontsize=11, transform=plt.gca().transAxes) plt.show()print("lrs:", lrs) # 输出lr的变化return lrsplot_swa_lr_curve(0.1, 0.001)
参考资料
Averaging Weights Leads to Wider Optima and Better Generalization(原论文): https://arxiv.org/abs/1803.05407
PyTorch 1.6 now includes Stochastic Weight Averaging: https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/
模型泛化技巧“随机权重平均(Stochastic Weight Averaging, SWA)”介绍与Pytorch Lightning的SWA实现讲解相关推荐
- Stochastic Weight Averaging (SWA) 随机权重平均
文章目录 相关链接 基础 思路 主要内容 概括 SWA图示 SWA算法 LR The Algorithm Batch normalization 在PyTorch中使用swa 最佳实践 Demo 最近 ...
- 目标检测之五:随机权值平均(Stochastic Weight Averaging,SWA)---木有看懂
随机权值平均(Stochastic Weight Averaging,SWA) 随机权值平均只需快速集合集成的一小部分算力,就可以接近其表现.SWA 可以用在任意架构和数据集上,都会有不错的表现.根据 ...
- SWA(随机权重平均)——一种全新的模型优化方法
这两天被朋友推荐看了一篇热乎的新型优化器的文章,文章目前还只挂在arxiv上,还没发表到顶会上.本着探索的目的,把这个论文给复现了一下,顺便弥补自己在优化器方面鲜有探索的不足. 论文标题:Averag ...
- 【提分trick】SWA(随机权重平均)和EMA(指数移动平均)
1. SWA随机权重平均 1.1步骤 1.2代码 2.EMA指数移动平均 2.1步骤 2.2代码 3.总结 在kaggle比赛中,不管是目标检测任务.语义分割任务中,经常能看到SWA(Stochast ...
- SWA(随机权重平均) for Pytorch
Stochastic Weight Averaging for Pytorch 随机权重平均 一.什么是Stochastic Weight Averaging(SWA) 二.SWA与SGD的对比 三. ...
- SWA Object Detection随机权重平均【论文+代码】
随机权重平均 摘要 Introduction SWA 实验部分 消融实验 摘要 您想在不增加推断成本和不改变检测器的情况下提高对象检测器的1.0 AP吗?让我们告诉您一个这样的秘方.这个秘方令人惊讶地 ...
- 炼丹系列2: Stochastic Weight Averaging (SWA) Exponential Moving Average(EMA)
这个系列将记录下本人平时在深度学习方面觉得实用的一些trick,可能会包括性能提升和工程优化等方面. 该系列的代码会更新到Github 炼丹系列1: 分层学习率&梯度累积 炼丹系列2: Sto ...
- Stochastic Weight Averaging
PyTorch从1.6.0版本以后开始支持Stochastic Weight Averaging. That is, after the conventional training of an obj ...
- SWA(Stochastic Weight Averaging)实验
有论文说swa能涨分,那么我来实验一下 那么我将在cifar10数据集上进行实验 原理 论文地址:https://arxiv.org/pdf/2012.12645.pdf SGD倾向于收敛到loss的 ...
最新文章
- 微信小程序云开发图片上传完整代码附效果图
- linux下把进程绑定到特定cpu核上运行
- 五十三、Java的记录日志Log4j框架的使用
- web前端分享JavaScript到底是什么?特点有哪些?
- 会mysql不会sql_不是吧,不会有人还不知道MySQL中具实用的SQL语句
- Hadoop平台搭建
- 亚马逊首席科学家:揭秘 Alexa 语音识别技术|AI NEXT
- python代码块使用缩进来表示_python 基础语法
- 在Intellij idea中快速重写父类方法
- 删除-驱动人生节能省电方案
- JavaScript 05
- 跨域问题 Failed to load http://xxxx/xxx/xxx/xx/xxx: No ‘Access-Control-Allow-Or
- 安卓手机怎么投屏台式计算机WIN7,手机怎么投屏到win7电脑
- python使用 photoshop-python-api 调用ps处理批量动作操作
- Unity - Timeline 之 Trimming clips(裁剪剪辑)
- 上班,老实人和精明人的区别是什么?
- A1008 Elevator (20 分)
- 高速ADC的关键指标:量化误差、offset/gain error、DNL、INL、ENOB、分辨率、RMS、SFDR、THD、SINAD、dBFS、TWO-TONE IMD
- 服务器性能评分,服务器CPU排行榜之服务器CPU性能评分
- SkipList ----- 跳表