之前的RNN,无法很好地学习到时序数据的长期依赖关系。因为BPTT会发生梯度消失和梯度爆炸的问题。

RNN梯度消失和爆炸

对于RNN来说,输入时序数据xt时,RNN 层输出ht。这个ht称为RNN 层的隐藏状态,它记录过去的信息。

语言模型的任务是根据已经出现的单词预测下一个将要出现的单词。

学习正确解标签过程中,RNN层通过向过去传递有意义的梯度,能够学习时间方向上的依赖关系。如果这个梯度在中途变弱(甚至没有包含任何信息),权重参数将不会被更新,也就是所谓的RNN层无法学习长期的依赖关系。梯度的流动如下图绿色箭头。

随着时间增加,RNN会产生梯度变小(梯度消失)或梯度变大(梯度爆炸)。

RNN 层在时间方向上的梯度传播,如下图。

反向传播的梯度流经tanh、+、MatMul(矩阵乘积)运算。

+的反向传播,将上游传来的梯度原样传给下游,梯度值不变。

tanh的计算图如下。它将上游传来的梯度乘以tanh的导数传给下游。

y=tanh(x)的值及其导数的值如下图。导数值小于1,x越远离0,值越小。反向传播梯度经过tanh节点要乘上tanh的导数,这就导致梯度越来越小。

如果RNN层的激活函数使用ReLU,可以抑制梯度消失,当ReLU输入为x时,输出是max(0,x)。x大于0时,反向传播将上游的梯度原样传递到下游,梯度不会退化。

对于MatMul(矩阵乘积)节点。仅关注RNN层MatMul节点时的梯度反向传播如下图。每一次矩阵乘积计算都使用相同的权重Wh。

N = 2  # mini-batch的大小
H = 3  # 隐藏状态向量的维数
T = 20  # 时序数据的长度dh = np.ones((N, H))#初始化为所有元素均为 1 的矩阵,dh是梯度np.random.seed(3)Wh = np.random.randn(H, H)#梯度的大小随时间步长呈指数级增加,发生梯度爆炸
#Wh = np.random.randn(H, H) * 0.5
#梯度的大小随时间步长呈指数级减小,发生梯度消失,权重梯度不能被更新,模型无法学习长期的依赖关系
norm_list = []
for t in range(T):dh = np.dot(dh, Wh.T)#根据反向传播的 MatMul 节点的数量更新 dh 相应次数norm = np.sqrt(np.sum(dh**2)) / N#mini-batch(N)中的平均L2 范数,L2 范数对所有元素的平方和求平方根.norm_list.append(norm)#将各步的 dh 的大小(范数)添加到 norm_list 中print(norm_list)# 绘制图形
plt.plot(np.arange(len(norm_list)), norm_list)
plt.xticks([0, 4, 9, 14, 19], [1, 5, 10, 15, 20])
plt.xlabel('time step')
plt.ylabel('norm')
plt.show()

如果Wh是标量,由于Wh被反复乘了T次,当Wh大于1时,梯度呈指数级增加;当 Wh 小于1时,梯度呈指数级减小。

如果wh是矩阵,矩阵的奇异值表示数据的离散程度,根据奇异值(多个奇异值中的最大值)是否大于1,可以预测梯度大小的变化。奇异值比1大是梯度爆炸的必要非充分条件。

梯度裁剪gradients clipping

梯度裁剪(gradients clipping)是解决解决梯度爆炸的一个方法。

将神经网络用到的所有参数的梯度整合成一个,用g表示,将阈值设置为threshold,如果梯度g的L2范数大于等于该阈值,就按如下方式修正梯度。

dW1 = np.random.rand(3, 3) * 10
dW2 = np.random.rand(3, 3) * 10
grads = [dW1, dW2]
max_norm = 5.0#阈值def clip_grads(grads, max_norm):total_norm = 0for grad in grads:total_norm += np.sum(grad ** 2)total_norm = np.sqrt(total_norm)#L2 范数对所有元素的平方和求平方根rate = max_norm / (total_norm + 1e-6)if rate < 1:#如果梯度的L2范数total_norm大于等于阈值max_norm,rate是小于1的,此时就需要修正梯度for grad in grads:grad *= rateprint('before:', dW1.flatten())
clip_grads(grads, max_norm)
print('after:', dW1.flatten())
before: [7.14418135 3.58857143 7.82910303 8.04057218 8.8617387  1.899638863.0606848  8.14163088 5.25490409]
after: [1.43122195 0.71891263 1.56843501 1.61079946 1.77530697 0.380562130.61315903 1.63104494 1.05273561]

解决梯度消失

为了解决梯度消失,需要从根本上改变 RNN 层的结构。

LSTM 和GRU中增加了一种门结构,可以学习到时序数据的长期依赖关系。

