文章目录

  • 本文内容
  • 1. 什么是权重衰减(Weight Decay)
  • 2. 什么是正则化?
    • 2.1 什么数据扰动
  • 3. 减小模型权重
  • 4. 为Loss增加惩罚项
    • 4.1 通过公式理解Weight Decay
    • 4.2 通过图像理解Weight Decay
      • 为什么1范数不好
  • 5. Weight Decay的实现
  • 6. weight_decay的一些trick
  • 参考资料

本文内容

Weight Decay是一个正则化技术,作用是抑制模型的过拟合,以此来提高模型的泛化性。

目前网上对于Weight Decay的讲解都比较泛,都是短短的几句话,但对于其原理、实现方式大多就没有讲解清楚,本文将会逐步解释weight decay机制。

1. 什么是权重衰减(Weight Decay)

Weight Decay是一个正则化技术,作用是抑制模型的过拟合,以此来提高模型的泛化性。

它是通过给损失函数增加模型权重L2范数的惩罚(penalty)来让模型权重不要太大,以此来减小模型的复杂度,从而抑制模型的过拟合。

看完上面那句话,可能很多人已经蒙圈了,这是在说啥。后面我会逐步进行解释,将会逐步回答以下问题:

  1. 什么是正则化?
  2. Weight Decay的减小模型参数的思想
  3. L1范数惩罚项和L2范数惩罚项是什么?
  4. 为什么Weight Decay参数是在优化器上,而不是在Loss上。
  5. weight decay的调参技巧

2. 什么是正则化?

正则化的目标是减小方差或是说减小数据扰动所造成的影响。 我们来看下图来理解一下这句话:

这幅图是随着训练次数,训练Loss和验证Loss的变化曲线。上面那条线是验证集的。很明显,这个模型出现了过拟合,因为随着训练次数的增加,训练Loss在下降,但是验证Loss却在上升。这里我们会引出三个概念:

  1. 方差(Variance):刻画数据扰动所造成的影响。
  2. 偏差(Bias):刻画学习算法本身的拟合能力。
  3. 噪声(Noise):当前任务任何学习算法能达到的期望泛化误差的下界。也就是数据的噪声导致一定会出现的那部分误差。

通常不考虑噪声,所以偏差和噪声合并称为偏差。

2.1 什么数据扰动

上面说方差是“刻画数据扰动所造成的影响”,我们可以通过下面例子来理解这句话。

假设我们要预测一个 y = x y=x y=x 的模型:

绿色的线是真正的模型 y = x y=x y=x,蓝色的点是训练数据,红色的线是预测出的模型。这个训练数据点距离真实模型的偏离程度就是数据扰动

如果我们使用数据扰动较小的数据,那么预测模型结果就会和真正模型的差距较小,例如:

当我们数据扰动越大,预测模型距离实际模型的差距就会越大。因此,我们减小过拟合就是让预测模型和真实模型尽可能的一致。通常有两种做法:

  1. 增加数据量和使用更好的数据。这也是最推荐的做法
  2. 然而,通常我们很难收集到更多的数据,所以此时就需要一些正则化技术来减小“数据扰动”对模型预测带来的影响

3. 减小模型权重

权重衰减(Weight Decay)就是减小模型的权重大小,而减小模型的权重大小就可以降低模型的复杂度,使模型变得平滑,进而减小过拟合。

假设我们的模型为: y = w 0 + w 1 x + w 2 x 2 + w 2 x 2 + ⋯ + w n x n y = w_0 + w_1 x + w_2x^2 + w_2x^2 + \cdots +w_nx^n y=w0​+w1​x+w2​x2+w2​x2+⋯+wn​xn,模型的参数为 W = ( w 0 , w 1 , w 2 , ⋯ , w n ) W=(w_0, w_1, w_2, \cdots, w_n) W=(w0​,w1​,w2​,⋯,wn​)

我们使用该模型根据一些训练数据点可能会学到如下的两种曲线:

很明显,蓝色的曲线显然过拟合了。如果我们观察 W W W 的话会发现,蓝色曲线的参数通常都比较大,而绿色曲线的参数通常都比较小。

上面只是直观的说一下。结论就是:模型权重数值越小,模型的复杂度越低

该结论可以通过实验观察出来,也可以通过数学证明。(李沐说可以证明,感兴趣的同学可以搜一下)

