普通RNN的缺陷—梯度消失和梯度爆炸
之前的RNN,无法很好地学习到时序数据的长期依赖关系。因为BPTT会发生梯度消失和梯度爆炸的问题。
RNN梯度消失和爆炸
对于RNN来说,输入时序数据xt时,RNN 层输出ht。这个ht称为RNN 层的隐藏状态,它记录过去的信息。
语言模型的任务是根据已经出现的单词预测下一个将要出现的单词。
学习正确解标签过程中,RNN层通过向过去传递有意义的梯度,能够学习时间方向上的依赖关系。如果这个梯度在中途变弱(甚至没有包含任何信息),权重参数将不会被更新,也就是所谓的RNN层无法学习长期的依赖关系。梯度的流动如下图绿色箭头。
随着时间增加,RNN会产生梯度变小(梯度消失)或梯度变大(梯度爆炸)。
RNN 层在时间方向上的梯度传播,如下图。
反向传播的梯度流经tanh、+、MatMul(矩阵乘积)运算。
+的反向传播,将上游传来的梯度原样传给下游,梯度值不变。
tanh的计算图如下。它将上游传来的梯度乘以tanh的导数传给下游。
y=tanh(x)的值及其导数的值如下图。导数值小于1,x越远离0,值越小。反向传播梯度经过tanh节点要乘上tanh的导数,这就导致梯度越来越小。
如果RNN层的激活函数使用ReLU,可以抑制梯度消失,当ReLU输入为x时,输出是max(0,x)。x大于0时,反向传播将上游的梯度原样传递到下游,梯度不会退化。
对于MatMul(矩阵乘积)节点。仅关注RNN层MatMul节点时的梯度反向传播如下图。每一次矩阵乘积计算都使用相同的权重Wh。
N = 2 # mini-batch的大小
H = 3 # 隐藏状态向量的维数
T = 20 # 时序数据的长度dh = np.ones((N, H))#初始化为所有元素均为 1 的矩阵,dh是梯度np.random.seed(3)Wh = np.random.randn(H, H)#梯度的大小随时间步长呈指数级增加,发生梯度爆炸
#Wh = np.random.randn(H, H) * 0.5
#梯度的大小随时间步长呈指数级减小,发生梯度消失,权重梯度不能被更新,模型无法学习长期的依赖关系
norm_list = []
for t in range(T):dh = np.dot(dh, Wh.T)#根据反向传播的 MatMul 节点的数量更新 dh 相应次数norm = np.sqrt(np.sum(dh**2)) / N#mini-batch(N)中的平均L2 范数,L2 范数对所有元素的平方和求平方根.norm_list.append(norm)#将各步的 dh 的大小(范数)添加到 norm_list 中print(norm_list)# 绘制图形
plt.plot(np.arange(len(norm_list)), norm_list)
plt.xticks([0, 4, 9, 14, 19], [1, 5, 10, 15, 20])
plt.xlabel('time step')
plt.ylabel('norm')
plt.show()
如果Wh是标量,由于Wh被反复乘了T次,当Wh大于1时,梯度呈指数级增加;当 Wh 小于1时,梯度呈指数级减小。
如果wh是矩阵,矩阵的奇异值表示数据的离散程度,根据奇异值(多个奇异值中的最大值)是否大于1,可以预测梯度大小的变化。奇异值比1大是梯度爆炸的必要非充分条件。
梯度裁剪gradients clipping
梯度裁剪(gradients clipping)是解决解决梯度爆炸的一个方法。
将神经网络用到的所有参数的梯度整合成一个,用g表示,将阈值设置为threshold,如果梯度g的L2范数大于等于该阈值,就按如下方式修正梯度。
dW1 = np.random.rand(3, 3) * 10
dW2 = np.random.rand(3, 3) * 10
grads = [dW1, dW2]
max_norm = 5.0#阈值def clip_grads(grads, max_norm):total_norm = 0for grad in grads:total_norm += np.sum(grad ** 2)total_norm = np.sqrt(total_norm)#L2 范数对所有元素的平方和求平方根rate = max_norm / (total_norm + 1e-6)if rate < 1:#如果梯度的L2范数total_norm大于等于阈值max_norm,rate是小于1的,此时就需要修正梯度for grad in grads:grad *= rateprint('before:', dW1.flatten())
clip_grads(grads, max_norm)
print('after:', dW1.flatten())
before: [7.14418135 3.58857143 7.82910303 8.04057218 8.8617387 1.899638863.0606848 8.14163088 5.25490409]
after: [1.43122195 0.71891263 1.56843501 1.61079946 1.77530697 0.380562130.61315903 1.63104494 1.05273561]
解决梯度消失
为了解决梯度消失,需要从根本上改变 RNN 层的结构。
LSTM 和GRU中增加了一种门结构,可以学习到时序数据的长期依赖关系。
普通RNN的缺陷—梯度消失和梯度爆炸相关推荐
- RNN、LSTM、GRU 的梯度消失及梯度爆炸
文章目录 RNN.LSTM.GRU 的梯度消失及梯度爆炸 RNN RNN 结构 前向传播 损失函数 后向传播(BPTT) LSTM LSTM 结构 前向传播 后向传播 GRU GRU 结构 前向传播 ...
- 深度学习——梯度消失、梯度爆炸
本文参考:深度学习之3--梯度爆炸与梯度消失 梯度消失和梯度爆炸的根源:深度神经网络结构.反向传播算法 目前优化神经网络的方法都是基于反向传播的思想,即根据损失函数计算的误差通过反向传播的方式,指导深 ...
- 【深度学习】梯度消失和梯度爆炸问题的最完整解析
作者丨奥雷利安 · 杰龙 来源丨机械工业出版社<机器学习实战:基于Scikit-Learn.Keras和TensorFlow> 编辑丨极市平台 1 梯度消失与梯度爆炸 正如我们在第10章中 ...
- sigmoid函数解决溢出_梯度消失和梯度爆炸及解决方法
一.为什么会产生梯度消失和梯度爆炸? 目前优化神经网络的方法都是基于BP,即根据损失函数计算的误差通过梯度反向传播的方式,指导深度网络权值的更新优化.其中将误差从末层往前传递的过程需要链式法则(Cha ...
- 深度学习 《梯度消失和梯度爆炸》
一:梯度消失 在深层网络中,一方面由于求导法则,计算越前面层次或者时刻的梯度,会出现很多的乘法运算,很容易导致梯度消失和梯度爆炸,另一方面还受到激活函数的影响,Sigmoid函数和tanh函数会出现梯 ...
- [深度学习-优化]梯度消失与梯度爆炸的原因以及解决方案
首先让我们先来了解一个概念:什么是梯度不稳定呢? 概念:在深度神经网络中的梯度是不稳定的,在靠近输入层的隐藏层中或会消失,或会爆炸.这种不稳定性才是深度神经网络中基于梯度学习的根本问题. 产生梯度不稳 ...
- 深度学习实战 Tricks —— 梯度消失与梯度爆炸(gradient exploding)
梯度爆炸:梯度过大会使得损失函数很难收敛,甚至导致梯度为 NaN,异常退出: 解决方案:gradient cliping 梯度消失:较前的层次很难对较后的层次产生影响,梯度更新失效: 解决方案:对于 ...
- 梯度消失和梯度爆炸问题解析
前言 本文转载于梯度消失和梯度爆炸问题的最完整解析 作者丨奥雷利安 · 杰龙 来源丨机械工业出版社<机器学习实战:基于Scikit-Learn.Keras和TensorFlow> 目录 前 ...
- 【神经网络】梯度消失与梯度爆炸问题
梯度消失与梯度爆炸问题 Glorot 和 He 初始化 我们需要信号在两个方向上正确流动:进行预测时,信号为正向:在反向传播梯度时,信号为反向.我们需要每层输出的方差等于输入的方差,并且在反方向流过某 ...
最新文章
- 如何使用区块链技术进行项目开发
- 2021年第十二届蓝桥杯 - 省赛 - C/C++大学C组 - D.相乘
- 10月编程语言排行榜,来了!
- 洛谷P2463 Sandy的卡片【后缀数组】【二分】
- HBase优化案例分析:Facebook Messages系统问题与解决方案
- 【毕业设计】基于Java的五子棋游戏的设计(源代码+论文)
- 群晖 半洗白_群晖6.17/6.21二合一引导启动系统盘
- c++课程设计(水)
- 数学 - 基本初等函数导数公式及求导法则
- EXCEL公式与函数
- PTA每日一题-Python-身份证校验
- VMMECH007_Thermal Stress in a Bar with Temperature Dependent Conductivity
- 即时通信和实时通信的区别
- unordered_set使用介绍
- ajax、php、json异步数据处理
- 第八届蓝桥杯个人赛赛后总结
- internet时间和域
- 【Java学习笔记】工厂模式
- 北京理工大学计算机学院赵曜,中国进出口银行2016年度拟接收毕业生情况公示...
- 【软件设计师】历年真题-模糊知识点备忘——15年上 上午真题
热门文章
- Net处理html页面元素工具类(HtmlAgilityPack.dll)的使用
- Thread线程的深刻理解和代理方法参数[有图有真相]
- mysql php ajax_PHP 和 AJAX MySQL 数据库实例
- keepalived vip mysql_mysql+keepalived高可用集群
- linux自动异地备份,Linux本地加异地自动备份方案
- in ms sql 集合参数传递_mybatis从入门到精通,第三篇《动态SQL》,干货满满
- Linux导出函数控制,linux 下仅导出指定函数的方法
- vsftpd设置被动模式_(1)vsftpd主、被动模式iptables配置方法
- Win10笔记本设置合盖不息屏的方法
- Win11有黑色边框怎么办 Win11有黑色边框的解决方法