《动手深度学习》4.5 权重衰减Weight Decay
4.5 权重衰减Weight Decay
- 理论
- 硬性限制W:
- 柔性限制W:
- 看图说明正则项(惩罚项)对最优解的影响
- 参数更新的过程
- 代码从零实现
- 生成人工数据
- 参数初始化
- 定义L2范数惩罚(λ\lambdaλ后续添加)
- 训练部分
- 训练结果
- ⭐⭐λ\lambdaλ的选择
- 代码简洁实现
理论
目的: 使用正则化技术缓解过拟合,而不必再去寻找更多的训练数据!
缓解过拟合的方法:
- 限制模型容量(限制特征的数量):eg:调整拟合多项式的阶数
- 限制参数的可选范围:eg:限制W和b的范围
权重衰减就是限制了W的范围!
硬性限制W:
- 使用均方范数作为硬性限制:直接限制W的范数小于某个值。
- 通常不限制b,因为b只是影响曲线的上下平移,对曲线的形状没有实际作用。
- θ\thetaθ越小,意味着这个正则项越强。
柔性限制W:
使用均方范数作为柔性限制:
对于每个θ\thetaθ,都可以找到一个λ\lambdaλ使得之前的目标函数等价于下式:
超参数λ\lambdaλ控制了正则项的重要程度
- λ=0\lambda = 0λ=0:无作用
- λ→∞\lambda → ∞λ→∞:表示w*→0
为什么选择L2范数作为正则项?
- 使用L2范数的一个原因是它对权重向量的大分量施加了巨大的惩罚。 这使得我们的学习算法偏向于在大量特征上均匀分布权重的模型。在实践中,这可能使它们对单个变量中的观测误差更为稳定。
相比之下,L1惩罚会导致模型将权重集中在一小部分特征上, 而将其他权重清除为零。 这称为特征选择(feature selection),这可能是其他场景下需要的。
看图说明正则项(惩罚项)对最优解的影响
- 要保证权重向量比较小, 最常用方法是将其范数作为惩罚项加到最小化损失的问题中。
- 将原来的训练目标最小化训练标签上的预测损失, 调整为最小化预测损失和惩罚项之和。
参数更新的过程
计算梯度:
更新参数:
与未正则化的更新过程做对比:
会发现正则化后只是多了一项(−ηλ)W(-\eta\lambda)W(−ηλ)Wt
通常,ηλ<1\eta\lambda < 1ηλ<1,也就是说每次会先给权重W减去一部分,然后再按梯度进行下降。这就是名字"权重衰退"的来由!!!
代码从零实现
权重衰减是最广泛使用的正则化的技术之一
%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l
生成人工数据
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)
参数初始化
def init_params():w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)b = torch.zeros(1, requires_grad=True)return [w, b]
定义L2范数惩罚(λ\lambdaλ后续添加)
def l2_penalty(w):return torch.sum(w.pow(2)) / 2
训练部分
和之前的区别在于,要接受一个输入参数
λ\lambdaλ:
l = loss(net(X), y) + lambd * l2_penalty(w)
#此处引入了惩罚项!这是和之前train唯一的区别!
def train(lambd):w, b = init_params() #参数初始化net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss #定义模型结构为简单的线性回归,loss为平方损失num_epochs, lr = 100, 0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])#动态作图效果for epoch in range(num_epochs):for X, y in train_iter:with torch.enable_grad():l = loss(net(X), y) + lambd * l2_penalty(w) #此处引入了惩罚项!这是和之前train唯一的区别!l.sum().backward()d2l.sgd([w, b], lr, batch_size)if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数是:', torch.norm(w).item())
训练结果
- 不使用正则化,直接训练:严重过拟合!
- 选择比较大一点的lambda,使用权重衰减
由于训练集只有20,太小了,而且lambda比较小,对W的限制性一般,所以拟合效果还是比较差。
- 保持λ=3\lambda = 3λ=3,数据集大小改成200
果然,数据集足够大可以很好的避免过拟合。
- 保持数据集不变,调大λ\lambdaλ的值
过拟合有所减轻,但是发现λ\lambdaλ也不能太大,会出现欠拟合的结果。
⭐⭐λ\lambdaλ的选择
- 当lambda的值很小时,其惩罚项值不大,还是会出现过拟合现象
- 当lambda的值逐渐调大的时候,过拟合现象的程度越来越低,
- 但是当labmda的值超过一个阈值时,就会出现欠拟合现象,因为其惩罚项太大,导致丢失太多的特征,甚至一些比较重要的特征。
lamda值是提升模型泛化能力的,但是不能设置过高,否则也会导致梯度消失,也不能设置过低,将会导致梯度爆炸
在调参的时候,可以通过网格搜索来确定最佳的正则化参数。
一般的做法的是,首先在0.0到0.1之间的各个数量级上进行网格搜索,然后在找到某个级别后,再对该级别进行网格搜索。
沐神说,一般都设置为e-2,e-3这种的,权重衰减的效果确实有限,如果没有明显效果就只能换别的方法了。
代码简洁实现
区别:把weight decay放在了训练过程里,而没有定义在损失函数loss(或者说目标函数)里。
而计算loss的过程没有变化:l = loss(net(X), y)
这样其实更适合计算机的训练,减少了求导的复杂性。但本质是一样的,因为weighth decay就是每次更新时给w多减了一项而已。
def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs, 1))for param in net.parameters():param.data.normal_()loss = nn.MSELoss()num_epochs, lr = 100, 0.003trainer = torch.optim.SGD([{"params": net[0].weight,'weight_decay': wd}, {"params": net[0].bias}], lr=lr)animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:with torch.enable_grad():trainer.zero_grad()l = loss(net(X), y)l.backward()trainer.step()if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())
《动手深度学习》4.5 权重衰减Weight Decay相关推荐
- 动手深度学习13——计算机视觉:数据增广、图片分类
文章目录 一.数据增广 1.1 为何进行数据增广? 1.2 常见图片增广方式 1.2.1 翻转 1.2.2 切割(裁剪) 1.2.3 改变颜色 1.2.4 综合使用 1.3 使用图像增广进行训练 1. ...
- 【动手深度学习-笔记】注意力机制(一)注意力机制框架
生物学中的注意力提示 非自主性提示: 在没有主观意识的干预下,眼睛会不自觉地注意到环境中比较突出和显眼的物体. 比如我们自然会注意到一堆黑球中的一个白球,马路上最酷的跑车等. 自主性提示: 在主观意识 ...
- 动手深度学习13:计算机视觉——语义分割、风格迁移
文章目录 一.语义分割 1.1 语义分割简介 1.2 Pascal VOC2012 语义分割数据集 1.2.1下载.读取数据集 1.2.2 构建字典(RGB颜色值和类名互相映射) 1.2.3 数据预处 ...
- 深度学习相关概念:权重初始化
深度学习相关概念:权重初始化 1.全零初始化(×) 2.随机初始化 2.1 高斯分布/均匀分布 2.1.1权重较小-N(0,0.01)\pmb{\mathcal{N}(0,0.01)}N(0,0.01 ...
- 【动手深度学习-笔记】注意力机制(四)自注意力、交叉注意力和位置编码
文章目录 自注意力(Self-Attention) 例子 Self-Attention vs Convolution Self-Attention vs RNN 交叉注意力(Cross Attenti ...
- 动手深度学习笔记(四十)7.4. 含并行连结的网络(GoogLeNet)
动手深度学习笔记(四十)7.4. 含并行连结的网络(GoogLeNet) 7.4. 含并行连结的网络(GoogLeNet) 7.4.1. Inception块 7.4.2. GoogLeNet模型 7 ...
- 动手深度学习笔记(一)2.1数据操作
动手深度学习笔记(一) 2. 预备知识 2.1. 数据操作 2.1.1. 入门 2.1.2. 运算符 2.1.3. 广播机制 2.1.4. 索引和切片 2.1.5. 节省内存 2.1.6. 转换为其他 ...
- 动手深度学习笔记(四十五)8.1. 序列模型
动手深度学习笔记(四十五)8.1. 序列模型 8.1. 序列模型 8.1.1. 统计工具 8.1.1.1. 自回归模型 8.1.1.2. 马尔可夫模型 8.1.1.3. 因果关系 8.1.2. 训练 ...
- 权值衰减weight decay的理解
1. 介绍 权值衰减weight decay即L2正则化,目的是通过在Loss函数后加一个正则化项,通过使权重减小的方式,一定减少模型过拟合的问题. L1正则化:即对权重矩阵的每个元素绝对值求和, λ ...
- 深度学习Trick——用权重约束减轻深层网络过拟合|附(Keras)实现代码
向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程 公众号:datayx 在深度学习中,批量归一化(batch normalization)以及对损失函数加一些正则项这 ...
最新文章
- php手册数组函数,PHP - Manual手册 - 函数参考 - Array 数组函数 - array_diff计算数组的差集...
- python版本选择-【小白学python】之一:版本选择
- Android Studio 提示与技巧(官方文档翻译)
- [译]模型-视图-提供器 模式
- 看漫画,学电子,我居然看懂了!
- MySQL查询select语句详解
- signature=f0dd2033ed5bb3cdb94f9136381f7750,Lesson 8: Signature Assignment
- linux 终端 拼音,告诉你Ubuntu中文智能拼音输入法配置的方法及命令
- 程序员的 升级 ,价值观的改变
- mt4 谐波_MT4指标Harmonic Dasboard — 外汇谐波仪表盘交易系统
- 关于Loadrunner11破解的各种问题。。。泪奔。。。
- 从MDK4到MDK5之“盘古开天辟地”
- php对字符数组进行排序,php数组去重_php对数组中字符串去重并排序例子
- 如何解决RS485 通讯接口被主站占用的问题
- Android 混淆总结
- 整一篇整一篇,python3实现自动重启路由器的上的花生壳(selenium)
- 求救 关于ORA-01115的错误
- 单片机课程设计:四位密码锁代码
- .c_str()函数解析
- 施工企业数字化管理系统赋能项目全生命周期 强化过程管控精细化
热门文章
- 应急响应中的入侵排查和权限维持
- 4.6 linux文件系统-虚拟文件系统VFS
- unity学习手记之角色动画
- 最近火爆全网的猫猫回收站教程,小七给你们搞来了
- python bunch制作可导入数据_Python 之 Bunch Pattern
- 记录mysql中如何统计日周月季度年
- containers matlab,Matlab 中实用数据结构之 containers.Map
- 美通社企业新闻汇总 | 2019.1.28 | 万豪集团2018年创增长新纪录;英特尔宣布AI合作伙伴创新激励计划...
- 家庭局域网的组建(2台或2台以上)
- HEVC: I帧、P帧及B帧