详解机器学习/深度学习中的梯度消失/梯度爆炸的原因/解决方案

本文主要深入介绍深度学习中的梯度消失和梯度爆炸的问题以及解决方案。本文分为三部分,第一部分主要直观的介绍深度学习中为什么使用梯度更新,第二部分主要介绍深度学习中梯度消失及爆炸的原因,第三部分对提出梯度消失及爆炸的解决方案。有基础的同鞋可以跳着阅读。其中,梯度消失爆炸的解决方案主要包括以下几个部分。

预训练加微调
梯度剪切、权重正则(针对梯度爆炸)
使用不同的激活函数
使用BatchNormal
使用残差结构
使用LSTM网络

第一部分:为什么要使用梯度更新规则(梯度下降)

在介绍梯度消失以及爆炸之前,先简单说一说梯度消失的根源—–深度神经网络和反向传播。目前深度学习方法中,深度神经网络的发展造就了我们可以构建更深层的网络完成更复杂的任务,深层网络比如深度卷积网络,LSTM等等,而且最终结果表明,在处理复杂任务上,深度网络比浅层的网络具有更好的效果。但是,目前优化神经网络的方法都是基于反向传播的思想,即根据损失函数计算的误差通过梯度反向传播的方式,指导深度网络权值的更新优化。这样做是有一定原因的,首先,深层网络由许多非线性层堆叠而来,每一层非线性层都可以视为是一个非线性函数 f(x)f(x)f(x) (非线性来自于非线性激活函数),因此整个深度网络可以视为是一个复合的非线性多元函数

F(x)=fn(...f3(f2(f1(x)∗θ1+b)∗θ2+b)...)F(x)=f _n (...f_3 (f_2 (f_1 (x)∗θ_1 +b)∗θ_2+b)...)F(x)=fn​(...f3​(f2​(f1​(x)∗θ1​+b)∗θ2​+b)...)

其中f(x)f(x)f(x)是激活函数,bbb是偏置量biasbiasbias,这么看来神经网络的传播就显得很好理解了。
我们最终的目的是希望这个多元函数可以很好的完成输入到输出之间的映射,假设不同的输入,输出的最优解是g(x)g(x)g(x) ,那么,优化深度网络就是为了寻找到合适的权值,满足Loss=L(g(x),F(x))Loss=L(g(x),F(x))Loss=L(g(x),F(x))取得极小值点,比如最简单的损失函数:

Loss=∣∣g(x)−f(x)∣∣22Loss=∣∣g(x)−f(x)∣∣_2^2Loss=∣∣g(x)−f(x)∣∣22​

假设损失函数的数据空间是下图这样的,我们最优的权值就是为了寻找下图中的最小值点,对于这种数学寻找最小值问题,采用梯度下降的方法再适合不过了。

第二部分:梯度消失、爆炸

梯度消失与梯度爆炸其实是一种情况,看接下来的文章就知道了。
梯度消失: 1. 在深层网络中。2. 采用了不合适的激活函数,比如sigmoid
梯度爆炸: 1. 在深层网络中。2. 权值初始化值太大的情况下

下面分别从这两个角度分析梯度消失和爆炸的原因。

1. 深层网络角度

比较简单的深层网络如下:

图中是一个四层的全连接网络,假设每一层网络激活后的输出为fi(x)f_i(x)fi​(x),其中iii为第iii层, xxx代表第iii层的输入,也就是第i−1i−1i−1层的输出,fff是激活函数,那么,得出fi+1=f(fi∗wi+1+bi+1)f_{i+1}=f(f_i∗w_{i+1}+b_{i+1})fi+1​=f(fi​∗wi+1​+bi+1​),简单记为:fi+1=f(fi∗wi+1)f_{i+1}=f(f_i*w_{i+1})fi+1​=f(fi​∗wi+1​)

BP算法基于梯度下降策略,以目标的负梯度方向对参数进行调整,参数的更新为 w←w+Δww \leftarrow w+\Delta ww←w+Δw ,给定学习率 α\alphaα,得出 Δw=−α∂Loss∂w\Delta w=-\alpha \frac{\partial Loss}{\partial w}Δw=−α∂w∂Loss​。如果要更新第二隐藏层的权值信息,根据链式求导法则,更新梯度信息:

