深度学习之循环神经网络(8)长短时记忆网络(LSTM)

  • 0. LSTM原理
  • 1. 遗忘门
  • 2. 输入门
  • 3. 刷新Memory
  • 4. 输出门
  • 5. 小结

 循环神经网络除了训练困难,还有一个更严重的问题,那就是 短时记忆(Short-term memory)。考虑一个长句子:

今天天气太美好了,尽管路上发生了一件不愉快的事,…,我马上调整好状态,开开心心地准备迎接美好的一天。

根据我们的理解,之所以能够“ 开开心心地准备迎接美好的一天”,在于句子最开始处点明了“ 今天天气太美好了”。可见人类是能够很好地理解长句子的,但是循环神经网络却不一定。研究人员发现,循环神经网络在处理较长的句子时,往往只能够理解有限长度内的信息,而对于位于较长范围内的有用信息往往不能很好地利用起来。我们把这种现象叫做短时记忆。

 那么,能不能够延长这种短时记忆,使得循环神经网络可以有效利用较大范围内的训练数据,从而提升性能呢?1997年,瑞士人工智能科学家J u¨\ddot{\text{u}}u¨rgen Schmidhuber 提出了 长短时记忆网络(Long Short-Term Memory,简称LSTM)。LSTM相对于基础的RNN网络来说,记忆能力更强,更擅长处理较长的序列信号数据,LSTM提出后,被广泛应用在序列预测、自然语言处理等任务中,几乎取代了基础的RNN模型。

 接下来,我们将介绍更加流行、更加强大的LSTM网络。

0. LSTM原理

 基础的RNN网络结构如下图所示,上一个时间戳的状态向量ht−1\boldsymbol h_{t-1}ht−1​与当前时间戳的输入xt\boldsymbol x_txt​经过线性变换后,通过激活函数tanh\text{tanh}tanh后得到新的状态向量ht\boldsymbol h_tht​。

基础RNN结构框图

相对于基础的RNN网络只有一个状态向量ht\boldsymbol h_tht​,LSTM新增了一个状态向量ct\boldsymbol c_tct​,同时引入了门控(Gate)机制,通过门控单元来控制信息的遗忘和刷新,如下图所示:

LSTM结构框图

 在LSTM中,有两个状态向量c\boldsymbol cc和h\boldsymbol hh,其中c\boldsymbol cc作为LSTM的内部状态向量,可以理解为LSTM的内部状态向量Memory,而h\boldsymbol hh表示LSTM的输出向量。相对于基础的RNN来说,LSTM把内部Memory和输出分开为两个变量,同时利用三个门控:输入门(Input Gate)遗忘门(Forget Gate)输出门(Output Gate)来控制内部信息的流动。

 门控机制可以理解为控制数据流通量的一种手段,类比于水阀门:当水阀门全部打开时,水流畅通无阻地通过;当水阀门全部关闭时,水流完全被隔断。在LSTM中,阀门开合程度利用门控值向量g\boldsymbol gg表示,如下图所示,通过σ(g)σ(\boldsymbol g)σ(g)激活函数将门控制压缩到[0,1][0,1][0,1]之间的区间,当σ(g)=0σ(\boldsymbol g)=0σ(g)=0时,门控全部关闭,输出o=0\boldsymbol o=0o=0;当σ(g)=1σ(\boldsymbol g)=1σ(g)=1时,门控全部打开,输出o=x\boldsymbol o=\boldsymbol xo=x。通过门控机制可以较好地控制数据的流量程度。

门控机制

 下面我们分别来介绍三个门控的原理及其作用。

1. 遗忘门

 遗忘门作用于LSTM状态向量c\boldsymbol cc上面,用于控制上一个时间戳的记忆ct−1\boldsymbol c_{t-1}ct−1​对当前时间戳的影响。遗忘门的控制变量gf\boldsymbol g_fgf​由
