torch.optim.SGD

torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False):随机梯度下降

  • 【我的理解】虽然叫做“随机梯度下降”,但是本质上还是还是实现的批量梯度下降,即用全部样本梯度的均值更新可学习参数。

    这里所说的全部样本可以是全部数据集,也可以是一个batch,为什么这么说?因为计算梯度是调用backward函数计算的,而backward函数又是通过损失值张量调用的,损失值的计算和样本集的选取息息相关。如果每次都使用全部样本计算损失值,那么很显然调用SGD时也就是用全部的样本梯度的均值去更新可学习参数,如果每次使用一个batch的样本计算损失值,再调用backward,那么调用SGD时就会用这个batch的梯度均值更新可学习参数。

  • params:要训练的参数,一般我们传入的都是model.parameters()

  • lr:learning_rate学习率,会梯度下降的应该都知道学习率吧,也就是步长。

  • weight_decay(权重衰退)和learning_rate(学习率)的区别

    learning_rate就是我们熟知的更新权重的方式,假设可学习参数为 θ\thetaθ ,学习率由γ\gammaγ表示,梯度均值为 ggg;计算出这批样本对应的梯度为 gtg_tgt​,迭代完第t−1t-1t−1次可学习参数的值为θt−1\theta_{t-1}θt−1​,当前第ttt次迭代的更新方式为θt=θt−1−γgt\theta_t = \theta_{t-1}-\gamma g_tθt​=θt−1​−γgt​,即梯度下降法。

    weight_decay是在L2正则化理论中出现的概念。

    什么是L2正则化?
    L2范数也被称为“权重衰减”和“岭回归”,L2的主要作用是解决过拟合,L2范数是所有权重的平方开方和,最小化L2范数,使得权重趋近于0,但是不会为0。那么为什么参数的值小,就能避免过拟合。模型的参数比较小,说明模型简单,泛化能力强。参数比较小,说明某些多项式分支的作用很小,降低了模型的复杂程度。其次参数很大,一个小的变动,都会产生很大的不同。所以采用L2正则化可以很好地解决过拟合的问题。

    在损失函数中加入L2正则化项后(大致)变为
    J(θ)=12m[∑i=1m(hθ(x(i))−y(i))2+λ∑j=1nθj2]J(\theta)=\frac{1}{2m}[\sum_{i=1}^{m} (h_{\theta}(x^{(i)})-y^{(i)})^2+\lambda\sum_{j=1}^{n}\theta^2_j] J(θ)=2m1​[i=1∑m​(hθ​(x(i))−y(i))2+λj=1∑n​θj2​]
    其中,λ∑j=1nθj2\lambda\sum_{j=1}^{n}\theta^2_jλ∑j=1n​θj2​就是正则化项,λ\lambdaλ就是weight_decay。

    根据上面的加入正则项的损失函数我们可以不严谨地推导一下,加入的正则项是权重的平方和,为了保证损失函数尽可能小,就要让权重和尽量小,所以权重会倾向于训练出较小的值。

    可是梯度不是由backward计算的吗,而backward中可未指明存在正则化项,那不成还要由优化器算法再算一遍新梯度?

    先看看官方文档是如何操作的:


    只看权重衰退和学习率,描述一下过程,计算使用第t−1t-1t−1次训练到的可学习参数的模型在这一批次中的梯度,赋值给gtg_tgt​,再用权重衰退乘上第t−1t-1t−1次训练到的可学习参数,重新赋值给gtg_tgt​,再用这个gtg_tgt​和学习率更新权重,从而实现训练参数的目的。

    也就是说正则化的实现是在原梯度的基础上加上一个系数乘以可学习参数值作为新梯度,用新梯度更新可学习参数值。并没有重新计算加入正则项后的损失函数的梯度。

    本质上是对∂J(θi)∂θi\frac{\partial{J(\theta_i)}}{\partial\theta_i}∂θi​∂J(θi​)​进行一下变形的结果。

    为了推导的方便,假设加入了正则化项后的损失函数为
    RegularizedCost=Cost+λ∑iθi2RegularizedCost=Cost+\lambda\sum_i\theta_i^2 RegularizedCost=Cost+λi∑​θi2​

    加入正则化项后的梯度为:
    ∂RegularizedCost∂θi=∂Cost∂θi+2λθi\frac{\partial{RegularizedCost}}{\partial{\theta_i}} = \frac{\partial Cost}{\partial \theta_i} + 2\lambda\theta_i ∂θi​∂RegularizedCost​=∂θi​∂Cost​+2λθi​

    可以看出∂Cost∂θi\frac{\partial Cost}{\partial \theta_i}∂θi​∂Cost​的值就是未加正则化项求出的梯度,即backward计算出的梯度。

    ∂RegularizedCost∂θi\frac{\partial{RegularizedCost}}{\partial{\theta_i}}∂θi​∂RegularizedCost​将会作为新的梯度值用于更新权重,即
    Updatedθi=θi−γ∂RegularizedCost∂θi⇒Updatedθi=γ(∂Cost∂θi+2λθi)Updated\space\space\theta_i = \theta_i - \gamma \frac{\partial{RegularizedCost}}{\partial{\theta_i}} \\ \space \\ \Rightarrow Updated\space\space\theta_i = \gamma (\frac{\partial Cost}{\partial \theta_i}+2\lambda\theta_i) Updated  θi​=θi​−γ∂θi​∂RegularizedCost​ ⇒Updated  θi​=γ(∂θi​∂Cost​+2λθi​)
    求导数得到的系数2可以扔掉,所以就得到了官方文档中的过程。

  • SGD方法的一个缺点是,其更新方向完全依赖于当前的batch,因而其更新十分不稳定。解决这一问题的一个简单的做法便是引入momentum。

    momentum即动量,它模拟的是物体运动时的惯性,即更新的时候在一定程度上保留之前更新的方向,同时利用当前batch的梯度微调最终的更新方向。这样一来,可以在一定程度上增加稳定性,从而学习地更快,并且还有一定摆脱局部最优的能力。


    从中可以看出,如果我们设置momentum参数不为0,则判断一下是否是第一次迭代,如果是第一次迭代,那么直接将 gtg_tgt​ 赋值给 btb_tbt​ (将btb_tbt​理解为临时变量吧),如果不是第一次迭代,那么就用前一次的 bt−1b_{t-1}bt−1​ 更新 btb_tbt​,更新方程为 bt=μbt−1+(1−τ)gtb_t=\mu b_{t-1} + (1-\tau)g_tbt​=μbt−1​+(1−τ)gt​(反正就是个公式,其中 τ\tauτ 为抑制参数)。由于不考虑nesterov参数,即为False,所以直接将计算得到的 btb_tbt​ 赋值回 gtg_tgt​ 。

  • 如果对Nesterov Momentum(牛顿动量)感兴趣可以看看Hinton这篇论文