4. 为Loss增加惩罚项

上面说了Weight Decay目的是要让模型权重小一点(控制在某一个范围内),以此来减小模型的复杂性,从而抑制过拟合。

而Weight Decay的具体做法就是在Loss后面增加一个权重的L2范数惩罚项。

4.1 通过公式理解Weight Decay

Weight Decay的具体公式就是:

L = L 0 + λ 2 ∣ ∣ W ∣ ∣ 2 L = L_0 + \frac{\lambda}{2}||W||^2 L=L0​+2λ​∣∣W∣∣2

其中 L 0 L_0 L0​ 是原本的Loss, λ \lambda λ 是一个超参,负责控制权重衰减的强弱。 ∣ ∣ W ∣ ∣ 2 ||W||^2 ∣∣W∣∣2 为模型参数的2范数的平方。

具体的,假设我们的模型有 n n n 个参数,即 W = [ w 1 , w 2 , ⋯ , w n ] W=[w_1, w_2, \cdots, w_n] W=[w1​,w2​,⋯,wn​],则 L L L 为:

L = L 0 + λ 2 ( w 1 2 + w 2 2 + ⋯ + w n 2 ) 2 = L 0 + λ 2 ( w 1 2 + w 2 2 + ⋯ + w n 2 ) \begin{aligned} L &= L_0 + \frac{\lambda}{2}\left( \sqrt{w_1^2+w_2^2+\cdots+w_n^2} \right) ^2 \\\\ &= L_0 + \frac{\lambda}{2}(w_1^2+w_2^2+\cdots+w_n^2) \end{aligned} L​=L0​+2λ​(w12​+w22​+⋯+wn2​ ​)2=L0​+2λ​(w12​+w22​+⋯+wn2​)​

从上面的公式,我们可以很明显的得到如下结论:

  1. 模型的权重越大,Loss就会越大。
  2. λ \lambda λ 越大,权重衰减的就越厉害
  3. 若 λ \lambda λ 过大,那么原本Loss的占比就会较低,最后模型就光顾着让模型权重变小了,最终模型效果就会变差。

4.2 通过图像理解Weight Decay

接下来我们用图像来感受一下Weight Decay。假设我们的模型只有两个参数W1和W2,W1和W2与Loss=2有如下关系:

这个绿色的椭圆表示,当W1和W2取绿色椭圆上的点时,Loss都是2。所以,当我们没有惩罚项时,对于Loss=2,取椭圆上的这些点都可以。若取到右上角的点,那么 W1和W2 的值就会比较大,所以我们希望W1和W2尽量往左下靠。

因为我们的惩罚项是 w 1 2 + w 2 2 w_1^2 + w_2^2 w12​+w22​,我们将其图像画出来( w 1 2 + w 2 2 = X w_1^2 + w_2^2= X w12​+w22​=X)。

上图我绘制了三条橘色图像,分别为

w 1 2 + w 2 2 = X 1 w_1^2 + w_2^2= X_1 w12​+w22​=X1​,与椭圆无焦点。
w 1 2 + w 2 2 = X 2 w_1^2 + w_2^2= X_2 w12​+w22​=X2​,与椭圆交于A点
w 1 2 + w 2 2 = X 3 w_1^2 + w_2^2= X_3 w12​+w22​=X3​,与椭圆交于B,C两点

从上图可以看到,在不改变原Loss的情况下,(W1, W2)落在A点时,惩罚项最小,即 w 1 2 + w 2 2 w_1^2 + w_2^2 w12​+w22​ 最小。

所以,我们增加2范数的惩罚,会让模型参数变小。

为什么1范数不好

可能有些同学比较好奇,为什么不取1范数,我们同样用图可以表示出来。我们将上述的2范数图像变成1范数图像(即 ∣ w 1 ∣ + ∣ w 2 ∣ = X |w_1| +|w_2|=X ∣w1​∣+∣w2​∣=X):


上图我绘制了三条橘色图像,分别为

∣ w 1 ∣ + ∣ w 2 ∣ = X 1 |w_1| + |w_2|= X_1 ∣w1​∣+∣w2​∣=X1​,与椭圆无焦点。
∣ w 1 ∣ + ∣ w 2 ∣ = X 2 |w_1| + |w_2|= X_2 ∣w1​∣+∣w2​∣=X2​,与椭圆交于A点
∣ w 1 ∣ + ∣ w 2 ∣ = X 3 |w_1| + |w_2|= X_3 ∣w1​∣+∣w2​∣=X3​,与椭圆交于B,C两点