gf=σ(Wf[ht−1,xt]+bf)\boldsymbol g_f=σ(\boldsymbol W_f [\boldsymbol h_{t-1},\boldsymbol x_t ]+\boldsymbol b_f)gf​=σ(Wf​[ht−1​,xt​]+bf​)
产生,如下图所示:

遗忘门

 其中Wf\boldsymbol W_fWf​和bf\boldsymbol b_fbf​为遗忘门的参数张量,可由反向传播算法自动优化,σσσ激活函数,一般使用Sigmoid函数。当门控gf=1\boldsymbol g_f=1gf​=1时,遗忘门全部打开,LSTM接受上一个状态ct−1\boldsymbol c_{t-1}ct−1​,输出为0的向量。这也是遗忘门的名字由来。

 经过遗忘门后,LSTM的状态向量变为gfct−1\boldsymbol g_f \boldsymbol c_{t-1}gf​ct−1​。

2. 输入门

 输入门用于控制LSTM对输入的接收程度。首先通过对当前时间戳的输入xt\boldsymbol x_txt​和上一个时间戳的输出ht−1\boldsymbol h_{t-1}ht−1​做非线性变换得到新的输入向量c~t\tilde{\boldsymbol c}_tc~t​:
c~t=tanh⁡(Wc[ht−1,xt]+bc)\tilde{\boldsymbol c}_t=\text{tanh}⁡(\boldsymbol W_c [\boldsymbol h_{t-1},\boldsymbol x_t ]+\boldsymbol b_c)c~t​=tanh⁡(Wc​[ht−1​,xt​]+bc​)
其中Wc\boldsymbol W_cWc​和bc\boldsymbol b_cbc​为输入门的参数,需要通过反向传播算法自动优化,tanh\text{tanh}tanh为激活函数,用于将输入标准化到[−1,1][-1,1][−1,1]区间。c~t\tilde{\boldsymbol c}_tc~t​并不会全部刷新进入LSTM的Memory,而是通过输入门控制接受输入的量。输入门的控制变量同样来自于输入xt\boldsymbol x_txt​和输出ht−1\boldsymbol h_{t-1}ht−1​:
gi=σ(Wi[ht−1,xt]+bi)\boldsymbol g_i=σ(\boldsymbol W_i [\boldsymbol h_{t-1},\boldsymbol x_t ]+\boldsymbol b_i)gi​=σ(Wi​[ht−1​,xt​]+bi​)
其中Wi\boldsymbol W_iWi​和bi\boldsymbol b_ibi​为输入门的参数,需要通过反向传播算法自动优化,σσσ为激活函数,一般使用Sigmoid函数。输入门控制变量gi\boldsymbol g_igi​决定了LSTM对当前时间戳的新输入c~t\tilde{\boldsymbol c}_tc~t​的接受程度:当gi\boldsymbol g_igi​=0时,LSTM不接受任何新的输入c~t\tilde{\boldsymbol c}_tc~t​;当gi=1\boldsymbol g_i=1gi​=1时,LSTM全部接受新输入c~t\tilde{\boldsymbol c}_tc~t​,如下图所示:

输入门

 经过输入门后,待写入Memory的向量为gic~t\boldsymbol g_i \tilde{\boldsymbol c}_tgi​c~t​。

3. 刷新Memory

 在遗忘门和输入门的控制下,LSTM有选择地读取了上一个时间戳的记忆ct−1\boldsymbol c_{t-1}ct−1​和当前时间戳的新输入c~t\tilde{\boldsymbol c}_tc~t​,状态向量ct\boldsymbol c_tct​的刷新方式为:
ct=gic~t+gfct−1\boldsymbol c_t=\boldsymbol g_i \tilde{\boldsymbol c}_t+\boldsymbol g_f \boldsymbol c_{t-1}ct​=gi​c~t​+gf​ct−1​
得到的新状态向量ct\boldsymbol c_tct​即为当前时间戳的状态向量。如下图所示:

刷新Memory

4. 输出门

 LSTM的内部状态向量ct\boldsymbol c_tct​并不会直接用于输出,这一点和基础的RNN不一样。基础的RNN网络的状态向量h\boldsymbol hh既用于记忆,又用于输出,所以基础的RNN可以理解为状态向量c\boldsymbol cc和输出向量h\boldsymbol hh是同一个对象。在LSTM内部,状态向量并不会全部输出,而是在输出门的作用下有选择地输出。输出门的门控变量go\boldsymbol g_ogo​为:
go=σ(Wo[ht−1,xt]+bo)\boldsymbol g_o=σ(\boldsymbol W_o [\boldsymbol h_{t-1},\boldsymbol x_t ]+\boldsymbol b_o)go​=σ(Wo​[ht−1​,xt​]+bo​)
其中Wo\boldsymbol W_oWo​和bo\boldsymbol b_obo​为输出门的参数,同样需要通过反向传播算法自动优化,σσσ为激活函数,一般使用Sigmoid函数。当输出门go=0\boldsymbol g_o=0go​=0时,输出关闭,LSTM的内部记忆完全被隔断,无法用作输出,此时输出为0的向量;当输出门go=1\boldsymbol g_o=1go​=1时,输出完全打开,LSTM的状态向量ct\boldsymbol c_tct​全部用于输出。LSTM的输出由:
ht=go⋅tanh⁡(ct)\boldsymbol h_t=\boldsymbol g_o\cdot \text{tanh⁡}(\boldsymbol c_t)ht​=go​⋅tanh⁡(ct​)
产生,即内存向量ct\boldsymbol c_tct​经过tanh\text{tanh}tanh激活函数后与输入门作用,得到LSTM的输出。由于go∈[0,1]\boldsymbol g_o\in[0,1]go​∈[0,1],tanh⁡(ct)∈[−1,1]\text{tanh}⁡(\boldsymbol c_t )\in[-1,1]tanh⁡(ct​)∈[−1,1],因此LSTM的输出ht∈[−1,1]\boldsymbol h_t\in[-1,1]ht​∈[−1,1]。

输出门

5. 小结

 LSTM虽然状态向量和门控数量较多,计算流程相对复杂。但是由于每个门控功能清晰明确,每个状态的作用也比较好理解。这里将典型的门控行为列举出来,并解释其代码的LSTM行为,如下表所示:

输入门和遗忘门的典型行为

输入门控 遗忘门控 LSTM行为
0 1 只使用记忆
1 1 综合输入和记忆
0 0 清零记忆
1 0 输入覆盖记忆

