前文回顾:模型选择、欠拟合和过拟合

文章目录

  • 一、权重衰退
    • 1.1 硬性限制
    • 1.2 柔性限制(正则化)
    • 1.3 参数更新法则
    • 1.4 总结
  • 二、代码实现
    • 2.1 从零开始实现
      • 2.1.1 人工数据集
      • 2.1.2 模型参数
      • 2.1.3 L2L_2L2​ 范数惩罚
      • 2.1.4 训练
    • 2.2 简洁实现

一、权重衰退

1.1 硬性限制

  • 在上一篇文章中,我们讲到了控制模型容量的两种方法:

    • 使用较小的参数(使得模型变小)
    • 使参数可选择的值比较少
  • 权重衰退通过限制参数值的选择范围来控制模型容量。
    • 例如,我们可以在最小化损失函数的时候增加一个限制,防止权重过大:
      min⁡l(w⃗,b)subject to∣∣w⃗∣∣2≤θ\min l(\vec{w}, b) \quad \text{subject to} \quad ||\vec{w}||^2 \leq \theta minl(w,b)subject to∣∣w∣∣2≤θ本例中,我们限制 w⃗\vec{w}w 的 L2L_2L2​ 损失不大于 θ\thetaθ
    • 我们通常不限制偏移 bbb(限不限制都差不多)
    • 选择一个小的 θ\thetaθ 意味着更强的正则项

1.2 柔性限制(正则化)

我们通常不会采用上一小节中那样的硬性限制,而是通过正则化这种柔性限制来控制模型容量。

  • L2L_2L2​ 正则化:

    • 对每个 θ\thetaθ,都可以找到 λ\lambdaλ 使得之前的目标函数等价于下式:
      min⁡l(w⃗,b)+λ2∣∣w⃗∣∣2\min{l(\vec w,b)}+\frac {\lambda}{2}||\vec w||^2minl(w,b)+2λ​∣∣w∣∣2
    • 可以通过拉格朗日乘子来证明
    • 超参数 λ\lambdaλ 控制了正则项的重要程度
      • λ=0\lambda=0λ=0 时,正则项不起作用
      • λ→∞\lambda \rightarrow \inftyλ→∞ 时,w⃗→0⃗\vec w \rightarrow\vec 0w→0
  • L1L_1L1​ 正则化:
    • 使得大部分模型参数的值等于0,已达到模型稀疏化的目的。
    • 其公式为:
      min⁡l(w⃗,b)+λ∣∣w⃗∣∣1\min{l(\vec w,b)}+\lambda||\vec w||_1minl(w,b)+λ∣∣w∣∣1​
  • 演示:
    我们以L2L_2L2​ 正则化为例进行演示,下图中:
    w⃗∗=arg⁡min⁡l(wˉ,b)+λ2∣∣wˉ∣∣2w⃗~∗=arg⁡min⁡l(wˉ~,b)\begin{aligned}&\vec w*=\arg\min{l(\bar w,b)+\frac{\lambda}2||\bar w||^2} \\ &\tilde{\vec w}*=\arg\min{l(\tilde{\bar w},b)}\end{aligned}​w∗=argminl(wˉ,b)+2λ​∣∣wˉ∣∣2w~∗=argminl(wˉ~,b)​

    绿色的曲线为只优化损失值的情况,黄色曲线为加入了正则项的情况。正则项会将权重的值从原本离原点较远的较大值,拉扯到离原点较近的较小值,从而实现对参数大小的控制。

1.3 参数更新法则

  • 计算梯度:
    ∂∂w⃗(l(w⃗,b)+λ2∣∣w⃗∣∣2)=∂l(w⃗,b)∂w+λw⃗\frac{\partial}{\partial \vec w}\Big( l(\vec w,b)+\frac{\lambda}2||\vec w||^2 \Big)=\frac{\partial l(\vec w, b)}{\partial w}+\lambda \vec w∂w∂​(l(w,b)+2λ​∣∣w∣∣2)=∂w∂l(w,b)​+λw
  • 更新参数(时间 t):
    w⃗t+1=(1−ηλ)w⃗t−η∂l(w⃗t,bt)∂w⃗t\vec w_{t+1}=(1-\eta \lambda)\vec w_t-\eta\frac{\partial l(\vec w_t, b_t)}{\partial \vec w_t}wt+1​=(1−ηλ)wt​−η∂wt​∂l(wt​,bt​)​

    • 通常 ηλ<1\eta \lambda<1ηλ<1,在深度学习中通常叫作权重衰退。这意味着每次更新参数时,现将原本的参数值缩小一些,再沿着梯度方向更新。