Δw2=∂Loss∂w2=∂Loss∂f4∗∂f4∂f3∗∂f3∂f2∗∂f2∂w2Δw_2 = \frac{∂Loss}{∂w_2} = \frac{∂Loss}{∂f_4}* \frac{∂f_4}{∂f_3}*\frac{∂f_3}{∂f_2}*\frac{∂f_2}{∂w_2}Δw2​=∂w2​∂Loss​=∂f4​∂Loss​∗∂f3​∂f4​​∗∂f2​∂f3​​∗∂w2​∂f2​​

很容易看出来∂f2∂w2=∂f∂(f1∗w2)∗f1\frac{\partial f_2}{\partial w_2}=\frac{\partial f}{\partial (f_1*w_2)}*f_1∂w2​∂f2​​=∂(f1​∗w2​)∂f​∗f1​,即第二隐藏层的输入。

所以说,∂f4∂f3\frac{\partial f_4}{\partial f_3}∂f3​∂f4​​ 就是对激活函数进行求导,如果此部分大于1,那么层数增多的时候,最终的求出的梯度更新将以指数形式增加,即发生梯度爆炸;如果此部分小于1,那么随着层数增多,求出的梯度更新信息将会以指数形式衰减,即发生了梯度消失。如果说从数学上看不够直观的话,下面几个图可以很直观的说明深层网络的梯度问题 1^11

那么对于四个隐层的网络来说,就更明显了,第四隐藏层比第一隐藏层的更新速度慢了两个数量级:

总结: 从深层网络角度来讲,不同的层学习的速度差异很大,表现为网络中靠近输出层的学习情况很好,靠近输入层的学习速度很慢,有时甚至训练了很久,前几层的权值和刚开始随机初始化的值差不多。因此,梯度消失、爆炸,其根本原因在于反向传播训练法则,属于先天不足,另外多说一句,HintonHintonHinton提出capsulecapsulecapsule的原因就是为了彻底抛弃反向传播,如果真能大范围普及,那真是一个革命。???

2.激活函数角度

上文中提到计算权值更新信息的时候需要计算前层偏导信息,因此如果激活函数选择不合适,比如使用 sigmoidsigmoidsigmoid,梯度消失就会很明显了,原因看下图,左图是 sigmoidsigmoidsigmoid 的损失函数图,右边是其导数的图像,如果使用 sigmoidsigmoidsigmoid 作为激活函数,其梯度是不可能超过0.25的,这样经过链式求导之后,很容易发生梯度消失,sigmoidsigmoidsigmoid 函数数学表达式为:sigmoid(x)=11+e−xsigmoid(x)=\frac{1}{1+e^{-x}}sigmoid(x)=1+e−x1​

同理,tanhtanhtanh 作为激活函数,它的导数图如下,可以看出,tanhtanhtanh 比sigmoidsigmoidsigmoid 要好一些,但是它的导数仍然是小于1的。tanhtanhtanh 数学表达为:

tanh(x)=ex−e−xex+e−xtanh(x)=\frac {e^x-e^{-x}}{e^x+e^{-x}}tanh(x)=ex+e−xex−e−x​

第三部分:梯度消失、爆炸的解决方案

解决方案1:预训练、微调

