4.5 权重衰减Weight Decay

  • 理论
    • 硬性限制W:
    • 柔性限制W:
    • 看图说明正则项(惩罚项)对最优解的影响
    • 参数更新的过程
  • 代码从零实现
    • 生成人工数据
    • 参数初始化
    • 定义L2范数惩罚(λ\lambdaλ后续添加)
    • 训练部分
    • 训练结果
    • ⭐⭐λ\lambdaλ的选择
  • 代码简洁实现

理论

目的: 使用正则化技术缓解过拟合,而不必再去寻找更多的训练数据!

缓解过拟合的方法:

  1. 限制模型容量(限制特征的数量):eg:调整拟合多项式的阶数
  2. 限制参数的可选范围:eg:限制W和b的范围

权重衰减就是限制了W的范围!

硬性限制W:

  • 使用均方范数作为硬性限制:直接限制W的范数小于某个值。
  • 通常不限制b,因为b只是影响曲线的上下平移,对曲线的形状没有实际作用。
  • θ\thetaθ越小,意味着这个正则项越强。

柔性限制W:

  • 使用均方范数作为柔性限制
    对于每个θ\thetaθ,都可以找到一个λ\lambdaλ使得之前的目标函数等价于下式:

  • 超参数λ\lambdaλ控制了正则项的重要程度

    • λ=0\lambda = 0λ=0:无作用
    • λ→∞\lambda → ∞λ→∞:表示w*→0
  • 为什么选择L2范数作为正则项?

    • 使用L2范数的一个原因是它对权重向量的大分量施加了巨大的惩罚。 这使得我们的学习算法偏向于在大量特征上均匀分布权重的模型。在实践中,这可能使它们对单个变量中的观测误差更为稳定。
  • 相比之下,L1惩罚会导致模型将权重集中在一小部分特征上, 而将其他权重清除为零。 这称为特征选择(feature selection),这可能是其他场景下需要的。

看图说明正则项(惩罚项)对最优解的影响

  • 要保证权重向量比较小, 最常用方法是将其范数作为惩罚项加到最小化损失的问题中。
  • 将原来的训练目标最小化训练标签上的预测损失, 调整为最小化预测损失和惩罚项之和

参数更新的过程

  • 计算梯度:

  • 更新参数:

  • 与未正则化的更新过程做对比:

    会发现正则化后只是多了一项(−ηλ)W(-\eta\lambda)W(−ηλ)Wt

  • 通常,ηλ<1\eta\lambda < 1ηλ<1,也就是说每次会先给权重W减去一部分,然后再按梯度进行下降。这就是名字"权重衰退"的来由!!!

代码从零实现

权重衰减是最广泛使用的正则化的技术之一

%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l

生成人工数据

n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

参数初始化

def init_params():w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)b = torch.zeros(1, requires_grad=True)return [w, b]

定义L2范数惩罚(λ\lambdaλ后续添加)

def l2_penalty(w):return torch.sum(w.pow(2)) / 2

训练部分

和之前的区别在于,要接受一个输入参数
λ\lambdaλ:
l = loss(net(X), y) + lambd * l2_penalty(w) #此处引入了惩罚项!这是和之前train唯一的区别!

def train(lambd):w, b = init_params() #参数初始化net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss #定义模型结构为简单的线性回归,loss为平方损失num_epochs, lr = 100, 0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])#动态作图效果for epoch in range(num_epochs):for X, y in train_iter:with torch.enable_grad():l = loss(net(X), y) + lambd * l2_penalty(w) #此处引入了惩罚项!这是和之前train唯一的区别!l.sum().backward()d2l.sgd([w, b], lr, batch_size)if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数是:', torch.norm(w).item())

训练结果

  • 不使用正则化,直接训练:严重过拟合!
  • 选择比较大一点的lambda,使用权重衰减
    由于训练集只有20,太小了,而且lambda比较小,对W的限制性一般,所以拟合效果还是比较差。
  • 保持λ=3\lambda = 3λ=3,数据集大小改成200
    果然,数据集足够大可以很好的避免过拟合。
  • 保持数据集不变,调大λ\lambdaλ的值

过拟合有所减轻,但是发现λ\lambdaλ也不能太大,会出现欠拟合的结果。

⭐⭐λ\lambdaλ的选择

  • 当lambda的值很小时,其惩罚项值不大,还是会出现过拟合现象
  • 当lambda的值逐渐调大的时候,过拟合现象的程度越来越低,
  • 但是当labmda的值超过一个阈值时,就会出现欠拟合现象,因为其惩罚项太大,导致丢失太多的特征,甚至一些比较重要的特征。

lamda值是提升模型泛化能力的,但是不能设置过高,否则也会导致梯度消失,也不能设置过低,将会导致梯度爆炸

调参的时候,可以通过网格搜索来确定最佳的正则化参数。
一般的做法的是,首先在0.0到0.1之间的各个数量级上进行网格搜索,然后在找到某个级别后,再对该级别进行网格搜索

沐神说,一般都设置为e-2,e-3这种的,权重衰减的效果确实有限,如果没有明显效果就只能换别的方法了。

代码简洁实现

区别:把weight decay放在了训练过程里,而没有定义在损失函数loss(或者说目标函数)里。

而计算loss的过程没有变化:l = loss(net(X), y)

这样其实更适合计算机的训练,减少了求导的复杂性。但本质是一样的,因为weighth decay就是每次更新时给w多减了一项而已。

