目录

  • RNN
  • LSTM

参考一个很全的总结:
预训练语言模型的前世今生 - 从Word Embedding到BERT
RNN部分参考了这个:
循环神经网络
LSTM部分参考了这两个:
LSTM以及三重门,遗忘门,输入门,输出门
LSTM如何解决梯度消失与梯度爆炸

这儿对预训练模型又有了一点理解,也是之前在做VGG实验时在困惑的点,预训练模型在使用时可以有两种做法:一种是Frozen,将参数锁住,在下游应用时不再改变;另一种就是Fine-Tuning,即将参数初始化为预训练模型的参数,下游应用时这里的参数仍然可以改变。
好了进入正题:

RNN

RNN结构最大的特点就是融入了时序信息,其结构如下图所示:

左侧部分称为RNN的一个timestep,对于每一个时刻 t t t ,输入的 x t x_t xt 都可以计算出一个 h t h_t ht ,将该信息传入下一个时刻 t + 1 t+1 t+1 ,这个过程是一个前馈神经网络;接收完一个序列中所有时刻的数据之后从 x t x_t xt 时刻沿时间反向传播(BPTT)计算loss。
RNN的主体结构是 A A AA A A 的结构如下图所示,输入为 ( h t − 1 , x t ) (h_{t-1},x_t) (ht1,xt) ,两个权重矩阵 W h W_h WhW x W_x Wx 可以分开,也可以合并在一起是一个 W W W

可以看到,RNN解决了时序依赖问题,但这里的时序一般是短距离的,短距离依赖影响较大,长距离依赖影响很小(一般超过10步就无能为力了)。
导致长期依赖的原因,在于RNN训练时容易发生梯度爆炸和梯度消失。
梯度爆炸相对友好,因为这时程序会收到NaN错误,同时处理上也可以设置一个梯度阈值,当梯度超过这个阈值时进行截断。
对于梯度消失,主要采用以下三种方式:

  1. 合理地初始化权重值,使每个神经元尽可能不要取极大或极小值,以避开梯度消失的区域。
  2. 用ReLU代替sigmoid和tanh作为激活函数。
  3. 采用其它结构的RNNs,比如LTSM和GRU,这也是最流行的方法。

梯度消失原因:

前向传播过程包括:

  1. 隐藏状态: h ( t ) = σ ( z ( t ) ) = σ ( U x ( t ) + W h ( t − 1 ) + b ) h^{(t)}=\sigma (z^{(t)})=\sigma(Ux^{(t)}+Wh^{(t-1)}+b) h(t)=σ(z(t))=σ(Ux(t)+Wh(t1)+b) , 此处激活函数一般为 t a n h tanh tanh
  2. 模型输出: o ( t ) = V h ( t ) + c o^{(t)}=Vh^{(t)}+c o(t)=Vh(t)+c
  3. 预测输出: y ^ = σ ( o ( t ) ) \hat{y}=\sigma(o^{(t)}) y^=σ(o(t)) ,此处激活函数一般为 s o f t m a x softmax softmax
  4. 模型损失: L = ∑ t = 1 T L ( t ) L=\sum^T_{t=1}L^{(t)} L=t=1TL(t)