深度学习之循环神经网络(8)长短时记忆网络(LSTM)相关推荐

  1. 深度学习代码实战演示_Tensorflow_卷积神经网络CNN_循环神经网络RNN_长短时记忆网络LSTM_对抗生成网络GAN

    前言 经过大半年断断续续的学习和实践,终于将深度学习的基础知识看完了,虽然还有很多比较深入的内容没有涉及到,但也是感觉收获满满.因为是断断续续的学习做笔记写代码跑实验,所以笔记也零零散散的散落在每个角 ...

  2. 长短时记忆神经网络python代码_零基础入门深度学习(6) - 长短时记忆网络(LSTM)

    无论即将到来的是大数据时代还是人工智能时代,亦或是传统行业使用人工智能在云上处理大数据的时代,作为一个有理想有追求的程序员,不懂深度学习(Deep Learning)这个超热的技术,会不会感觉马上就o ...

  3. 深度学习之长短时记忆网络(LSTM)

    本文转自<零基础入门深度学习>系列文章,阅读原文请移步这里 之前我们介绍了循环神经网络以及它的训练算法.我们也介绍了循环神经网络很难训练的原因,这导致了它在实际应用中,很难处理长距离的依赖 ...

  4. 深度学习(7) - 长短时记忆网络(LSTM)

    长短时记忆网络是啥 我们首先了解一下长短时记忆网络产生的背景.回顾一下零基础入门深度学习(5) - 循环神经网络中推导的,误差项沿时间反向传播的公式: 我们可以根据下面的不等式,来获取的模的上界(模可 ...

  5. 小常识10: 循环神经网络(RNN)与长短时记忆网络LSTM简介。

    小常识10:  循环神经网络(RNN)与长短时记忆网络LSTM简介. 本文目的:在计算机视觉(CV)中,CNN 通过局部连接/权值共享/池化操作/多层次结构逐层自动的提取特征,适应于处理如图片类的网格 ...

  6. 深度学习之循环神经网络(11-b)GRU情感分类问题代码

    深度学习之循环神经网络(11-b)GRU情感分类问题代码 1. Cell方式 代码 运行结果 2. 层方式 代码 运行结果 1. Cell方式 代码 import os import tensorfl ...

  7. 深度学习之循环神经网络(11-a)LSTM情感分类问题代码

    深度学习之循环神经网络(11-a)LSTM情感分类问题代码 1. Cell方式 代码 运行结果 2. 层方式 代码 运行结果 1. Cell方式 代码 import os import tensorf ...

  8. 深度学习之循环神经网络(10)GRU简介

    深度学习之循环神经网络(10)GRU简介 1. 复位门 2. 更新门 3. GRU使用方法  LSTM具有更长的记忆能力,在大部分序列任务上面都取得了比基础RNN模型更好的性能表现,更重要的是,LST ...

  9. 深度学习之循环神经网络(7)梯度裁剪

    深度学习之循环神经网络(7)梯度裁剪 1. 张量限幅 2. 限制范数 3. 全局范数裁剪 梯度弥散  梯度爆炸可以通过 梯度裁剪(Gradient Clipping)的方式在一定程度上的解决.梯度裁剪 ...

最新文章

  1. 【物理方程】物理学中最难的方程之一,你知道多少?
  2. 如何使用XML作为小型数据库
  3. Mybatis的xml文件中$ 与 #的区别
  4. 算法题-大数相乘问题
  5. c++ primer 5th第13章拷贝控制知识点和自编习题答案
  6. js实现星级评分效果(非常规5个li代码)
  7. tomcat的wget链接_Linux(jdk安装tomcat安装nginx安装gcc/wget)
  8. JavaFX官方教程(十五)之A Xylophone.java
  9. JsonData工具类
  10. RedEngine11
  11. python中function函数的用法_Python中Function(函数)和methon(方法)
  12. ios获取区域服务器信息,ios获取服务器数据
  13. JUC和线程池的详细讲解
  14. javascript基础知识练习题
  15. 计算几何小结 我对计算几何的理解以及叉积和点积
  16. No Route to Host from master/192.168.2.131 to master:9000 failed on socket t
  17. 怎样去除图片水印?教你一个一键去除水印的方法
  18. [论文写作笔记] C8 讨论用于增加论文厚度
  19. 安装ubuntu16.04 14.04 登录时一直显示紫色问题
  20. [附源码]java毕业设计图书借阅系统

热门文章

  1. 《深入浅出数据分析》第十二章——R语言lattice数据包
  2. 矽谷真假U盘测试软件,要闻回顾_科技时代_新浪网
  3. C++17下map不常用的接口函数汇总
  4. f12控制台如何查看consul_基于 Consul 的 Go Micro 客户端服务发现是如何实现的
  5. 计算机二级操作范文,计算机二级考试(范文).doc
  6. 211计算机实力末尾的学校,实力最弱的十所985大学是哪几所?选择末尾985好还是选211好?...
  7. 在LINQ to SQL中使用Translate方法以及修改查询用SQL
  8. shell切割日志脚本
  9. Unity 8 和 Snap 将会是 Ubuntu 的未来
  10. Kafka学习-入门