©PaperWeekly 原创 · 作者 | 苏剑林

单位 | 追一科技

研究方向 | NLP、神经网络

在本博客中,已经多次讨论过梯度惩罚相关内容了。从形式上来看,梯度惩罚项分为两种,一种是关于输入梯度惩罚与参数梯度惩罚的一个不等式在本博客中,已经多次讨论过梯度惩罚相关内容了。从形式上来看,梯度惩罚项分为两种,一种是关于输入的梯度惩罚 ,在《对抗训练浅谈:意义、方法和思考(附Keras实现)》、《泛化性乱弹:从随机噪声、梯度惩罚到虚拟对抗训练》等文章中我们讨论过,另一种则是关于参数的梯度惩罚 ,在《从动力学角度看优化算法(五):为什么学习率不宜过小?》、《我们真的需要把训练集的损失降低到零吗?》[1] 等文章我们讨论过。

在相关文章中,两种梯度惩罚都声称有着提高模型泛化性能的能力,那么两者有没有什么联系呢?笔者从 Google 最近的一篇论文《The Geometric Occam's Razor Implicit in Deep Learning》[2] 学习到了两者的一个不等式,感觉以后可能用得上,在此做个笔记。

最终结果

假设有一个 l 层的 MLP 模型,记为:

其中 是当前层的激活函数,,并记为,即模型的原始输入,为了方便后面的推导,我们记 ;参数全体为 。设 是 的任意标量函数,那么成立不等式:

其中上式中 、和 用的是普通的 范数,也就是每个元素的平方和再开平方,而 和 用的则是矩阵的“谱范数”(参考《深度学习中的 Lipschitz 约束:泛化与生成模型》)。该不等式显示,参数的梯度惩罚一定程度上包含了输入的梯度惩罚。

推导过程

显然,为了不等式(2),我们只需要对每一个参数证明:

然后遍历所有 ,将每一式左右两端相加即可。这两个不等式的证明本质上是一个矩阵求导问题,但多数读者可能跟笔者一样,都不熟悉矩阵求导,这时候最佳的办法就是写出分量形式,然后就变成标量的求导问题。

具体来说, 写成分量形式:

然后由链式法则:

然后:

这里 是克罗内克符号。现在我们可以写出:

代入(6)得到:

两边乘以 得:

约定原始向量为列向量,求梯度后矩阵的形状反转,那么上述可以写成矩阵形式:

两边左乘 得:

两边取范数得:

等于第二个不等号来说,矩阵的范数用 范数或者谱范数都是成立的。于是选择所需要的范数后,整理可得式(3);至于式(4)的证明类似,这里不再重复。

简单评析

可能有读者会想问具体该如何理解式(2)?事实上,笔者主要觉得式(2)本身有点意思,以后说不准在某个场景用得上,所以本文主要是对此做个“笔记”,但对它并没有很好的解读结果。

至于原论文的逻辑顺序是这样的:在《从动力学角度看优化算法(五):为什么学习率不宜过小?》中我们介绍了《Implicit Gradient Regularization》(跟本篇论文同一作者),里边指出 SGD 隐式地包含了对参数的梯度惩罚项,而式(2)则说明对参数的梯度惩罚隐式地包含了对输入的梯度惩罚,而对输入的梯度惩罚又跟 Dirichlet 能量有关,Dirichlet 能量则可以作为模型复杂度的表征。所以总的一串推理下来,结论就是:SGD 本身会倾向于选择复杂度比较小的模型

不过,原论文在解读式(2)时,犯了一个小错误。它说初始阶段的 会很接近于 0,所以式(2)中括号的项会很大,因此如果要降低式(2)右边的参数梯度惩罚,那么必须要使得式(2)左边的输入梯度惩罚足够小。然而从《从几何视角来理解模型参数的初始化策略》[3] 我们知道,常用的初始化方法其实接近于正交初始化,而正交矩阵的谱范数其实为 1,如果考虑激活函数,那么初始化的谱范数其实还大于 1,所以初始化阶段 会很接近于 0 是不成立的。

事实上,对于一个没有训练崩的网络,模型的参数和每一层的输入输出基本上都会保持一种稳定的状态,所以其实整个训练过程中 、、 其实波动都不大,因此右端对参数的梯度惩罚近似等价于左端对输入的乘法惩罚。这是笔者的理解,不需要“ 会很接近于 0”的假设。

文章小结

本文主要介绍了两种梯度惩罚项之间的一个不等式,并给出了自己的证明以及一个简单的评析。

参考文献

[1] https://kexue.fm/archives/7643

[2 ]https://arxiv.org/abs/2111.15090

[3] https://kexue.fm/archives/7180

特别鸣谢

感谢 TCCI 天桥脑科学研究院对于 PaperWeekly 的支持。TCCI 关注大脑探知、大脑功能和大脑健康。

更多阅读

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