此方法来自Hinton在2006年发表的一篇论文,Hinton为了解决梯度的问题,提出采取无监督逐层训练方法,其基本思想是每次训练一层隐节点,训练时将上一层隐节点的输出作为输入,而本层隐节点的输出作为下一层隐节点的输入,此过程就是逐层“预训练”(pre-training);在预训练完成后,再对整个网络进行“微调”(fine-tunning)。Hinton在训练深度信念网络(Deep Belief Networks中,使用了这个方法,在各层预训练完成后,再利用BP算法对整个网络进行训练。此思想相当于是先寻找局部最优,然后整合起来寻找全局最优,此方法有一定的好处,但是目前应用的不是很多了

解决方案2:梯度剪切、正则

梯度剪切:主要是针对梯度爆炸提出的,其思想是设置一个梯度剪切阈值,然后更新梯度的时候,如果梯度超过这个阈值,那么就将其强制限制在这个范围之内。这可以防止梯度爆炸。

注:在WGAN中也有梯度剪切限制操作,但是和这个是不一样的,WGAN限制梯度更新信息是为了保证lipchitz条件。

另外一种解决梯度爆炸的手段是采用权重正则化(weithts regularization)比较常见的是 l1l1l1 正则,和 l2l2l2 正则,在各个深度框架中都有相应的API可以使用正则化,比如在tensorflow中:
若搭建网络的时候已经设置了正则化参数,则调用以下代码可以直接计算出正则损失:

regularization_loss = tf.add_n(tf.losses.get_regularization_losses(scope='my_resnet_50'))

如果没有设置初始化参数,也可以使用以下代码计算 l2l2l2 正则损失:

l2_loss = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables() if 'weights' in var.name])

正则化是通过对网络权重做正则限制过拟合,仔细看正则项在损失函数的形式:

Loss=(y−WTx)2+α∣∣W∣∣2Loss=(y−W^Tx)^2 +α∣∣W∣∣^2Loss=(y−WTx)2+α∣∣W∣∣2

其中,α\alphaα 是指正则项系数,因此,如果发生梯度爆炸,权值的范数就会变的非常大,通过正则化项,可以部分限制梯度爆炸的发生。

注:事实上,在深度神经网络中,往往是梯度消失出现的更多一些。

解决方案3:使用 relu、leakrelu、elu 等激活函数

relu

思想很简单,如果激活函数的导数为1,那么就不存在梯度消失爆炸的问题了,每层的网络都可以得到相同的更新速度,relurelurelu 就这样应运而生。先看一下 relurelurelu 的数学表达式:

其函数图像:

从上图中,我们可以很容易看出,relurelurelu 函数的导数在正数部分是恒等于1的,因此在深层网络中使用 relurelurelu 激活函数就不会导致梯度消失和爆炸的问题。

relurelurelu 的优点:

– 解决了梯度消失、爆炸的问题
– 计算方便,计算速度快
– 加速了网络的训练

同时也存在一些缺点:

– 由于负数部分恒为0,会导致一些神经元无法激活(可通过设置小学习率部分解决)
– 输出不是以0为中心的

尽管relurelurelu也有缺点,但是仍然是目前使用最多的激活函数。

leakrelu

leakreluleakreluleakrelu就是为了解决relurelurelu的0区间带来的影响,其数学表达为:leakrelu=max(k∗x,x)leakrelu=max(k∗x,x)leakrelu=max(k∗x,x)其中kkk是leakleakleak系数,一般选择0.01或者0.02,或者通过学习而来。

leakreluleakreluleakrelu解决了0区间带来的影响,而且包含了relurelurelu的所有优点。

elu

elueluelu激活函数也是为了解决relurelurelu的0区间带来的影响,其数学表达为:

其函数及其导数数学形式为:

但是elueluelu相对于leakreluleakreluleakrelu来说,计算要更耗时间一些。

解决方案4:BatchNorm

Batchnorm是深度学习发展以来提出的最重要的成果之一了,目前已经被广泛的应用到了各大网络中,具有加速网络收敛速度,提升训练稳定性的效果,Batchnorm本质上是解决反向传播过程中的梯度问题。batchnorm全名是batch normalization,简称BN,即批规范化,通过规范化操作将输出信号x规范化保证网络的稳定性。
具体的batchnorm原理非常复杂,在这里不做详细展开,此部分大概讲一下batchnorm解决梯度的问题上。具体来说就是反向传播中,经过每一层的梯度会乘以该层的权重,举个简单例子:
正向传播中f2=f1(wT∗x+b)f_2=f_1(w^T*x+b)f2​=f1​(wT∗x+b),那么反向传播中,∂f2∂x=∂f2∂f1w\frac {\partial f_2}{\partial x}=\frac{\partial f_2}{\partial f_1}w∂x∂f2​​=∂f1​∂f2​​w,反向传播式子中有www的存在,所以www的大小影响了梯度的消失和爆炸,batchnorm就是通过对每一层的输出规范为均值和方差一致的方法,消除了www带来的放大缩小的影响,进而解决梯度消失和爆炸的问题,或者可以理解为BN将输出从饱和区拉倒了非饱和区。