1.4 总结

  • 权重衰退通过 L2L_2L2​ 正则项使得模型参数不会过大,从而控制模型复杂度。
  • 正则项权重是控制模型复杂度的超参数。

二、代码实现

2.1 从零开始实现

2.1.1 人工数据集

权重衰退是最广泛使用的正则化的技术之一。

import torch
from torch import nn
from d2l import torch as d2l

生成人工数据集:
y=0.05+∑i=1d0.01xi+ϵwhereϵ∼N(0,0.012)y=0.05+\sum_{i=1}^d0.01x_i + \epsilon \quad \text{where} \quad \epsilon\sim \mathcal{N}(0, 0.01^2)y=0.05+i=1∑d​0.01xi​+ϵwhereϵ∼N(0,0.012)

n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size, is_train=True)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

2.1.2 模型参数

初始化模型参数

# 初始化模型参数
def init_params():w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)b = torch.zeros(1, requires_grad=True)return [w, b]

2.1.3 L2L_2L2​ 范数惩罚

定义 L2L_2L2​ 范数惩罚

# 定义L2范数惩罚
def l2_penalty(w):return torch.sum(w.pow(2)) / 2

2.1.4 训练

本次的训练函数和之前训练函数的最大区别是:增加了输入参数lambd。我们用超参数lambd来控制正则项的重要程度。当lambd等于0时,相当于没有正则化;当lambd趋近于无穷时,相当于权重趋近于0.

# 训练函数
def train(lambd):w, b = init_params()net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_lossnum_epochs, lr = 100, 0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:# with torch.enable_grad():l = loss(net(X), y) + lambd * l2_penalty(w)l.sum().backward()d2l.sgd([w, b], lr, batch_size)if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))d2l.plt.show()print('w的L2范数是:', torch.norm(w).item())

首先,我们令lambd=0,忽视正则化直接进行训练。

train(lambd=0)

此时发生了严重的过拟合,训练误差不断减小,但测试误差一直很高。结果如下图所示:

使用权重衰减后,解决了过拟合的问题。

train(lambd=3)

2.2 简洁实现

L2L_2L2​ 正则化可以写在目标函数中,也可以写在训练算法里面
在简洁实现中,我们将权重衰减写在训练算法中

def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs, 1))for param in net.parameters():param.data.normal_()loss = nn.MSELoss(reduction='none')num_epochs, lr = 100, 0.003trainer = torch.optim.SGD([{"params": net[0].weight,"weight_decay": wd}, {"params": net[0].bias }], lr=lr)animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:# with torch.enable_grad():trainer.zero_grad()l = loss(net(X), y)l.mean().backward()trainer.step()if (epoch + 1) % 5 == 0:animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())d2l.plt.show()

类似从零开始实现,我们也分别在不使用和使用正则化的情况下进行训练。

train_concise(0)
train_concise(3)

不使用正则化的结果如下图所示:

使用正则化的结果如下图所示:

下一篇:【动手学深度学习v2李沐】学习笔记08:丢弃法