普通RNN的缺陷—梯度消失和梯度爆炸相关推荐

  1. RNN、LSTM、GRU 的梯度消失及梯度爆炸

    文章目录 RNN.LSTM.GRU 的梯度消失及梯度爆炸 RNN RNN 结构 前向传播 损失函数 后向传播(BPTT) LSTM LSTM 结构 前向传播 后向传播 GRU GRU 结构 前向传播 ...

  2. 深度学习——梯度消失、梯度爆炸

    本文参考:深度学习之3--梯度爆炸与梯度消失 梯度消失和梯度爆炸的根源:深度神经网络结构.反向传播算法 目前优化神经网络的方法都是基于反向传播的思想,即根据损失函数计算的误差通过反向传播的方式,指导深 ...

  3. 【深度学习】梯度消失和梯度爆炸问题的最完整解析

    作者丨奥雷利安 · 杰龙 来源丨机械工业出版社<机器学习实战:基于Scikit-Learn.Keras和TensorFlow> 编辑丨极市平台 1 梯度消失与梯度爆炸 正如我们在第10章中 ...

  4. sigmoid函数解决溢出_梯度消失和梯度爆炸及解决方法

    一.为什么会产生梯度消失和梯度爆炸? 目前优化神经网络的方法都是基于BP,即根据损失函数计算的误差通过梯度反向传播的方式,指导深度网络权值的更新优化.其中将误差从末层往前传递的过程需要链式法则(Cha ...

  5. 深度学习 《梯度消失和梯度爆炸》

    一:梯度消失 在深层网络中,一方面由于求导法则,计算越前面层次或者时刻的梯度,会出现很多的乘法运算,很容易导致梯度消失和梯度爆炸,另一方面还受到激活函数的影响,Sigmoid函数和tanh函数会出现梯 ...

  6. [深度学习-优化]梯度消失与梯度爆炸的原因以及解决方案

    首先让我们先来了解一个概念:什么是梯度不稳定呢? 概念:在深度神经网络中的梯度是不稳定的,在靠近输入层的隐藏层中或会消失,或会爆炸.这种不稳定性才是深度神经网络中基于梯度学习的根本问题. 产生梯度不稳 ...

  7. 深度学习实战 Tricks —— 梯度消失与梯度爆炸(gradient exploding)

    梯度爆炸:梯度过大会使得损失函数很难收敛,甚至导致梯度为 NaN,异常退出: 解决方案:gradient cliping 梯度消失:较前的层次很难对较后的层次产生影响,梯度更新失效: 解决方案:对于 ...

  8. 梯度消失和梯度爆炸问题解析

    前言 本文转载于梯度消失和梯度爆炸问题的最完整解析 作者丨奥雷利安 · 杰龙 来源丨机械工业出版社<机器学习实战:基于Scikit-Learn.Keras和TensorFlow> 目录 前 ...

  9. 【神经网络】梯度消失与梯度爆炸问题

    梯度消失与梯度爆炸问题 Glorot 和 He 初始化 我们需要信号在两个方向上正确流动:进行预测时,信号为正向:在反向传播梯度时,信号为反向.我们需要每层输出的方差等于输入的方差,并且在反方向流过某 ...

最新文章

  1. 如何使用区块链技术进行项目开发
  2. 2021年第十二届蓝桥杯 - 省赛 - C/C++大学C组 - D.相乘
  3. 10月编程语言排行榜,来了!
  4. 洛谷P2463 Sandy的卡片【后缀数组】【二分】
  5. HBase优化案例分析:Facebook Messages系统问题与解决方案
  6. 【毕业设计】基于Java的五子棋游戏的设计(源代码+论文)
  7. 群晖 半洗白_群晖6.17/6.21二合一引导启动系统盘
  8. c++课程设计(水)
  9. 数学 - 基本初等函数导数公式及求导法则
  10. EXCEL公式与函数
  11. PTA每日一题-Python-身份证校验
  12. VMMECH007_Thermal Stress in a Bar with Temperature Dependent Conductivity
  13. 即时通信和实时通信的区别
  14. unordered_set使用介绍
  15. ajax、php、json异步数据处理
  16. 第八届蓝桥杯个人赛赛后总结
  17. internet时间和域
  18. 【Java学习笔记】工厂模式
  19. 北京理工大学计算机学院赵曜,中国进出口银行2016年度拟接收毕业生情况公示...
  20. 【软件设计师】历年真题-模糊知识点备忘——15年上 上午真题

热门文章

  1. Net处理html页面元素工具类(HtmlAgilityPack.dll)的使用
  2. Thread线程的深刻理解和代理方法参数[有图有真相]
  3. mysql php ajax_PHP 和 AJAX MySQL 数据库实例
  4. keepalived vip mysql_mysql+keepalived高可用集群
  5. linux自动异地备份,Linux本地加异地自动备份方案
  6. in ms sql 集合参数传递_mybatis从入门到精通,第三篇《动态SQL》,干货满满
  7. Linux导出函数控制,linux 下仅导出指定函数的方法
  8. vsftpd设置被动模式_(1)vsftpd主、被动模式iptables配置方法
  9. Win10笔记本设置合盖不息屏的方法
  10. Win11有黑色边框怎么办 Win11有黑色边框的解决方法