参考

[1] How does SGD weight_decay work? - autograd - PyTorch Forums

[2] L1、L2正则化知识详解 - 简书

[3] L1,L2正则化的原理与区别 - CSDN博客

[4] 各种优化方法总结比较(sgd/momentum/Nesterov/adagrad/adadelta) - 博客园

torch.optim.SGD参数详解(除nesterov)相关推荐

  1. torch.optim.sgd参数详解

    SGD(随机梯度下降)是一种更新参数的机制,其根据损失函数关于模型参数的梯度信息来更新参数,可以用来训练神经网络.torch.optim.sgd的参数有:lr(学习率).momentum(动量).we ...

  2. sgd 参数 详解_关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)

    torch.optim的灵活使用详解 1. 基本用法: 要构建一个优化器Optimizer,必须给它一个包含参数的迭代器来优化,然后,我们可以指定特定的优化选项, 例如学习速率,重量衰减值等. 注:如 ...

  3. sgd 参数 详解_代码笔记--PC-DARTS代码详解

    DARTS是可微分网络架构搜搜索,PC-DARTS是DARTS的拓展,通过部分通道连接的方法在网络搜索过程中减少计算时间的内存占用.接下来将会结合论文和开源代码来详细介绍PC-DARTS. 1 总体框 ...

  4. PyTorch搜索Tensor指定维度的前K大个(K小个)元素--------(torch.topk)命令参数详解及举例

    torch.topk 语法 torch.topk(input, k, dim=None, largest=True, sorted=True, *, out = None) 作用 返回输入tensor ...

  5. torch.optim.SGD()

    其中的SGD就是optim中的一个算法(优化器):随机梯度下降算法 PyTorch 的优化器基本都继承于 "class Optimizer",这是所有 optimizer 的 ba ...

  6. conv2d的输入_pytorch1.0中torch.nn.Conv2d用法详解

    Conv2d的简单使用 torch 包 nn 中 Conv2d 的用法与 tensorflow 中类似,但不完全一样. 在 torch 中,Conv2d 有几个基本的参数,分别是 in_channel ...

  7. pytorch---之BN层参数详解及应用(1,2,3)(1,2)?

    BN层参数详解(1,2) 一般来说pytorch中的模型都是继承nn.Module类的,都有一个属性trainning指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层(对 ...

  8. PyTorch实现AlexNet模型及参数详解

    文章目录 一.卷积池化层原理 二.全连接层原理 三.模型参数详解 注:AlexNet论文错误点 1.卷积池化层1 (1)卷积运算 (2)分组 (3)激活函数层 (4)池化层 (5)归一化处理 (6)参 ...

  9. pytorch之torch.nn.Conv2d()函数详解

    文章目录 一.官方文档介绍 二.torch.nn.Conv2d()函数详解 参数详解 参数dilation--扩张卷积(也叫空洞卷积) 参数groups--分组卷积 三.代码实例 一.官方文档介绍 官 ...

  10. python:flatten()参数详解

    python:flatten()参数详解 这篇博客主要写flatten()作用,及其参数的含义 flatten()是对多维数据的降维函数. flatten(),默认缺省参数为0,也就是说flatten ...

最新文章

  1. python对字幕的改动
  2. 人脸特征值能存放在sql server中吗_钥匙丢了进不了门,Out了!只要自己没丢就能进门...
  3. 计算机网络实验设计应用题,计算机网络实验三实验报告.doc
  4. PMCAFF今天摆地摊了,然后……
  5. [转] Apache日志分析常用Shell命令
  6. cut、grep和排序命令
  7. 语言4位bcd码怎么加加_S7300400如何使用SCL语言调用SFC1(READ_CLK)读取日期和时间?...
  8. CVPR 2020|打脸SOTA!不能忍,谷歌发起图像匹配挑战赛
  9. 更深更宽的孪生网络,有效提升目标跟踪精度,代码开源
  10. Node.js 14 发布,改进了诊断功能
  11. 转载-配置tomcat让shtml嵌套文件显示
  12. Trustdata:《2018年Q1中国移动互联网行业发展分析报告》
  13. Landsat卫星MSS/TM/ETM数据(转自ESRI社区)(二)
  14. 查看java/jdk版本
  15. ubuntu下搜狗拼音输入法不见了
  16. 获取浏览器唯一标识_探讨浏览器指纹 fingerprint
  17. AXD 汇编调试经验,使用及问题
  18. 身份证男女识别---进一步优化03
  19. Matplotlib常见图形绘制(折线图、散点图 、柱状图 、直方图 、饼图 、条形图)
  20. 解决麒麟V10上传文件乱码问题

热门文章

  1. 2022年Cs231n PPT笔记-训练CNN
  2. 互联网快讯:粉笔科技布局线下打造双核驱动;极米产品获用户青睐;迅雷发布2021年财报;荣耀Magic4系列国内发布
  3. SC16IS750在STM32的应用
  4. Qt之调用Windows图片查看器预览图片
  5. NOI(OJ)编程基础篇
  6. 转载-卷影复制服务(VSS)详细介绍
  7. 无法确认设备和计算机之间的连接,如何解决“爱思助手”无法识别设备或连接超时等故障?...
  8. 金融信息化及交易管理系统(股票交易系统APP)
  9. 短视频剪辑怎么做?4步教你快速入门
  10. 博弈论(五)——#10247. 「一本通 6.7 练习 4」S-Nim