【动手学深度学习v2李沐】学习笔记07:权重衰退、正则化相关推荐

  1. 使用AWS最便宜的GPU实例  from 动手学深度学习v2 李沐大神

    使用AWS最便宜的GPU实例  from 动手学深度学习v2 李沐大神 视频链接https://www.bilibili.com/video/BV1MA411L78X?t=493 由于购买的电脑没有配 ...

  2. 动手学深度学习V2——李沐Bilibili直播视频Jupyter Notebook安装

    在哔哩哔哩上发现李沐是视频直播讲解<动手学深度V2>- Pytorch,准备按照视频中的安装教程来搭建一个新的虚拟环境d2l,李沐使用的是Jupyter Notebook 而不是 Pych ...

  3. 李沐动手学深度学习v2/总结1

    总结 编码过程 数据 数据预处理 模型 参数,初始化参数 超参数 损失函数,先计算损失,清空梯度(防止有累积的梯度),再对损失后向传播计算损失关于参数的梯度 优化算法,使用优化算法更新参数 训练求参数 ...

  4. 资源 | 李沐等人开源中文书《动手学深度学习》预览版上线

    来源:机器之心 本文约2000字,建议阅读10分钟. 本文为大家介绍了一本交互式深度学习书籍. 近日,由 Aston Zhang.李沐等人所著图书<动手学深度学习>放出了在线预览版,以供读 ...

  5. 最新版 | 2020李沐《动手学深度学习》

    点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 强烈推荐李沐等人的<动手学深度学习>最新版!完整中文版 PDF 终于 在 ...

  6. 用PyTorch实现的李沐《动手学深度学习》,登上GitHub热榜,获得700+星

    晓查 发自 凹非寺  量子位 报道 | 公众号 QbitAI 李沐老师的<动手学深度学习>是一本入门深度学习的优秀教材,也是各大在线书店的计算机类畅销书. 作为MXNet的作者之一,李沐老 ...

  7. 李沐《动手学深度学习》PyTorch 实现版开源,瞬间登上 GitHub 热榜!

    点击上方"AI有道",选择"星标"公众号 重磅干货,第一时间送达 李沐,亚马逊 AI 主任科学家,名声在外!半年前,由李沐.Aston Zhang 等人合力打造 ...

  8. 【深度学习】李沐《动手学深度学习》的PyTorch实现已完成

    这个项目是中文版<动手学深度学习>中的代码进行整理,用Pytorch实现,是目前全网最全的Pytorch版本. 项目作者:吴振宇博士 简介   Dive-Into-Deep-Learnin ...

  9. 李沐《动手学深度学习》新增PyTorch和TensorFlow实现,还有中文版

    李沐老师的<动手学深度学习>已经有Pytorch和TensorFlow的实现了,并且有了中文版. 网址:http://d2l.ai/ 简介 李沐老师的<动手学深度学习>自一年前 ...

最新文章

  1. 企业分布式微服务云SpringCloud SpringBoot mybatis (九)服务链路追踪(Spring Cloud Sleuth)...
  2. console js刷新页面_Console.js使用说明
  3. 为什么 Django 框架持续统治着 Python 开发?
  4. ML.NET 1.3.1 发布,.NET 跨平台机器学习框架
  5. 二维数组中最大连通子数组
  6. 免费 | 开源操作系统年度盛会最新日程曝光,邀您一同开启烧脑模式!
  7. MFC的Dialogbox多行文本框(CEdit)有最大字符限制,默认最大显示长度
  8. python问题整理
  9. (转)常用英语100句
  10. paip.python错误解决19
  11. 《Thinkphp5使用Socket服务》 入门篇
  12. webstorm汉化后乱码现象解决
  13. mysql根据身份证得到年龄_MySQL根据身份证获取省份 年龄 性别
  14. 中国石油井架行业发展前景与投资盈利预测报告2022-2027
  15. 英特尔服务器主板g41性能,英特尔g41显卡好用吗 英特尔g41显卡评测【详解】
  16. 前端开发step1,2,3
  17. 关于Tomcat以及我是个小机灵鬼这回事
  18. glTF模型在线查看利器【glTF Viewer 2.0】
  19. python的cfg是什么模块_python操作cfg配置文件方式
  20. android bean对象,Android GreenDao 保存 JavaBean 或者List JavaBean类型数据

热门文章

  1. 使用Kubernetes最常见的10个错误
  2. 构建之法阅读笔记002
  3. 染书CRMA|一个贴身的智慧校园
  4. SIIM-ACR Pneumothorax Segmentation 气胸x光识别比赛数据处理
  5. 欧国联 法国 vs 德国
  6. three.js加载obj模型和材质
  7. 用XAMPP搭建本地:Web服务器,访问服务器,下载服务器。
  8. 某型无人机群的监视覆盖任务航路规划
  9. 2022年怎样的企业才能迎难而上?这场年会给你答案
  10. 2015 电子科大校园招聘名单(更新中)