动手学深度学习(十四)——权重衰退
文章目录
- 1. 如何缓解过拟合?
- 2. 如何衡量模型的复杂度?
- 3. 通过限制参数的选择范围来控制模型容量(复杂度)
- 4. 正则化如何让权重衰退?
- 5. 可视化地看看正则化是如何利用权重衰退来达到缓解过拟合的目的的
- 6. 为什么使用平方范数而非标准范数(欧几里得距离)?
- 7. L1范数和L2范数在应用中的一些区别
1. 如何缓解过拟合?
(1)可以通过收集更多的训练数据来缓解,但是这种方法成本高,而且不受人为控制,短时间内很难取得很高的成效;
(2)也可以通过限制特征的数量来缓解过拟合,但简单地丢弃某些特征会对训练结果产生更多的影响。
(3)假设我们已经拥有了足够多的高质量数据,我们可以用正则化技术来缓解过拟合的影响。
2. 如何衡量模型的复杂度?
记住:模型越复杂,其能拟合的函数也就越复杂,但其也越是容易过拟合;但是模型过于简单,又很容易导致欠拟合,不能够学习到数据的关键特征。
正则化是基于一个基本直觉,即在所有函数fff中,函数f=0f = 0f=0(所有输入都得到值000)在某种意义上是最简单的,我们可以通过函数与零的距离来衡量函数的复杂度。但是我们应该如何精确地测量一个函数和零之间的距离呢?没有一个正确的答案。事实上,整个数学分支,包括函数分析和巴拿赫空间理论,都在致力于回答这个问题。
一种简单的方法是通过线性函数f(x)=w⊤xf(\mathbf{x}) = \mathbf{w}^\top \mathbf{x}f(x)=w⊤x中的权重向量的某个范数来度量其复杂性,例如∥w∥2\| \mathbf{w} \|^2∥w∥2。要保证权重向量比较小,最常用方法是将其范数作为惩罚项加到最小化损失的问题中。
现在,如果我们的权重向量增长的太大,我们的学习算法可能会更集中于最小化权重范数 ∥w∥2\| \mathbf{w} \|^2∥w∥2。这正是我们想要的。我们的损失由下式给出:
L(w,b)=1n∑i=1n12(w⊤x(i)+b−y(i))2.L(\mathbf{w}, b) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right)^2.L(w,b)=n1i=1∑n21(w⊤x(i)+b−y(i))2.
3. 通过限制参数的选择范围来控制模型容量(复杂度)
- 使用均方范数作为硬性限制
minL(w,b)subjectto∣∣w∣∣2⩽θmin \; L(w,b) \quad subject \; to \; ||w||^2 \leqslant \thetaminL(w,b)subjectto∣∣w∣∣2⩽θ
* 通常不限制偏移b(限制与不限制的区别不大)
* 小θ\thetaθ表示更强的正则项 - 使用均方范数作为柔性限制
* 对于每个θ\thetaθ都可以找到λ\lambdaλ使得之前的目标函数等价于下面公式(使用拉格郎日乘子证明):
minL(w,b)+λ2∣∣w∣∣2min \; L(w,b) + \frac{\lambda}{2}||w||^2minL(w,b)+2λ∣∣w∣∣2
* 超参数λ\lambdaλ控制了正则项的重要程度
* λ=0\lambda = 0λ=0:无作用
* λ→∞\lambda \rightarrow \inftyλ→∞:w∗→0w^* \rightarrow 0w∗→0
4. 正则化如何让权重衰退?
计算梯度
∂∂w(L(w,b)+λ2∣∣w∣∣2)=∂L(w,b)∂w+λw\frac{\partial}{\partial w}(L(w,b)+\frac{\lambda}{2}{||w||}^2) = \frac{\partial L(w,b)}{\partial w} + \lambda w∂w∂(L(w,b)+2λ∣∣w∣∣2)=∂w∂L(w,b)+λw时间t更新参数
wt+1=wt−η(∂L(wt,b)∂wt+λwt)w_{t+1} = w_t - \eta (\frac{\partial L(w_t,b)}{\partial w_t} + \lambda w_t)wt+1=wt−η(∂wt∂L(wt,b)+λwt)
wt+1=(1−ηλ)wt−η∂L(wt,bt)∂wtw_{t+1} = (1-\eta \lambda)w_t - \eta \frac{\partial L(w_t,b_t)}{\partial w_t}wt+1=(1−ηλ)wt−η∂wt∂L(wt,bt)- η\etaη为学习率
- 通常ηλ<1\eta \lambda < 1ηλ<1,在深度学习中通常称之为权重衰退
可以看出来加上正则项的loss function的梯度只是在wtw_twt这里加上了一个−ηλ-\eta \lambda−ηλ项,通常−ηλ<1-\eta \lambda<1−ηλ<1那么我们得到的梯度更新量就会在梯度更新的方向上回退一些,从而控制了梯度更新的步子。
5. 可视化地看看正则化是如何利用权重衰退来达到缓解过拟合的目的的
权重衰退代码实现(李沐沐神2021动手学深度学习课程中的教学代码):
训练的公式为:
y=0.05+∑i=1d0.01xi+ϵwhere ϵ∼N(0,0.012).y = 0.05 + \sum_{i = 1}^d 0.01 x_i + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.01^2).y=0.05+i=1∑d0.01xi+ϵ where ϵ∼N(0,0.012).
我们选择标签是关于输入的线性函数。标签同时被均值为0,标准差为0.01高斯噪声破坏。为了使过拟合的效果更加明显,我们可以将问题的维数增加到d=200d = 200d=200,并使用一个只包含20个样本的小训练集。
%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l
# 训练样本20,测试样本100个,输入特征维度200,批次大小5(训练样本太小,很容易过拟合)
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)
# 初始化模型参数
def init_params():w = torch.normal(0,1,size =(num_inputs,1),requires_grad=True)b = torch.zeros(1,requires_grad=True)return [w,b]# 定义L2范数惩罚
def l2_penalty(w):return torch.sum(w.pow(2))/2
# 定义训练函数
def train(lambd):w,b = init_params()net,loss = lambda X:d2l.linreg(X,w,b),d2l.squared_lossnum_epochs ,lr = 500,0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X,y in train_iter:l = loss(net(X),y)+lambd*l2_penalty(w)l.sum().backward()d2l.sgd([w,b],lr,batch_size)if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数是:', torch.norm(w).item())
# 忽略正则化直接训练
train(lambd=0)
# 使用权重衰减
train(lambd=3)
由此可见正则化进行权重衰退的却是缓解了过拟合的问题(未加正则化时,其训练误差和测试误差相差巨大,明显过拟合)。
6. 为什么使用平方范数而非标准范数(欧几里得距离)?
L(w,b)+λ2∥w∥2,L(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2,L(w,b)+2λ∥w∥2,
对于λ=0\lambda = 0λ=0,我们恢复了原来的损失函数。对于λ>0\lambda > 0λ>0,我们限制∥w∥\| \mathbf{w} \|∥w∥的大小。公式中除以222:当我们取一个二次函数的导数时,222和1/21/21/2会抵消,以确保更新表达式看起来既漂亮又简单。
所以采用平方范数的原始是为了便于计算。通过平方L2L_2L2范数,我们去掉平方根,留下权重向量每个分量的平方和。这使得惩罚的导数很容易计算:导数的和等于和的导数。
7. L1范数和L2范数在应用中的一些区别
L1范数和L2范数在整个统计学习领域都是有效而受欢迎的正则化方法。
- L1正则化线性回归是统计学习中的基本模型,通常称之为“套索回归”(lasso regression)
- L1范数施加的惩罚较之L2范数更小,使得模型将其他权重清零,而将权重主要集中在一小部分特征上,也被称之为“特征选择”,在某些情景下这是非常重要和被需要的。
- L2正则化线性模型则构建成了经典的“岭回归”(ridge regression)
- L2范数对权重向量施加了巨大的惩罚,使得我们的学习算法偏向于在大量特征上均匀分布权重的模型(也就是说使得模型对单个变量中的误差更加鲁棒)
动手学深度学习(十四)——权重衰退相关推荐
- 《动手学深度学习》(四) -- LeNet、AlexNet、VGG、NiN、GoogLeNet、ResNet、DenseNet 实现
上一小节学习了卷积神经网络的卷积层和池化层的实现,趁热打铁继续学习现代卷积神经网络的搭建,欢迎小伙伴们一起学习和交流~ 为了能够应⽤softmax回归和多层感知机,我们⾸先将每个⼤小为28×2828 ...
- 动手学深度学习(四十)——长短期记忆网络(LSTM)
文章目录 一.长短期记忆网络(LSTM) 1.1 门控记忆单元 1.2 输入门.遗忘门与输出门 1.3候选记忆单元 1.4 记忆单元 1.5 隐藏状态 二.从零实现LSTM 2.1 初始化模型参数 2 ...
- 李沐动手学深度学习第四章-4.9.环境和分布偏移
我们从来没有想过数据最初从哪里来?以及我们计划最终如何处理模型的输出? 根据测试集的精度衡量,模型表现得非常出色. 但是当数据分布突然改变时,模型在部署中会出现灾难性的失败. 解决方案很简单(要求&q ...
- 【动手学深度学习PyTorch版】6 权重衰退
上一篇移步[动手学深度学习PyTorch版]5 模型选择 + 过拟合和欠拟合_水w的博客-CSDN博客 目录 一.权重衰退 1.1 权重衰退 weight decay:处理过拟合的最常见方法(L2_p ...
- 【动手学深度学习v2李沐】学习笔记07:权重衰退、正则化
前文回顾:模型选择.欠拟合和过拟合 文章目录 一.权重衰退 1.1 硬性限制 1.2 柔性限制(正则化) 1.3 参数更新法则 1.4 总结 二.代码实现 2.1 从零开始实现 2.1.1 人工数据集 ...
- 动手学深度学习(三十九)——门控循环单元GRU
文章目录 门控循环单元(GRU) 一.门控隐藏状态 1.1 重置门和更新门 1.2候选隐藏状态 1.3 隐藏状态 二.从零实现GRU 2.1 初始化模型参数 2.2 定义模型 2.3 训练与预测 2. ...
- 动手学深度学习笔记3.4+3.5+3.6+3.7
系列文章目录 动手学深度学习笔记系列: 动手学深度学习笔记3.1+3.2+3.3 文章目录 系列文章目录 前言 一.softmax回归 1.1 分类问题 1.2 网络架构 1.3 全连接层的参数开销 ...
- 《动手学深度学习》—学习笔记
文章目录 深度学习简介 起源 特点 小结 预备知识 获取和运行本书的代码 pytorch环境安装 方式一 方式二 数据操作 创建 运算 广播机制 索引 运算的内存开销 NDArray和NumPy相互变 ...
- 第1章【深度学习简介】--动手学深度学习【Tensorflow2.0版本】
项目地址:https://github.com/TrickyGo/Dive-into-DL-TensorFlow2.0 UC 伯克利李沐的<动手学深度学习>开源书一经推出便广受好评.很多开 ...
最新文章
- java qq登陆api_java方式接入QQ登录
- 2011 ScrumGathering敏捷个人.pptx
- Wdatepicker日期控件的使用指南
- php flock 死锁了,php – 防止由flock引起的死锁
- BCB中获得RichEdit 默认行间距
- 【MPI编程】MPI_Bcast广播讲解和使用
- Android安全-SO动态库注入
- C#LeetCode刷题-分治算法
- 使用vbs脚本实现自动化安装GUI程序
- 命令行查看硬盘序列号
- 雷赛服务器信号er020,伺服与雷赛控制卡配套的小技巧
- COMSOL中文指导教程全集
- 小游戏---java版2048(2048 go go go)
- 文件查找,打包压缩,解压相关分享
- 关于Could not find method javacompileOptions() for arguments
- roc_auc_score()、auc()和roc_curve()
- html 各种字符 换位键,excel替换特定位置处的字符
- 说说在 Python 中如何处理文件系统路径
- 软件加密系统Themida应用程序保护指南(四):虚拟机的选择
- DevOps系列之 —— 持续规划与设计(四)敏捷需求管理【用户故事 敏捷估算】
热门文章
- JSP+ssm计算机毕业设计大媛小南美味佳肴网站8p0nh【源码、数据库、LW、部署】
- 开发常用linux命令
- kali:ARP欺骗
- 解决 为什么会出现 “Safari浏览器打不开该网页,因为地址无效“ 的提示
- 报警系统QuickAlarm之报警规则的设定与加载
- C++ 二叉搜索树(补充)
- HTML5--制作导航栏
- Hack The Box 注册邀请码破解记录
- vivo Hi-Fi+QQ音乐 数字音乐市场的一剂良方
- android alarmmanager后台,Android各版本AlarmManager使用