与2范数同理,在不改变原Loss的情况下,(W1, W2)落在A点时,惩罚项最小,即 ∣ w 1 ∣ + ∣ w 2 ∣ |w_1| + |w_2| ∣w1​∣+∣w2​∣ 最小。

但这里有个问题,我们发现此时 w 1 w_1 w1​ 变成 0 了。这就是为什么我们通常不用1范数,因为1范数会倾向于让一部分权重变成0。

更高的范数同理,可以参考“什么是范数(Norm),其具有哪些性质”这篇博客来感受一下每个范数不同的图像,然后将其套到上面的图中,感受一下其他范数。

5. Weight Decay的实现

通常我们在使用Weight Decay是在优化器(Optimizer)上,这就很奇怪了,上面明明都是在说Loss,为什么weight decay参数是在优化器上呢?

这是因为它们是等价的。这个很容易推导,我们用SGD来举例,SGD的更新参数的过程为:

w i ← w i − γ ∂ L ∂ w i w_i \gets w_i - \gamma \frac{\partial L}{\partial w_i} wi​←wi​−γ∂wi​∂L​

其中 γ \gamma γ 是学习率。

我们将 L = L 0 + λ 2 ∣ ∣ W ∣ ∣ 2 L = L_0 + \frac{\lambda}{2}||W||^2 L=L0​+2λ​∣∣W∣∣2 带进来求一下可得:

w i − γ ∂ L ∂ w i = w i − γ ( ∂ L 0 ∂ w i + λ w i ) \begin{aligned} & w_i - \gamma \frac{\partial L}{\partial w_i} \\\\ = & w_i - \gamma (\frac{\partial L_0}{\partial w_i} + \lambda w_i) \end{aligned} =​wi​−γ∂wi​∂L​wi​−γ(∂wi​∂L0​​+λwi​)​

其中 ∂ L 0 ∂ w i \frac{\partial L_0}{\partial w_i} ∂wi​∂L0​​ 就是原本的梯度,所以我们为Loss增加L2正则项只需要在更新参数时,给模型的梯度加一个 λ w i \lambda w_i λwi​ 即可。

对应Pytorch的实现如下图:

6. weight_decay的一些trick

  1. weight_decay并没有你想想中的那么好,它的效果可能只有一点点,不要太指望它。尤其是当你的模型很复杂时,权重衰退的效果可能会更小了。
  2. 通常取1e-3,如果要尝试的话,一般也就是1e-2, 1e-3, 1e-4 这些选项。
  3. 权重衰退通常不对bias做。但通常bias做不做权重衰退其实效果差不多,不过最好不要做。
  4. weight_decay取值越大,对抑制模型的强度越大。但这并不说明越大越好,太大的话,可能会导致模型欠拟合。

针对第三点:对于一个二维曲线,bias只是让曲线整体上下移动,并不能减小模型的复杂度,所以通常不需要对bias做正则化。

参考资料

正则化之weight_decay(深度之眼): https://www.bilibili.com/video/BV1HB4y1i7Fn

权重衰退(李沐): https://www.bilibili.com/video/BV1UK4y1o7dy

从拉格朗日乘数法角度理解L1L2正则: https://www.bilibili.com/video/BV1Z44y147xA