输入梯度惩罚与参数梯度惩罚的一个不等式相关推荐

  1. 泛化性乱弹:从随机噪声、梯度惩罚到虚拟对抗训练

    ©PaperWeekly 原创 · 作者|苏剑林 单位|追一科技 研究方向|NLP.神经网络 提高模型的泛化性能是机器学习致力追求的目标之一.常见的提高泛化性的方法主要有两种:第一种是添加噪声,比如往 ...

  2. 梯度惩罚(Pytorch)

    起因:希望深度学习输入时小的扰动不会影响结果,我们会在输入端加一些噪声,让模型自己去适应这种扰动,从而提升整体的鲁棒性,CV领域可以直接在图像输入添加噪声,NLP领域因为输入都是one-hot形式,无 ...

  3. cnn 反向传播推导_深度学习中的参数梯度推导(三)下篇

    前言 在深度学习中的参数梯度推导(三)中篇里,我们总结了CNN的BP推导第一步:BP通过池化层时梯度的计算公式.本篇(下篇)则继续推导CNN相关的其他梯度计算公式. 注意:本文默认读者已具备深度学习上 ...

  4. eta 深度学习 参数_深度学习中的参数梯度推导(一)

    必备的数学知识 矩阵微分与求导 前言 深度学习向来被很多人认为是"黑盒",因为似乎很多人都不清楚深度学习的运作方式,本系列<深度学习中的数学>的连载文章主要目的就是向大 ...

  5. 逻辑回归的参数计算:牛顿法,梯度下降法,随机梯度下降法

    逻辑回归的参数计算:牛顿法,梯度下降(上升)法,随机梯度下降法,批量梯度下降法 前面文章中对逻辑回归进行了讲解,下面来说一说逻辑回归的参数是怎么计算的. 逻辑回归的计算使用的是最大似然方法.记 z i ...

  6. 梯度值与参数更新optimizer.zero_grad(),loss.backward、和optimizer.step()、lr_scheduler.step原理解析

    在用pytorch训练模型时,通常会在遍历epochs的过程中依次用到optimizer.zero_grad(),loss.backward.和optimizer.step().lr_schedule ...

  7. 神经网络中的梯度是什么,神经网络梯度公式推导

    1.BP神经网络的MATLAB训练Gradient是什么意思?Performance是什么意思?,大神能解释一下吗?谢谢了 Gradient是梯度的意思,BP神经网络训练的时候涉及到梯度下降法,表示为 ...

  8. 不使用梯度裁剪和使用梯度裁剪的对比(tensorflow)

    一:不使用梯度裁剪 #网络搭建和模型训练 import tensorflow as tf from tensorflow.keras import layers,optimizers,datasets ...

  9. 大白话5分钟带你走进人工智能-第十一节梯度下降之手动实现梯度下降和随机梯度下降的代码(6)...

                                第十一节梯度下降之手动实现梯度下降和随机梯度下降的代码(6) 我们回忆一下,之前咱们讲什么了?梯度下降,那么梯度下降是一种什么算法呢?函数最优化 ...

最新文章

  1. python下载的文件放在哪里的-python实现文件下载的方法总结
  2. 机器学习算法小结与收割offer遇到的问题
  3. Period_JAVA
  4. Akka之actor模型
  5. WaveShaperNode
  6. 全球研发投入榜:中国第二逼近美国,以色列最下血本 | 联合国数据
  7. python文件夹,文件监听工具(pyinotify,watchdog)
  8. Office2016+Visio2016安装教程(超简单)
  9. matlab|dsolve解决常微分初值与讲解(含实例使用)
  10. sla java_Grafana中滑动窗口的Prometheus正常运行时间或SLA百分比
  11. 天猫服饰新推“良品臻选”,请了一群挑剔的女人给服装“挑刺”
  12. SnakeGame(贪吃蛇游戏)
  13. 非线性方程的数值解法:牛顿法及牛顿下山法(含Matlab程序)
  14. 29岁了还一事无成是人生的常态?
  15. 【小沐学qt】生成二维码
  16. 《大数据》第七章 聚类 K-means算法 BFR算法 CURE算法
  17. 游戏道具平台|基于Springboot+Vue实现游戏道具平台系统
  18. 基于C++的DES的EBC电子密码本加解密,CBC密码分组链接思想,以及相关流程图
  19. 如何保证API不被别人恶意调用(彩蛋)
  20. SpringBoot+SpringSecurity+MySQL+Html图书管理系统

热门文章

  1. 打开微型计算机的电源时,计算机操作与使用试题(有答案)
  2. Android 拦截WebView请求,并加入或修改参数(GET)
  3. VBA遍历文件夹下文件文件实用源码
  4. ambari 维护模式及reset API 操作
  5. art-template-loader:template
  6. Arrays和Collection之间的转换
  7. wordpress woodstock主题导入demo xml文件 execution time out
  8. IOS TextField设置大全
  9. PHP面向对象的进阶学习
  10. C/C++编译预处理指令