RNN所有的timestep共享一套参数 U , V , W U,V,W U,V,W ,在RNN反向传播的过程中,需要计算 U , V , W U,V,W U,V,W 的梯度,以 W W W 为例,如下(这是一个链式求导…微积分全不会了好无语…):
∂ L ∂ W = ∑ t = 1 T ∂ L ∂ y ( T ) ∂ y ( T ) ∂ o ( T ) ∂ o ( T ) ∂ h ( T ) ( ∏ k = t + 1 T ∂ h ( k ) ∂ h ( k − 1 ) ) ∂ h ( t ) ∂ W = ∑ t = 1 T ∂ L ∂ y ( T ) ∂ y ( T ) ∂ o ( T ) ∂ o ( T ) ∂ h ( T ) ( ∏ k = t + 1 T tanh ⁡ ′ ( z ( k ) ) W ) ∂ h ( t ) ∂ W \begin{aligned} \frac{\partial L}{\partial W} &= \sum_{t=1}^T\frac{\partial L}{\partial y^{(T)}} \frac{\partial y^{(T)}}{\partial o^{(T)}} \frac{\partial o^{(T)}}{\partial h^{(T)}}(\prod_{k=t+1}^{T} \frac{\partial h^{(k)}}{\partial h^{(k-1)}}) \frac{\partial h^{(t)}}{\partial W}\\ &=\sum_{t=1}^T\frac{\partial L}{\partial y^{(T)}} \frac{\partial y^{(T)}}{\partial o^{(T)}} \frac{\partial o^{(T)}}{\partial h^{(T)}}(\prod_{k=t+1}^{T} \tanh' (z^{(k)})W) \frac{\partial h^{(t)}}{\partial W} \end{aligned} WL=t=1Ty(T)Lo(T)y(T)h(T)o(T)(k=t+1Th(k1)h(k))Wh(t)=t=1Ty(T)Lo(T)y(T)h(T)o(T)(k=t+1Ttanh(z(k))W)Wh(t)
对于公式中的 ( ∏ k = t + 1 T ∂ h ( k ) ∂ h ( k − 1 ) ) = ( ∏ k = t + 1 T tanh ⁡ ′ ( z ( k ) ) W ) (\prod_{k=t+1}^{T} \frac{\partial h^{(k)}}{\partial h^{(k-1)}})=(\prod_{k=t+1}^{T} \tanh' (z^{(k)})W) (k=t+1Th(k1)h(k))=(k=t+1Ttanh(z(k))W) ,tanh的导数总是小于1的,又因为是 ( T − ( t − + 1 ) ) (T-(t-+1)) (T(t+1)) 个timestep参数的连乘,所以如果 W W W 小于1,梯度就会消失;如果 W W W 的特征值大于1,梯度就会爆炸。
所以,RNN梯度消失的真正含义是,梯度被近距离(当 ( t + 1 ) (t+1) (t+1) 趋向于 T T T)的梯度主导,远距离会发生爆炸或消失,导致模型难以学到远距离的信息。
值得强调的是,RNN的这一缺陷并非理论上的,而是技术实践上的。换言之,RNN在理论上是一个优秀的模型,前提是我们能够找到一组合适的参数,然而实践上这组参数并不好找。

LSTM

先来大致看看LSTM相比RNN的结构改变是什么,多了一个传输状态:

这个图是LSTM的timestep:

根据这个图,LSTM的前向传播过程包括:

  1. 遗忘门:接收 t − 1 t-1 t1 时刻的状态 h t − 1 h_{t-1} ht1 以及当前的输入 x t x_t xt,经过sigmoid函数之后输出一个0到1之间的值,输出为: f t = σ ( W f h t − 1 + U f x t + b f ) f_t=\sigma(W_fh_{t-1}+U_fx_t+b_f) ft=σ(Wfht1+Ufxt+bf)
  2. 输入门:这里进行了两个操作,输出分别为: i t = σ ( W i h t − 1 + U i x t + b i ) i_t=\sigma(W_ih_{t-1}+U_ix_t+b_i) it=σ(Wiht1+Uixt+bi)C ~ t = tanh ⁡ ( W a h t − 1 + U a x t + b a ) \tilde C_t=\tanh(W_ah_{t-1}+U_ax_t+b_a) C~t=tanh(Waht1+Uaxt+ba)
  3. 当前状态:输出为: C t = C t − 1 ⊙ f t + i t ⊙ C ~ t C_t=C_{t-1} \odot f_t+i_t \odot \tilde C_t Ct=Ct1ft+itC~t
  4. 输出门:输出为: o t = σ ( W o h t − 1 + U o x t + b o ) o_t=\sigma(W_oh_{t-1}+U_ox_t+b_o) ot=σ(Woht1+Uoxt+bo)h t = o t ⊙ tanh ⁡ C t h_t=o_t \odot \tanh C_t ht=ottanhCt
  5. 预测输出: y ^ = σ ( V h t + c ) \hat y=\sigma (Vh_t+c) y^=σ(Vht+c)

对于三个门的作用如下图所示:

关于LSTM如何RNN中解决梯度消失或爆炸:
如上文中所述,RNN中引起梯度消失或爆炸的点在于:
∏ k = t + 1 T ∂ h ( k ) ∂ h ( k − 1 ) = ∏ k = t + 1 T tanh ⁡ ′ ( z ( k ) ) W \prod_{k=t+1}^{T} \frac{\partial h^{(k)}}{\partial h^{(k-1)}}=\prod_{k=t+1}^{T} \tanh' (z^{(k)})W k=t+1Th(k1)h(k)=k=t+1Ttanh(z(k))W
在LSTM中这个公式是这样的:
∏ k = t + 1 T ∂ h ( k ) ∂ h ( k − 1 ) = ∏ k = t + 1 T tanh ⁡ ′ σ ( W f X t + b f ) \prod_{k=t+1}^{T} \frac{\partial h^{(k)}}{\partial h^{(k-1)}}=\prod_{k=t+1}^{T} \tanh' \sigma(W_fX_t+b_f) k=t+1Th(k1)h(k)=k=t+1Ttanhσ(WfXt+bf)
如果设 Z = tanh ⁡ ( x ) σ ( y ) Z=\tanh (x)\sigma(y) Z=tanh(x)σ(y),其函数图像如下所示:

可以看到这个函数的值基本可以近似为0或1,这样就可以解决多个小于1或多个大于1的数相乘导致的梯度消失或梯度爆炸问题。
通过LSTM这种方式,除了在结构上天然地克服了梯度消失的问题,更重要的是能够具有更多的参数来控制模型;其参数量是RNN的四倍,能够更加精细地预测时间序列变量。

预训练语言模型(三):RNN和LSTM相关推荐

  1. 赠书 | 一文了解预训练语言模型

    来源 | 博文视点 头图 | 下载于视觉中国 近年来,在深度学习和大数据的支撑下,自然语言处理技术迅猛发展.而预训练语言模型把自然语言处理带入了一个新的阶段,也得到了工业界的广泛关注. 通过大数据预训 ...

  2. 【赠书】如何掌握好自然语言处理中的预训练语言模型?你需要读这本书

    ‍‍ 预训练语言模型属于人工智能领域中自然语言处理领域的一个细分,是自然语言处理领域的重要突破,得到了越来越广泛的关注,相关研究者和从业人员在实际应用的过程中,亟需一本理论翔实.代码细节充分的参考书. ...

  3. PyTorch学习(8)-问答系统、文本摘要和大规模预训练语言模型

    问答系统 1. SQuAD数据集 给定一段文字作为context,给定一个问题question,从context中寻找一段连续的文字(text span)作为问题的答案. 网址:https://raj ...

  4. 周末送新书 | 一文了解预训练语言模型!

    近年来,在深度学习和大数据的支撑下,自然语言处理技术迅猛发展. 而预训练语言模型把自然语言处理带入了一个新的阶段,也得到了工业界的广泛关注. 通过大数据预训练加小数据微调,自然语言处理任务的解决,无须 ...

  5. NLP判断语言情绪_网易严选nlp预训练语言模型的应用

    随着2018年底bert的发布,预训练(pre-train)成为nlp领域最为热门的方向之一,大规模的无监督语料加上少量有标注的语料成为了nlp模型的标配.本文将介绍几种常见的语言模型的基本原理和使用 ...

  6. 大模型系统和应用——Transformer预训练语言模型

    引言 最近在公众号中了解到了刘知远团队退出的视频课程<大模型交叉研讨课>,看了目录觉得不错,因此拜读一下. 观看地址: https://www.bilibili.com/video/BV1 ...

  7. 预训练语言模型在网易严选的应用

    导读:随着Bert的发布,预训练 ( pre-train ) 成为NLP领域最为热门的方向之一,大规模的无监督语料加上少量有标注的语料成为了NLP模型的标配.本文将介绍几种常见的语言模型的基本原理和使 ...

  8. 微软统一预训练语言模型UniLM 2.0解读

    微软研究院在2月底发布的UniLM 2.0(Unified Language Model \ 统一语言模型)文章,相比于19年上半年发布的UniLM 1.0,更加有效地整合了自然语言理解(NLU)与自 ...

  9. 预训练语言模型真的是世界模型?

    文 | 子龙 自GPT.BERT问世以来,预训练语言模型在NLP领域大放异彩,刷新了无数榜单,成为当前学界业界的心头爱,其主体结构--Transformer--也在逐步的运用于其他领域的任务中,常见的 ...

最新文章

  1. pycharm debug code -1073741819
  2. DOS获取局域网内所有正在使用的ip地址
  3. VTK:InfoVis之WordCloud
  4. QT信号与槽(自定义带参数的信号)
  5. ssis 数据转换_SSIS数据类型:高级编辑器的更改与数据转换的转换
  6. django-admin
  7. Adobe AIR for Android 缓存本地数据常用方法
  8. 根据系统创建文件路径
  9. 风火编程--繁体转简体
  10. 当你不能够再拥有,你唯一可以做的,就是令自己不要忘记
  11. rust 案例_深入浅出rust.pdf 高清版
  12. 服务器盘符修改不了怎么办,win10更换盘符报参数错误怎么办_win10系统盘符改不了参数错误解决方法...
  13. 电脑CPU型号是什么意思?
  14. 嘉曼服饰上市破发,大跌16%:公司市值37亿 刘溦家族色彩浓厚
  15. 恭喜谷歌迈出抗议大猪蹄子第一步
  16. 第8关:判断条件的相容性
  17. [收集整理]BT恶心诗全集
  18. 计算机组成原理(二) 计算机算术
  19. 评分卡模型验证常用指标
  20. matlab微分方程求法,matlab微分方程的求解的方法ppt课件

热门文章

  1. SEO中的ip、uv和pv的定义
  2. MySQL_02 快速入门 MySQL(SQL、PHP)大全
  3. 使用Urllib2制作有道翻译器
  4. C编程 求1+1/2+1/3+……+1/100的和
  5. Linux下使用libreoffice把doc转换成Pdf
  6. 可见即可爬:快速上手 Selenium
  7. 用python给游戏加上音效_pygame游戏之旅 添加icon和bgm音效的方法
  8. java设计模式--1.单例模式
  9. 基于android 10的国产手机,基于Android的国产手机UI比拼:ColorOS、MIUI、EMUI,你打算盘谁...
  10. 智联招聘的基于 Nebula Graph 的推荐实践分享