声明:文章仅作知识整理、分享,如有侵权请联系作者删除博文,谢谢!

网上有很多关于梯度消失-爆炸这方面的文章,相似的也比较多,最近对不同文章进行整理,修改部分文章公式错误,形成整理。

1、概念

目前优化神经网络的方法都是基于BP,即根据损失函数计算的误差通过梯度反向传播的方式,指导深度网络权值的更新优化。其中将误差从末层往前传递的过程需要链式法则(Chain Rule)的帮助,因此反向传播算法可以说是梯度下降在链式法则中的应用。

而链式法则是一个连乘的形式,所以当层数越深的时候,梯度将以指数形式传播。梯度消失问题和梯度爆炸问题一般随着网络层数的增加会变得越来越明显。在根据损失函数计算的误差通过梯度反向传播的方式对深度网络权值进行更新时,得到的梯度值接近0或特别大,也就是梯度消失或爆炸。梯度消失或梯度爆炸在本质原理上其实是一样的。

1.1、梯度消失

经常出现,产生的原因有:一是在深层网络中,二是采用了不合适的损失函数,比如sigmoid。当梯度消失发生时,接近于输出层的隐藏层由于其梯度相对正常,所以权值更新时也就相对正常,但是当越靠近输入层时,由于梯度消失现象,会导致靠近输入层的隐藏层权值更新缓慢或者更新停滞。这就导致在训练时,只等价于后面几层的浅层网络的学习。

梯度消失的影响:

1)浅层基本不学习,后面几层一直在学习,失去深度的意义。

2)无法收敛,相当于浅层网络。

1.2、梯度爆炸

根据链式法则,如果每一层神经元对上一层的输出的偏导乘上权重结果都大于1的话,在经过足够多层传播之后,误差对输入层的偏导会趋于无穷大。这种情况又会导致靠近输入层的隐含层神经元调整变动极大。梯度爆炸一般出现在深层网络和权值初始化值太大的情况下。另外,初始学习率太小或太大也会出现梯度消失或爆炸。

梯度爆炸的影响:

1)模型不稳定,导致更新过程中的损失出现显著变化;

2)训练过程中,在极端情况下,权重的值变得非常大,以至于溢出,导致模型损失变成 NaN等等。

2、产生梯度消失和梯度爆炸的原因

梯度消失的根源—–深度神经求导网络和反向传播。目前深度学习方法中,深度神经网络的发展造就了我们可以构建更深层的网络完成更复杂的任务,深层网络比如深度卷积网络,LSTM等等,而且最终结果表明,在处理复杂任务上,深度网络比浅层的网络具有更好的效果。但是,目前优化神经网络的方法都是基于反向传播的思想,即根据损失函数计算的误差通过梯度反向传播的方式,指导深度网络权值的更新优化。下面将从这3个角度分析一下产生这两种现象的根本原因:

比较简单的深层网络如下:

2.1、深层网络

如图所示的含有3个隐藏层的神经网络,梯度消失问题发生时,接近于输出层的hidden layer 3等的权值更新相对正常,但前面的hidden layer 1的权值更新会变得很慢,导致前面的层权值几乎不变,仍接近于初始化的权值,这就导致hidden layer 1相当于只是一个映射层,对所有的输入做了一个同一映射,这是此深层网络的学习就等价于只有后几层的浅层网络的学习了。

图中是一个四层的全连接网络,假设每一层网络激活后的输出为fi(x),其中i为第i层, x代表第i层的输入,也就是第i−1层的输出,f是激活函数,那么,得出:

简单记为:

BP算法基于梯度下降策略,以目标的负梯度方向对参数进行调整,参数的更新为w←w+Δw,给定学习率α,得出:

如果要激活函数的导数、网络初值(w,b)连续相乘表现为w的更新第一隐藏层量。避免网络不work的权值信息过程就是调整这部分连乘结果,根据链式求导法则使其保持在1附近。学习率决定的网络学习的快慢,过大或过小也会直接影响网络的参数更新梯度信息过程。