权重衰减weight_decay参数从入门到精通相关推荐

  1. (pytorch-深度学习系列)pytorch避免过拟合-权重衰减的实现-学习笔记

    pytorch避免过拟合-权重衰减的实现 首先学习基本的概念背景 L0范数是指向量中非0的元素的个数:(L0范数难优化求解) L1范数是指向量中各个元素绝对值之和: L2范数是指向量各元素的平方和然后 ...

  2. pytorch学习笔记(十二):权重衰减

    文章目录 1. 方法 2. 高维线性回归实验 3. 从零开始实现 3.1 初始化模型参数 3.2 定义L2L_2L2​范数惩罚项 3.3 定义训练和测试 3.4 观察过拟合 3.5 使用权重衰减 4. ...

  3. 权重衰减/权重衰退——weight_decay

    目录 权重衰减/权重衰退--weight_decay 一.什么是权重衰减/权重衰退--weight_decay? 二.weight decay 的作用 三.设置weight decay的值为多少? 权 ...

  4. bean validation校验方法参数_SpringBoot参数校验 从入门到精通 解决繁琐的参数验证工作...

    ● 手把手教你实现 SpringBoot与Vue整合开发 前后端分离 简单例子 详解●SQL优化经历  SQL执行效率提高了1000w倍●Java面试题 详解 由易到难● SQL语句大全详解 增删改查 ...

  5. SpringBoot从入门到精通教程(二十七)- @Valid注解用法详解+全局处理器Exception优雅处理参数验证用法

    问题痛点 用 Spring 框架写代码时,写接口类,相信大家对该类的写法非常熟悉.在写接口时要写效验请求参数逻辑,这时候我们会常用做法是写大量的 if 与 if else 类似这样的代码来做判断,如下 ...

  6. ECCV2020 | SOD100K:超低参数量的高效显著性目标检测算法,广义OctConv和动态权重衰减...

    点击上方"AI算法修炼营",选择"星标"公众号 精选作品,第一时间送达 这篇文章收录于ECCV2020,是一篇超高效的显著性目标检测的算法,仅有100K的参数量 ...

  7. paddlepaddle系列之三行代码从入门到精通

    PaddlePaddle系列之三行代码从入门到精通 前言 这将是PaddlePaddle系列教程的开篇,属于非官方教程.既然是非官方,自然会从一个使用者的角度出发,来教大家怎么用,会有哪些坑,以及如何 ...

  8. [pytorch、学习] - 3.12 权重衰减

    参考 3.12 权重衰减 本节介绍应对过拟合的常用方法 3.12.1 方法 正则化通过为模型损失函数添加惩罚项使学出的模型参数更小,是应对过拟合的常用手段. 3.12.2 高维线性回归实验 impor ...

  9. Arduino Mixly入门到精通教程

    目录 1.介绍 2.实验器材和相关资料下载链接 3. Uno Plus 开发板和米思齐软件 第1小节  简单介绍 Uno Plus 开发板 第2小节 Uno Plus 开发板的驱动安装方法 第3小节 ...

最新文章

  1. 图书馆座位预定管理系统前端设计_图书馆座位预约管理信息系统设计设计.doc...
  2. 【白话机器学习】算法理论+实战之朴素贝叶斯
  3. 【数据结构与算法】之深入解析“正则表达式匹配”的求解思路与算法示例
  4. 转载之NetApp RAID技术介绍
  5. 爱因斯坦留下的预言还有几个未实现?
  6. 百年理工计算机专业课程,这两所国内的百年理工院校,实力强劲,都是国内顶尖实力...
  7. workbench透明设置_ansys workbench模型能透明显示吗?非常感谢
  8. 【软件使用技巧】PL/SQL Developer实现双击table询
  9. Android Stutio 3.0 - Gradle sync failed
  10. C++_实现一个简单的智能指针shared_ptr
  11. c# asp.net页面传值方法总结
  12. js 调用摄像头拍照
  13. 学习笔记1-【计算机组成原理】-【计算机科学速成课】[40集全/精校] - Crash Course Computer Science
  14. Mac版PhotoShop 2020 最新版下载
  15. 全“芯”赋能,SOM3568核心板
  16. 格雷斯音频大篷车无线音箱回顾
  17. 智慧养老整体解决方案
  18. 手机摄像头的等效焦距
  19. 谷歌浏览器(chrome)允许跨域设置的方法
  20. MySQL read_only 与 super_read_only 之间的关系

热门文章

  1. 计算机和计算机之间如何传送文件,两台电脑实现互传文件:多种方法可选择
  2. 依维世苏打水让办公也可以冒出开心的小泡泡
  3. Python3快速入门-Python是什么
  4. 论坛没落了吗?传统BBS(论坛)何去何从?
  5. python语言中、复数类型中实数部分_python学习03.02:Python数值类型(整形、浮点型和复数)及其用法...
  6. 安全帽识别软件能够解决现场管理诸多问题
  7. qq分享提示设备未授权_友盟微信、QQ等分享提示未验证应用配置
  8. cad.net 利用win32api实现一个命令开关参照面板
  9. VS2008运行过程中出现regsvr32问题解决方法记录
  10. 编辑/调试汇编语言所需要工具