解决方案5:残差结构

残差结构说起残差的话,不得不提这篇论文了:《Deep Residual Learning for Image Recognition》,关于这篇论文的解读,可以参考知乎链接:https://zhuanlan.zhihu.com/p/31852747 这里只简单介绍残差如何解决梯度的问题(resnet为代表)。事实上,就是残差网络的出现导致了image net比赛的终结,自从残差提出后,几乎所有的深度网络都离不开残差的身影,相比较之前的几层,几十层的深度网络,在残差网络面前都不值一提,残差可以很轻松的构建几百层,一千多层的网络而不用担心梯度消失过快的问题,原因就在于残差的捷径(shortcut)部分,其中残差单元如下图所示:

相比较于以前网络的直来直去结构,残差中有很多这样的跨层连接结构,这样的结构在反向传播中具有很大的好处,见下式:

式子的第一个因子 ∂loss∂xL\frac{\partial loss}{\partial {{x}_{L}}}∂xL​∂loss​表示的损失函数到达 L 的梯度,小括号中的1表明短路机制可以无损地传播梯度,而另外一项残差梯度则需要经过带有weights的层,梯度不是直接传递过来的。残差梯度不会那么巧全为-1,而且就算其比较小,有1的存在也不会导致梯度消失。所以残差学习会更容易。

注:上面的推导并不是严格的证明。

解决方案6:LSTM

LSTM全称是长短期记忆网络(long-short term memory networks),是不那么容易发生梯度消失的,主要原因在于LSTM内部复杂的“门”(gates),如下图,LSTM通过它内部的“门”可以接下来更新的时候“记住”前几次训练的”残留记忆“,因此,经常用于生成文本中。目前也有基于CNN的LSTM,感兴趣的可以尝试一下。

介绍LSTM的一篇非常好的博客:
https://blog.csdn.net/gzj_1101/article/details/79376798

参考资料:

1.《Neural networks and deep learning》
2.《机器学习》周志华
3. https://www.cnblogs.com/willnote/p/6912798.html
4. https://www.zhihu.com/question/38102762
5. http://www.jianshu.com/p/9dc9f41f0b29

转载出处:

https://blog.csdn.net/qq_25737169/article/details/78847691#commentBox
感谢该博主的分享,真的能学到很多的东西。