def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs, 1))for param in net.parameters():param.data.normal_()loss = nn.MSELoss()num_epochs, lr = 100, 0.003trainer = torch.optim.SGD([{"params": net[0].weight,'weight_decay': wd}, {"params": net[0].bias}], lr=lr)animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:with torch.enable_grad():trainer.zero_grad()l = loss(net(X), y)l.backward()trainer.step()if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())

《动手深度学习》4.5 权重衰减Weight Decay相关推荐

  1. 动手深度学习13——计算机视觉:数据增广、图片分类

    文章目录 一.数据增广 1.1 为何进行数据增广? 1.2 常见图片增广方式 1.2.1 翻转 1.2.2 切割(裁剪) 1.2.3 改变颜色 1.2.4 综合使用 1.3 使用图像增广进行训练 1. ...

  2. 【动手深度学习-笔记】注意力机制(一)注意力机制框架

    生物学中的注意力提示 非自主性提示: 在没有主观意识的干预下,眼睛会不自觉地注意到环境中比较突出和显眼的物体. 比如我们自然会注意到一堆黑球中的一个白球,马路上最酷的跑车等. 自主性提示: 在主观意识 ...

  3. 动手深度学习13:计算机视觉——语义分割、风格迁移

    文章目录 一.语义分割 1.1 语义分割简介 1.2 Pascal VOC2012 语义分割数据集 1.2.1下载.读取数据集 1.2.2 构建字典(RGB颜色值和类名互相映射) 1.2.3 数据预处 ...

  4. 深度学习相关概念:权重初始化

    深度学习相关概念:权重初始化 1.全零初始化(×) 2.随机初始化 2.1 高斯分布/均匀分布 2.1.1权重较小-N(0,0.01)\pmb{\mathcal{N}(0,0.01)}N(0,0.01 ...

  5. 【动手深度学习-笔记】注意力机制(四)自注意力、交叉注意力和位置编码

    文章目录 自注意力(Self-Attention) 例子 Self-Attention vs Convolution Self-Attention vs RNN 交叉注意力(Cross Attenti ...

  6. 动手深度学习笔记(四十)7.4. 含并行连结的网络(GoogLeNet)

    动手深度学习笔记(四十)7.4. 含并行连结的网络(GoogLeNet) 7.4. 含并行连结的网络(GoogLeNet) 7.4.1. Inception块 7.4.2. GoogLeNet模型 7 ...

  7. 动手深度学习笔记(一)2.1数据操作

    动手深度学习笔记(一) 2. 预备知识 2.1. 数据操作 2.1.1. 入门 2.1.2. 运算符 2.1.3. 广播机制 2.1.4. 索引和切片 2.1.5. 节省内存 2.1.6. 转换为其他 ...

  8. 动手深度学习笔记(四十五)8.1. 序列模型

    动手深度学习笔记(四十五)8.1. 序列模型 8.1. 序列模型 8.1.1. 统计工具 8.1.1.1. 自回归模型 8.1.1.2. 马尔可夫模型 8.1.1.3. 因果关系 8.1.2. 训练 ...

  9. 权值衰减weight decay的理解

    1. 介绍 权值衰减weight decay即L2正则化,目的是通过在Loss函数后加一个正则化项,通过使权重减小的方式,一定减少模型过拟合的问题. L1正则化:即对权重矩阵的每个元素绝对值求和, λ ...

  10. 深度学习Trick——用权重约束减轻深层网络过拟合|附(Keras)实现代码

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx 在深度学习中,批量归一化(batch normalization)以及对损失函数加一些正则项这 ...

最新文章

  1. php手册数组函数,PHP - Manual手册 - 函数参考 - Array 数组函数 - array_diff计算数组的差集...
  2. python版本选择-【小白学python】之一:版本选择
  3. Android Studio 提示与技巧(官方文档翻译)
  4. [译]模型-视图-提供器 模式
  5. 看漫画,学电子,我居然看懂了!
  6. MySQL查询select语句详解
  7. signature=f0dd2033ed5bb3cdb94f9136381f7750,Lesson 8: Signature Assignment
  8. linux 终端 拼音,告诉你Ubuntu中文智能拼音输入法配置的方法及命令
  9. 程序员的 升级 ,价值观的改变
  10. mt4 谐波_MT4指标Harmonic Dasboard — 外汇谐波仪表盘交易系统
  11. 关于Loadrunner11破解的各种问题。。。泪奔。。。
  12. 从MDK4到MDK5之“盘古开天辟地”
  13. php对字符数组进行排序,php数组去重_php对数组中字符串去重并排序例子
  14. 如何解决RS485 通讯接口被主站占用的问题
  15. Android 混淆总结
  16. 整一篇整一篇,python3实现自动重启路由器的上的花生壳(selenium)
  17. 求救 关于ORA-01115的错误
  18. 单片机课程设计:四位密码锁代码
  19. .c_str()函数解析
  20. 施工企业数字化管理系统赋能项目全生命周期 强化过程管控精细化

热门文章

  1. 应急响应中的入侵排查和权限维持
  2. 4.6 linux文件系统-虚拟文件系统VFS
  3. unity学习手记之角色动画
  4. 最近火爆全网的猫猫回收站教程,小七给你们搞来了
  5. python bunch制作可导入数据_Python 之 Bunch Pattern
  6. 记录mysql中如何统计日周月季度年
  7. containers matlab,Matlab 中实用数据结构之 containers.Map
  8. 美通社企业新闻汇总 | 2019.1.28 | 万豪集团2018年创增长新纪录;英特尔宣布AI合作伙伴创新激励计划...
  9. 家庭局域网的组建(2台或2台以上)
  10. HEVC: I帧、P帧及B帧