权重衰减weight_decay参数从入门到精通
文章目录
- 本文内容
- 1. 什么是权重衰减(Weight Decay)
- 2. 什么是正则化?
- 2.1 什么数据扰动
- 3. 减小模型权重
- 4. 为Loss增加惩罚项
- 4.1 通过公式理解Weight Decay
- 4.2 通过图像理解Weight Decay
- 为什么1范数不好
- 5. Weight Decay的实现
- 6. weight_decay的一些trick
- 参考资料
本文内容
Weight Decay是一个正则化技术,作用是抑制模型的过拟合,以此来提高模型的泛化性。
目前网上对于Weight Decay的讲解都比较泛,都是短短的几句话,但对于其原理、实现方式大多就没有讲解清楚,本文将会逐步解释weight decay机制。
1. 什么是权重衰减(Weight Decay)
Weight Decay是一个正则化技术,作用是抑制模型的过拟合,以此来提高模型的泛化性。
它是通过给损失函数增加模型权重L2范数的惩罚(penalty)来让模型权重不要太大,以此来减小模型的复杂度,从而抑制模型的过拟合。
看完上面那句话,可能很多人已经蒙圈了,这是在说啥。后面我会逐步进行解释,将会逐步回答以下问题:
- 什么是正则化?
- Weight Decay的减小模型参数的思想
- L1范数惩罚项和L2范数惩罚项是什么?
- 为什么Weight Decay参数是在优化器上,而不是在Loss上。
- weight decay的调参技巧
2. 什么是正则化?
正则化的目标是减小方差或是说减小数据扰动所造成的影响。 我们来看下图来理解一下这句话:
这幅图是随着训练次数,训练Loss和验证Loss的变化曲线。上面那条线是验证集的。很明显,这个模型出现了过拟合,因为随着训练次数的增加,训练Loss在下降,但是验证Loss却在上升。这里我们会引出三个概念:
- 方差(Variance):刻画数据扰动所造成的影响。
- 偏差(Bias):刻画学习算法本身的拟合能力。
- 噪声(Noise):当前任务任何学习算法能达到的期望泛化误差的下界。也就是数据的噪声导致一定会出现的那部分误差。
通常不考虑噪声,所以偏差和噪声合并称为偏差。
2.1 什么数据扰动
上面说方差是“刻画数据扰动所造成的影响”,我们可以通过下面例子来理解这句话。
假设我们要预测一个 y = x y=x y=x 的模型:
绿色的线是真正的模型 y = x y=x y=x,蓝色的点是训练数据,红色的线是预测出的模型。这个训练数据点距离真实模型的偏离程度就是数据扰动。
如果我们使用数据扰动较小的数据,那么预测模型结果就会和真正模型的差距较小,例如:
当我们数据扰动越大,预测模型距离实际模型的差距就会越大。因此,我们减小过拟合就是让预测模型和真实模型尽可能的一致。通常有两种做法:
- 增加数据量和使用更好的数据。这也是最推荐的做法
- 然而,通常我们很难收集到更多的数据,所以此时就需要一些正则化技术来减小“数据扰动”对模型预测带来的影响。
3. 减小模型权重
权重衰减(Weight Decay)就是减小模型的权重大小,而减小模型的权重大小就可以降低模型的复杂度,使模型变得平滑,进而减小过拟合。
假设我们的模型为: y = w 0 + w 1 x + w 2 x 2 + w 2 x 2 + ⋯ + w n x n y = w_0 + w_1 x + w_2x^2 + w_2x^2 + \cdots +w_nx^n y=w0+w1x+w2x2+w2x2+⋯+wnxn,模型的参数为 W = ( w 0 , w 1 , w 2 , ⋯ , w n ) W=(w_0, w_1, w_2, \cdots, w_n) W=(w0,w1,w2,⋯,wn)
我们使用该模型根据一些训练数据点可能会学到如下的两种曲线:
很明显,蓝色的曲线显然过拟合了。如果我们观察 W W W 的话会发现,蓝色曲线的参数通常都比较大,而绿色曲线的参数通常都比较小。
上面只是直观的说一下。结论就是:模型权重数值越小,模型的复杂度越低。
该结论可以通过实验观察出来,也可以通过数学证明。(李沐说可以证明,感兴趣的同学可以搜一下)
4. 为Loss增加惩罚项
上面说了Weight Decay目的是要让模型权重小一点(控制在某一个范围内),以此来减小模型的复杂性,从而抑制过拟合。
而Weight Decay的具体做法就是在Loss后面增加一个权重的L2范数惩罚项。
4.1 通过公式理解Weight Decay
Weight Decay的具体公式就是:
L = L 0 + λ 2 ∣ ∣ W ∣ ∣ 2 L = L_0 + \frac{\lambda}{2}||W||^2 L=L0+2λ∣∣W∣∣2
其中 L 0 L_0 L0 是原本的Loss, λ \lambda λ 是一个超参,负责控制权重衰减的强弱。 ∣ ∣ W ∣ ∣ 2 ||W||^2 ∣∣W∣∣2 为模型参数的2范数的平方。
具体的,假设我们的模型有 n n n 个参数,即 W = [ w 1 , w 2 , ⋯ , w n ] W=[w_1, w_2, \cdots, w_n] W=[w1,w2,⋯,wn],则 L L L 为:
L = L 0 + λ 2 ( w 1 2 + w 2 2 + ⋯ + w n 2 ) 2 = L 0 + λ 2 ( w 1 2 + w 2 2 + ⋯ + w n 2 ) \begin{aligned} L &= L_0 + \frac{\lambda}{2}\left( \sqrt{w_1^2+w_2^2+\cdots+w_n^2} \right) ^2 \\\\ &= L_0 + \frac{\lambda}{2}(w_1^2+w_2^2+\cdots+w_n^2) \end{aligned} L=L0+2λ(w12+w22+⋯+wn2 )2=L0+2λ(w12+w22+⋯+wn2)
从上面的公式,我们可以很明显的得到如下结论:
- 模型的权重越大,Loss就会越大。
- λ \lambda λ 越大,权重衰减的就越厉害
- 若 λ \lambda λ 过大,那么原本Loss的占比就会较低,最后模型就光顾着让模型权重变小了,最终模型效果就会变差。
4.2 通过图像理解Weight Decay
接下来我们用图像来感受一下Weight Decay。假设我们的模型只有两个参数W1和W2,W1和W2与Loss=2有如下关系:
这个绿色的椭圆表示,当W1和W2取绿色椭圆上的点时,Loss都是2。所以,当我们没有惩罚项时,对于Loss=2,取椭圆上的这些点都可以。若取到右上角的点,那么 W1和W2 的值就会比较大,所以我们希望W1和W2尽量往左下靠。
因为我们的惩罚项是 w 1 2 + w 2 2 w_1^2 + w_2^2 w12+w22,我们将其图像画出来( w 1 2 + w 2 2 = X w_1^2 + w_2^2= X w12+w22=X)。
上图我绘制了三条橘色图像,分别为
w 1 2 + w 2 2 = X 1 w_1^2 + w_2^2= X_1 w12+w22=X1,与椭圆无焦点。
w 1 2 + w 2 2 = X 2 w_1^2 + w_2^2= X_2 w12+w22=X2,与椭圆交于A点
w 1 2 + w 2 2 = X 3 w_1^2 + w_2^2= X_3 w12+w22=X3,与椭圆交于B,C两点
从上图可以看到,在不改变原Loss的情况下,(W1, W2)落在A点时,惩罚项最小,即 w 1 2 + w 2 2 w_1^2 + w_2^2 w12+w22 最小。
所以,我们增加2范数的惩罚,会让模型参数变小。
为什么1范数不好
可能有些同学比较好奇,为什么不取1范数,我们同样用图可以表示出来。我们将上述的2范数图像变成1范数图像(即 ∣ w 1 ∣ + ∣ w 2 ∣ = X |w_1| +|w_2|=X ∣w1∣+∣w2∣=X):
上图我绘制了三条橘色图像,分别为
∣ w 1 ∣ + ∣ w 2 ∣ = X 1 |w_1| + |w_2|= X_1 ∣w1∣+∣w2∣=X1,与椭圆无焦点。
∣ w 1 ∣ + ∣ w 2 ∣ = X 2 |w_1| + |w_2|= X_2 ∣w1∣+∣w2∣=X2,与椭圆交于A点
∣ w 1 ∣ + ∣ w 2 ∣ = X 3 |w_1| + |w_2|= X_3 ∣w1∣+∣w2∣=X3,与椭圆交于B,C两点
与2范数同理,在不改变原Loss的情况下,(W1, W2)落在A点时,惩罚项最小,即 ∣ w 1 ∣ + ∣ w 2 ∣ |w_1| + |w_2| ∣w1∣+∣w2∣ 最小。
但这里有个问题,我们发现此时 w 1 w_1 w1 变成 0 了。这就是为什么我们通常不用1范数,因为1范数会倾向于让一部分权重变成0。
更高的范数同理,可以参考“什么是范数(Norm),其具有哪些性质”这篇博客来感受一下每个范数不同的图像,然后将其套到上面的图中,感受一下其他范数。
5. Weight Decay的实现
通常我们在使用Weight Decay是在优化器(Optimizer)上,这就很奇怪了,上面明明都是在说Loss,为什么weight decay参数是在优化器上呢?
这是因为它们是等价的。这个很容易推导,我们用SGD来举例,SGD的更新参数的过程为:
w i ← w i − γ ∂ L ∂ w i w_i \gets w_i - \gamma \frac{\partial L}{\partial w_i} wi←wi−γ∂wi∂L
其中 γ \gamma γ 是学习率。
我们将 L = L 0 + λ 2 ∣ ∣ W ∣ ∣ 2 L = L_0 + \frac{\lambda}{2}||W||^2 L=L0+2λ∣∣W∣∣2 带进来求一下可得:
w i − γ ∂ L ∂ w i = w i − γ ( ∂ L 0 ∂ w i + λ w i ) \begin{aligned} & w_i - \gamma \frac{\partial L}{\partial w_i} \\\\ = & w_i - \gamma (\frac{\partial L_0}{\partial w_i} + \lambda w_i) \end{aligned} =wi−γ∂wi∂Lwi−γ(∂wi∂L0+λwi)
其中 ∂ L 0 ∂ w i \frac{\partial L_0}{\partial w_i} ∂wi∂L0 就是原本的梯度,所以我们为Loss增加L2正则项只需要在更新参数时,给模型的梯度加一个 λ w i \lambda w_i λwi 即可。
对应Pytorch的实现如下图:
6. weight_decay的一些trick
- weight_decay并没有你想想中的那么好,它的效果可能只有一点点,不要太指望它。尤其是当你的模型很复杂时,权重衰退的效果可能会更小了。
- 通常取1e-3,如果要尝试的话,一般也就是1e-2, 1e-3, 1e-4 这些选项。
- 权重衰退通常不对bias做。但通常bias做不做权重衰退其实效果差不多,不过最好不要做。
- weight_decay取值越大,对抑制模型的强度越大。但这并不说明越大越好,太大的话,可能会导致模型欠拟合。
针对第三点:对于一个二维曲线,bias只是让曲线整体上下移动,并不能减小模型的复杂度,所以通常不需要对bias做正则化。
参考资料
正则化之weight_decay(深度之眼): https://www.bilibili.com/video/BV1HB4y1i7Fn
权重衰退(李沐): https://www.bilibili.com/video/BV1UK4y1o7dy
从拉格朗日乘数法角度理解L1L2正则: https://www.bilibili.com/video/BV1Z44y147xA
权重衰减weight_decay参数从入门到精通相关推荐
- (pytorch-深度学习系列)pytorch避免过拟合-权重衰减的实现-学习笔记
pytorch避免过拟合-权重衰减的实现 首先学习基本的概念背景 L0范数是指向量中非0的元素的个数:(L0范数难优化求解) L1范数是指向量中各个元素绝对值之和: L2范数是指向量各元素的平方和然后 ...
- pytorch学习笔记(十二):权重衰减
文章目录 1. 方法 2. 高维线性回归实验 3. 从零开始实现 3.1 初始化模型参数 3.2 定义L2L_2L2范数惩罚项 3.3 定义训练和测试 3.4 观察过拟合 3.5 使用权重衰减 4. ...
- 权重衰减/权重衰退——weight_decay
目录 权重衰减/权重衰退--weight_decay 一.什么是权重衰减/权重衰退--weight_decay? 二.weight decay 的作用 三.设置weight decay的值为多少? 权 ...
- bean validation校验方法参数_SpringBoot参数校验 从入门到精通 解决繁琐的参数验证工作...
● 手把手教你实现 SpringBoot与Vue整合开发 前后端分离 简单例子 详解●SQL优化经历 SQL执行效率提高了1000w倍●Java面试题 详解 由易到难● SQL语句大全详解 增删改查 ...
- SpringBoot从入门到精通教程(二十七)- @Valid注解用法详解+全局处理器Exception优雅处理参数验证用法
问题痛点 用 Spring 框架写代码时,写接口类,相信大家对该类的写法非常熟悉.在写接口时要写效验请求参数逻辑,这时候我们会常用做法是写大量的 if 与 if else 类似这样的代码来做判断,如下 ...
- ECCV2020 | SOD100K:超低参数量的高效显著性目标检测算法,广义OctConv和动态权重衰减...
点击上方"AI算法修炼营",选择"星标"公众号 精选作品,第一时间送达 这篇文章收录于ECCV2020,是一篇超高效的显著性目标检测的算法,仅有100K的参数量 ...
- paddlepaddle系列之三行代码从入门到精通
PaddlePaddle系列之三行代码从入门到精通 前言 这将是PaddlePaddle系列教程的开篇,属于非官方教程.既然是非官方,自然会从一个使用者的角度出发,来教大家怎么用,会有哪些坑,以及如何 ...
- [pytorch、学习] - 3.12 权重衰减
参考 3.12 权重衰减 本节介绍应对过拟合的常用方法 3.12.1 方法 正则化通过为模型损失函数添加惩罚项使学出的模型参数更小,是应对过拟合的常用手段. 3.12.2 高维线性回归实验 impor ...
- Arduino Mixly入门到精通教程
目录 1.介绍 2.实验器材和相关资料下载链接 3. Uno Plus 开发板和米思齐软件 第1小节 简单介绍 Uno Plus 开发板 第2小节 Uno Plus 开发板的驱动安装方法 第3小节 ...
最新文章
- 图书馆座位预定管理系统前端设计_图书馆座位预约管理信息系统设计设计.doc...
- 【白话机器学习】算法理论+实战之朴素贝叶斯
- 【数据结构与算法】之深入解析“正则表达式匹配”的求解思路与算法示例
- 转载之NetApp RAID技术介绍
- 爱因斯坦留下的预言还有几个未实现?
- 百年理工计算机专业课程,这两所国内的百年理工院校,实力强劲,都是国内顶尖实力...
- workbench透明设置_ansys workbench模型能透明显示吗?非常感谢
- 【软件使用技巧】PL/SQL Developer实现双击table询
- Android Stutio 3.0 - Gradle sync failed
- C++_实现一个简单的智能指针shared_ptr
- c# asp.net页面传值方法总结
- js 调用摄像头拍照
- 学习笔记1-【计算机组成原理】-【计算机科学速成课】[40集全/精校] - Crash Course Computer Science
- Mac版PhotoShop 2020 最新版下载
- 全“芯”赋能,SOM3568核心板
- 格雷斯音频大篷车无线音箱回顾
- 智慧养老整体解决方案
- 手机摄像头的等效焦距
- 谷歌浏览器(chrome)允许跨域设置的方法
- MySQL read_only 与 super_read_only 之间的关系
热门文章
- 计算机和计算机之间如何传送文件,两台电脑实现互传文件:多种方法可选择
- 依维世苏打水让办公也可以冒出开心的小泡泡
- Python3快速入门-Python是什么
- 论坛没落了吗?传统BBS(论坛)何去何从?
- python语言中、复数类型中实数部分_python学习03.02:Python数值类型(整形、浮点型和复数)及其用法...
- 安全帽识别软件能够解决现场管理诸多问题
- qq分享提示设备未授权_友盟微信、QQ等分享提示未验证应用配置
- cad.net 利用win32api实现一个命令开关参照面板
- VS2008运行过程中出现regsvr32问题解决方法记录
- 编辑/调试汇编语言所需要工具