详解机器学习/深度学习中的梯度消失/梯度爆炸的原因/解决方案相关推荐

  1. DL之AF:机器学习/深度学习中常用的激活函数(sigmoid、softmax等)简介、应用、计算图实现、代码实现详细攻略

    DL之AF:机器学习/深度学习中常用的激活函数(sigmoid.softmax等)简介.应用.计算图实现.代码实现详细攻略 目录 激活函数(Activation functions)相关配图 各个激活 ...

  2. 深度学习中多层全连接网络的梯度下降法及其变式

    深度学习中多层全连接网络的梯度下降法及其变式 1 梯度下降法 2 梯度下降的变式 1.SGD 2.Momentum 3.Adagrad 4.RMSprop 5.Adam 6.小结 1 梯度下降法 梯度 ...

  3. 梯度消失和梯度爆炸_梯度消失、爆炸的原因及解决办法

    一.引入:梯度更新规则 目前优化神经网络的方法都是基于反向传播的思想,即根据损失函数计算的误差通过梯度反向传播的方式,更新优化深度网络的权值.这样做是有一定原因的,首先,深层网络由许多非线性层堆叠而来 ...

  4. ztree在刷新时第一个父节点消失_从反向传播推导到梯度消失and爆炸的原因及解决方案(从DNN到RNN,内附详细反向传播公式推导)...

    引言:参加了一家公司的面试和另一家公司的笔试,都问到了这个题!看来很有必要好好准备一下,自己动手推了公式,果然理解更深入了!持续准备面试中... 一. 概述: 想要真正了解梯度爆炸和消失问题,必须手推 ...

  5. RNN梯度消失和爆炸的原因 以及 LSTM如何解决梯度消失问题

    RNN梯度消失和爆炸的原因 经典的RNN结构如下图所示: 假设我们的时间序列只有三段,  为给定值,神经元没有激活函数,则RNN最简单的前向传播过程如下: 假设在t=3时刻,损失函数为  . 则对于一 ...

  6. nfa确定化 dfa最小化_深度学习中的优化:梯度下降,确定全局最优值或与之接近的局部最优值...

    深度学习中的优化是一项极度复杂的任务,本文是一份基础指南,旨在从数学的角度深入解读优化器. 一般而言,神经网络的整体性能取决于几个因素.通常最受关注的是网络架构,但这只是众多重要元素之一.还有一个常常 ...

  7. 深度学习中的激活函数与梯度消失

    转载请注明出处:http://www.cnblogs.com/willnote/p/6912798.html 前言 深度学习的基本原理是基于人工神经网络,信号从一个神经元进入,经过非线性的激活函数,传 ...

  8. 机器学习(深度学习)中的反向传播算法与梯度下降

    这是自己在CSDN的第一篇博客,目的是为了给自己学习过的知识做一个总结,方便后续温习,避免每次都重复搜索相关文章. 一.反向传播算法 定义:反向传播(Backpropagation,缩写为BP)是&q ...

  9. 机器学习深度学习中反向传播之偏导数链式法则

    前记    无论是机器学习还是深度学习,都是构造目标函数,这个目标函数内部有很多未知变量,我们的目标就是求得这些未知变量.    那么如何构造目标函数?这是一个非常优美的话题(本文未讲,先欠着).美好 ...

  10. 机器学习/深度学习中的常用损失函数公式、原理与代码实践(持续更新ing...)

    诸神缄默不语-个人CSDN博文目录 最近更新时间:2023.5.8 最早更新时间:2022.6.12 本文的结构是首先介绍一些常见的损失函数,然后介绍一些个性化的损失函数实例. 文章目录 1. 分类 ...

最新文章

  1. 在docker镜像中加入环境变量
  2. who is the one who actually know the essential things in life?
  3. JAVA复习5(总结+循环链表)
  4. 八大排序:Java实现八大排序及算法复杂度分析
  5. 10个强大实用数据地图,不懂代码也能做!(附demo)
  6. 一文详解 Serverless 技术选型
  7. GDAL\OGR C#中文路径不支持的问题解决方法
  8. 数据包络分析--Malmquist指数
  9. pdf转换成word转换器注册码
  10. 浙大计算机海归教授,科学网—人才引进的“拿来主义”——我看浙江大学海外招聘 - 周波的博文...
  11. SUN软件包管理的命令:pkgadd
  12. 【VALSE 2019 PPT】香港科技大学沈劭劼最新研究-《无人机视觉感知与导航》-总结
  13. 忘记Jenkins管理员密码的解决办法
  14. SpringBoot整合JpaMapper实现基于mybatis的快速开发
  15. 如何维持手机电池寿命_关于如何延长智能手机电池寿命的一些提示
  16. asp毕业设计——基于asp+access的网上投票系统设计与实现(毕业论文+程序源码)——网上投票系统
  17. A2. Gsensor调试
  18. 空中夺命“杀手锏”!以色列研发致命性无人机,让人毛骨悚然
  19. Thinkpad L440 无线驱动突然无法使用,无法搜索到无线上网
  20. 系统System文件损坏或丢失的简单解决办法

热门文章

  1. vue音频wavesurfer波形图
  2. 【Microsoft Azure 的1024种玩法】三十四.将本地数据文件快速迁移到Azure Blob云存储最佳实践
  3. 关于django后台界面的美化
  4. Android - View 和 ViewGroup
  5. 校园导航系统 数据结构
  6. java时间处理--判断当前时间是否在一个时间区间内
  7. 独自一人开发返利平台小程序日记(准备开源中):万事开头难,既然做了,那就只能咬牙坚持了
  8. 土豆网总裁回忆与乔布斯的会面
  9. 知识积累 | GATK的使用
  10. 用计算机如何绘制流程图,电脑上怎么绘制流程图?电脑小白也能学会的流程图制作方法...