,很容易看出来:

,即第一层的输入。

所以说af4/af3就是对激活函数进行求导,如果此部分大于1,那么层数增多的时候,最终的求出的梯度更新将以指数形式增加,即发生梯度爆炸,如果此部分小于1,那么随着层数增多,求出的梯度更新信息将会以指数形式衰减,即发生了梯度消失。

如果说从数学上看不够直观的话,下面几个图可以很直观的说明深层网络的梯度问题:

图中的曲线表示权值更新的速度,对于下图两个隐层的网络来说,已经可以发现隐藏层2的权值更新速度要比隐藏层1更新的速度慢。

那么对于四个隐层的网络来说,就更明显了,第四隐藏层比第一隐藏层的更新速度慢了两个数量级:

总结:从深层网络角度来讲,不同的层学习的速度差异很大,表现为网络中靠近输出的层学习的情况很好,靠近输入的层学习的很慢,有时甚至训练了很久,前几层的权值和刚开始随机初始化的值差不多。因此,梯度消失、爆炸,其根本原因在于反向传播训练法则,属于先天不足。

2.2、激活函数

以下图的反向传播为例(假设每一层只有一个神经元且对于每一层:

偏置b可以推导出:

而sigmoid的导数为:

同理,使用tanh作为损失函数,它的导数图如下,可以看出,tanh比sigmoid要好一些,但是它的倒数仍然是小于1的。tanh数学表达为:

如果接近输出层的激活函数求导后梯度值大于1,那么层数增多的时候,最终求出的梯度很容易指数级增长,就会产生梯度爆炸;相反,如果小于1,那么经过链式法则的连乘形式,也会很容易衰减至0,就会产生梯度消失。

2.3、初始化权值太大

如上图所示,当:

,也就是w比较大的情况。根据链式相乘(反向传播)可得,则前面的网络层比后面的网络层梯度变化更快,很容易发生梯度爆炸的问题。

3、总结

深层网络出现梯度消失或爆炸,主要是由于链式求导发现传播引起。参数的更新为w←w+Δw,给定学习率α,得出:

激活函数的导数、网络初值(w,b)连续相乘表现为w的更新量。避免网络不work的过程就是调整这部分连乘结果,使其保持在1附近。学习率决定的网络学习的快慢,过大或过小也会直接影响网络的参数更新过程。

参考文章:

1、https://zhuanlan.zhihu.com/p/25631496

2、https://zhuanlan.zhihu.com/p/72589432

深层网络梯度消失-爆炸原因相关推荐

  1. 深度学习100问之深入理解Vanishing/Exploding Gradient(梯度消失/爆炸)

    这几天正在看梯度消失/爆炸,在深度学习的理论中梯度消失/爆炸也是极其重要的,所以就抽出一段时间认真地研究了一下梯度消失/爆炸的原理,以下为参考网上的几篇文章总结得出的. 本文分为四个部分:第一部分主要 ...

  2. 算法基础--梯度消失的原因

    深度学习训练中梯度消失的原因有哪些?有哪些解决方法? 1.为什么要使用梯度反向传播? 归根结底,深度学习训练中梯度消失的根源在于梯度更新规则的使用.目前更新深度神经网络参数都是基于反向传播的思想,即基 ...

  3. 也来谈谈RNN的梯度消失/爆炸问题

    ©PaperWeekly 原创 · 作者|苏剑林 单位|追一科技 研究方向|NLP.神经网络 尽管 Transformer 类的模型已经攻占了 NLP 的多数领域,但诸如 LSTM.GRU 之类的 R ...

  4. 谈谈RNN的梯度消失/爆炸问题

    尽管 Transformer 类的模型已经攻占了 NLP 的多数领域,但诸如 LSTM.GRU 之类的 RNN 模型依然在某些场景下有它的独特价值,所以 RNN 依然是值得我们好好学习的模型.而于 R ...

  5. ResNet(残差网络)和梯度消失/爆炸

    ResNet解决的不是梯度弥散或爆炸问题,kaiming的论文中也说了:臭名昭著的梯度弥散/爆炸问题已经很大程度上被normalized initialization and intermediate ...

  6. Dropout、梯度消失/爆炸、Adam优化算法,神经网络优化算法看这一篇就够了

    作者 | mantch 来源 | 知乎 1. 训练误差和泛化误差 对于机器学习模型在训练数据集和测试数据集上的表现.如果你改变过实验中的模型结构或者超参数,你也许发现了:当模型在训练数据集上更准确时, ...

  7. 梯度消失/爆炸与RNN家族的介绍(LSTM GRU B-RNN Multi-RNNs)-基于cs224n的最全总结

    vanishing gradients and fancy RNNs(RNN家族与梯度消失) 文章目录 vanishing gradients and fancy RNNs(RNN家族与梯度消失) 内 ...

  8. <美团>深度学习训练中梯度消失的原因有哪些?有哪些解决方法?

    梯度消失产生的主要原因有:一是使用了深层网络,二是采用了不合适的损失函数. (1)目前优化神经网络的方法都是基于BP,即根据损失函数计算的误差通过梯度反向传播的方式,指导深度网络权值的更新优化.其中将 ...

  9. rnn 梯度消失爆炸

    文章目录 梯度消失和爆炸原理 求导知识 RNN推导 梯度消失和爆炸原理 求导知识 y=x2y = x^2y=x2 dy\mathrm{d} {y}dy 导数 dydx\Large \frac {\ma ...

  10. ztree在刷新时第一个父节点消失_从反向传播推导到梯度消失and爆炸的原因及解决方案(从DNN到RNN,内附详细反向传播公式推导)...

    引言:参加了一家公司的面试和另一家公司的笔试,都问到了这个题!看来很有必要好好准备一下,自己动手推了公式,果然理解更深入了!持续准备面试中... 一. 概述: 想要真正了解梯度爆炸和消失问题,必须手推 ...

最新文章

  1. js学习笔记(执行上下文、闭包、this部分)
  2. Linux网络不可达解决方法
  3. 总结4:input文本输入框自动提示
  4. django1.5 连接mysql_django1.5.5使用mysql
  5. LeetCode(合集)两数之和总结 (1,167,1346)
  6. 这个为生信学习打造的开源Bash教程真香!!(目录更新)!
  7. oracle PROFILE的使用学习
  8. php代码 编码转换,php字符编码转换代码
  9. 提取身份证信息的自定义函数
  10. 在 RAID 磁盘上面架构 LVM 系统
  11. redis整理の配置
  12. STC单片机烧录时的坑不要踩
  13. Amos实操教程|调节效应检验
  14. 让DEVCpp支持C11
  15. Linux云计算架构-docker容器命名和资源配额控制(2)
  16. 只需10分钟,给你全世界!水经注全球三维离线GIS系统
  17. 【JavaEE】社区版IDEA(2021.X版本及之前)创建SpringBoot项目
  18. CAN总线BUS OFF
  19. oracle导出报错LRM 00101,Oracle:ORA-01078与LRM-00109报错
  20. (初中数学)2018网校新初一初二初三数学年卡尖子班(全国人教版)

热门文章

  1. 怎样和求职者聊天_我如何学会欣赏求职者
  2. 前沿计算技术于推动设计技术发展
  3. 提升营业额的正确方法
  4. Safari 兼容问题累积
  5. 直角三角形斜边用计算机怎么算,直角三角形斜边怎么算 计算方法有哪些
  6. CCF-CSP计算机职业资格认证备考
  7. 软件破解逆向安全(十二)内存特征码
  8. 百度指数批量查询器,百度指数
  9. Python关键词百度指数采集,抓包Cookie及json数据处理
  10. python 卡方分布函数_推断统计分析(二):python验证三大抽样分布