简单介绍长短期记忆网络 - LSTM
文章目录
- 一、引言
- 1.1 什么是LSTM
- 二、循环神经网络RNN
- 2.1 为什么需要RNN
- 三、长短时记忆神经网络LSTM
- 3.1 为什么需要LSTM
- 3.2 LSTM结构分析
- 3.3 LSTM背后的核心思想
- 3.4 LSTM的运行机制
- 3.5 LSTM如何避免梯度下降
- 四、入门例子
- 五、总结
- 六、参考资料
一、引言
1.1 什么是LSTM
首先看看百科的解释。
长短期记忆(英语:Long Short-Term Memory,LSTM)是一种时间循环神经网络(RNN),论文首次发表于1997年。由于独特的设计结构,LSTM适合于处理和预测时间序列中间隔和延迟非常长的重要事件。1
为了更好地理解长短期记忆网络 - LSTM(下文简称LSTM),可以先了解循环神经网络-RNN(下文简称RNN)的相关知识,这里有一些相关的文章。LSTM只是RNN的一个变种,LSTM是为了解决RNN中的梯度消失的问题而提出的。
二、循环神经网络RNN
2.1 为什么需要RNN
人的思想是有记忆延续性。比如当你在阅读这篇文章,你会根据你曾经对每个字的理解来理解这篇文章的字,而不是每次都要思考一个字在这篇文章的语境下到底如何理解(从一个字或词的多种解释来选择一个符合当下语境的解释)。
举个例子:要识别这么一个句子:
The cat, which already ate cakes, () full.2
假设对其中的单词从左到右一个一个地处理,前面已经cat的识别结果是一个单数名词,到后边()里的内容,到底是填were 还是 was,那么就需要根据前边cat的识别结果进行判断。这就是RNN需要做的。
使用神经网络来预测句子中下一个字的解释。传统的神经网络在模型训练好了以后,在输入层给定一个x,通过网络之后就能在输出层得到特定的y。利用这个模型可以通过训练拟合任意函数,但是只能单独的取处理一个个的输入,前一个输出和后一个输出是完全没有关系的。
神经网络的结构如下:
但是,在理解一句话的意思的时候,一个字的意思是跟前面的字相关联的,即前面的输出和后面的输出是有关系的。所以仅仅利用这样的模型是不够的的,为了解决这个问题,有人提出了RNN。
RNN模型构造:
RNN神经网络示意图:
蓝色部分的是隐藏层,RNN利用隐藏层将信息向后传递。
我们来看看RNN隐藏层里发生了什么,将上图按时间线展开3:
符号 | 意义 |
---|---|
X | 一个向量,输入层的值 |
S | 一个向量,隐藏层的值 |
O | 一个向量,输出层的值 |
U | 输入层到隐藏层的权重矩阵 |
V | 隐藏层到输出层的权重矩阵 |
W | 隐藏层上一次的值作为这一次输入的权重 |
再给出一个更具体的图,给出各层元素的对应关系
现在看上去就比较清楚了,这个网络在 t 时刻接收到输入 xtx_txt 之后,隐藏层的值是 sts_tst ,输出值是 oto_tot 。关键一点是,sts_tst 的值不仅仅取决于 xtx_txt ,还取决于 st−1s_{t-1}st−1 。 我们可以用下面的公式来表示RNN的计算方法:
用公式表示如下:
Ot=g(V⋅St)O_t = g(V·S_t) Ot=g(V⋅St)
St=f(U⋅Xt+W⋅St−1)S_t = f(U·X_t + W ·S_{t-1}) St=f(U⋅Xt+W⋅St−1)
注意:为了简单说明问题,偏置都没有包含在公式里面。
这样,就可以做到的在一个序列中根据前面的输出来影响后面的输出。
三、长短时记忆神经网络LSTM
3.1 为什么需要LSTM
回到我们的例子:
The cat, which already ate …, () full.
这个例子与之前的例子稍微有一些不同,这里的cat 和()之间已经相隔了较长的一段距离,这时候用RNN来处理这样的长期信息就不太合适。
因为RNN在反向传播阶段有梯度消失等问题不能处理长依赖问题,这里的梯度消失是由于RNN在计算过程中使用链式法则。
具体来说,RNN使用覆盖的方式来计算状态:St=f(St−1,xt)S_t = f(S_{t-1},x_t)St=f(St−1,xt),这类似于复合函数,根据链式求导的法则,复合函数求导:设fff 和 ggg 为 xxx 的可导函数,则(f∘g)′(x)=f′(g(x))g′(x)(f \circ g)'(x) = f'(g(x))g'(x)(f∘g)′(x)=f′(g(x))g′(x),这是一种连乘的方式,如果导数小于或大于1,会发生梯度下降以及梯度爆炸。梯度爆炸可以通过剪枝算法解决,但是梯度消失却没办法解决。
梯度消失可能不太好理解,可以简单理解为RNN中后边输入的数据影响越大,前面的数据的影响小,因此不能处理长期信息。后来,有学者在一篇论文Long Short-Term Memory 4 提出了LSTM,LSTM通过选择性地保留信息,有效地缓解了梯度消失以及梯度下降的问题,可以说LSTM正是为了适合学习长期依赖而产生的。
3.2 LSTM结构分析
回顾一下RNN的模型构造:
可以看到,RNN循环网络模型的链式结构非常简单,通常仅含有一个tanh层。
LSTM模型构造:
而LSTM的链式结构中,循环单元结构不同,里边有四个神经网络层。
先来解释一下图中符号含义:
符号 | 含义 |
---|---|
黄色矩形 | 神经网络层 |
粉色圆 | 结点操作,比如向量相加 |
箭头 | 从一个结点的输出到另外的结点的输入 |
箭头合并 | 链接 |
箭头分叉 | 内容复制后副本流向不同的位置 |
LSTM结构(图右)和普通RNN的主要输入输出区别如下所示:
相比RNN只有一个传递状态 hth^tht , LSTM有两个传输状态,一个 ctc^tct (cell state), 和一个 hth^tht (hidden state)。(RNN中的 hth^tht 对应LSTM中的 CtC^tCt)
3.3 LSTM背后的核心思想
LSTM的核心思想,LSTM的关键是细胞状态(cell state),即下图中上边的水平线。cell state像是一条传送带,它贯穿整条链,其中只发生一些小的线性作用。信息流过这条线而不改变是非常容易的。5 改变cell state需要三个门的相互配合。
如下图所示:
LSTM删除或添加信息到cell state,是由被称为门的结构控制的。LSTM中有三个门,“遗忘门” “输入门” 以及“输出门”,用来保护和更新cell的状态。
门是筛选信息的方法,由一个sigmoid网络层和一个点乘操作组成。
如下图:
sigmoid层作为激活函数,将输出控制在(0,1)区间内,Sigmoid的函数图形如下:
可以看到,绝大多数的值都是接近0或者接近1的。利用这一个性质,0 表示不允许任何通过,1 表示允许一切通过。
3.4 LSTM的运行机制
第一步,需要决定从cell state中丢弃什么样的信息,这个由“遗忘门”的sigmoid层决定。根据输入ht−1h_{t-1}ht−1 和 xtx_txt,得到的输出是0和1之间的数。0 代表“完全保留这个值”,1代表“完全丢弃这个值”。
回到开始的例子,原来的主语是"cat",之后遇到了一个新的主语"cats"。这时需要把之前的"cat"给忘掉,以便确定接下来是要使用"were",而不是"was"。如下图:
第二步,需要决定在cell state里存储什么样的信息。这一步划分为两个部分,一是称为“输入门”的sigmoid层决定哪些数据需要更新。然后,tanh层创建一个新的候选值向量C~t\widetilde{C}_tCt,这些值能加入state中。第二部分,需要将这两个部分合并以实现对state的更新。
在例子中,这里对应于把新的"cats"加入到"cell state"中,以替代需要遗忘的"cat"。如下图:
在决定好需要遗忘的以及需要加入的记忆之后,就可以把旧的cell state Ct−1C_{t-1}Ct−1更新到新的cell state CtC_tCt。 这一步中,把旧的state Ct−1C_{t-1}Ct−1 与ftf_tft 相乘,遗忘先前决定遗忘的东西,之后加上新的记忆信息 it∗C~ti_t \ast \widetilde{C}_tit∗Ct。这里为了体现对状态值的更新度是有限制的,可以把iti_tit当成一个权重。如下图:
最后,需要决定输出。这个输出将会基于cell state ,这是一个过滤后的值。首先,使用“输出门”的sigmoid层决定输出cell state的哪些部分的。然后,将cell state放入tanh(将数值限制在-1到1),最后将结果与sigmoid门的输出相乘,这样就可以只输出需要的部分。如下图:
3.5 LSTM如何避免梯度下降
上边提到了RNN中的梯度下降以及梯度爆炸问题,是是因为在计算过程中使用链式法则,使用了乘积。而在LSTM中,状态是通过累加的方式来计算,St=∑τ=1tΔSτS_t = \sum_{\tau =1}^t \Delta S_{\tau}St=∑τ=1tΔSτ。这样的计算,就不是复合函数的形式,它的导数也就不是乘积的形式,就不会发生梯度消失的情况。
四、入门例子
下面给出LSTM的一个入门实例-根据前9年的数据预测后3年的客流6,感谢原作者的代码,完整的代码见GithubYonv1943。这里简单说一下这个代码实例的结果,需要了解更加详细的代码细节可以看看原作者的原文详解。
考虑有一组某机场1949年~1960年12年共144个月的客流量数据。使用这个数据中的前9年的客流量来预测后3年的客流量,再和实际的数据进行比对,可以看出LSTM的对这类具有时序关系的拟合效果。
结果图:
- 数据:机场1949~1960年12年共144个月的客流量数据。数据具有三个维度[客运量,年份,月份]。其中前75%(前9年)的数据作为训练集,后25%(后3年)的数据作为测试集。
- 纵坐标:标准化处理:变量值与平均数的差除以标准差,给出数值的相对位置。横坐标为月数。
- 图解释:竖直黑线左边是训练集(前9年)。右边(后3年)红色的是预测数值,蓝色的是实际数值。
可以看到在这个LSTM对这个数据集的拟合效果是比较好的,在这样的实际场景中,可以利用LSTM这样的工具来对客流量做一个预测,以便对客运高峰等情况做好预备方案。
五、总结
- RNN的计算中存在多个偏导数连乘,导致梯度消失或梯度爆炸,难以处理长依赖的信息。
- LSTM通过三个选择性地保留信息,可以选择最近的信息或者很久之前的信息。
- LSTM更新cell state是采用了线性求和的计算,因此不会出现梯度消失问题,可以处理长期依赖的信息。
六、参考资料
长短期记忆 ↩︎
吴恩达深度学习课程 ↩︎
一文搞懂RNN(循环神经网络)基础篇 ↩︎
Long Short-Term Memory ↩︎
Understanding LSTM Networks ↩︎
LSTM入门例子:根据前9年的数据预测后3年的客流(PyTorch实现) ↩︎
简单介绍长短期记忆网络 - LSTM相关推荐
- 『NLP学习笔记』长短期记忆网络LSTM介绍
长短期记忆网络LSTM介绍 文章目录 一. 循环神经网络 二. 长期依赖问题 三. LSTM 网络 四. LSTM 背后的核心理念 4.1 忘记门 4.2 输入门 4.3 输出门 五. LSTM总结( ...
- MATLAB-基于长短期记忆网络(LSTM)的SP500的股票价格预测 股价预测 matlab实战 数据分析 数据可视化 时序数据预测 变种RNN 股票预测
MATLAB-基于长短期记忆网络(LSTM)的SP500的股票价格预测 股价预测 matlab实战 数据分析 数据可视化 时序数据预测 变种RNN 股票预测 摘要 近些年,随着计算机技术的不断发展,神 ...
- 基于长短期记忆网络(LSTM)对股票价格的涨跌幅度进行预测
完整代码:https://download.csdn.net/download/qq_38735017/87536579 为对股票价格的涨跌幅度进行预测,本文使用了基于长短期记忆网络(LSTM)的方法 ...
- 1014长短期记忆网络(LSTM)
长短期记忆网络(LSTM) 长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题,解决这个问题最早的方法之一就是 LSTM 发明于90年代 使用的效果和 GRU 相差不大,但是使用的东西更加复杂 ...
- 长短期记忆网络(LSTM)学习笔记
文章目录 0 前言 1 LSTM与RNN的异同 2 LSTM结构细节 2.1 细胞状态 2.2 遗忘门 2.3 输入门 2.4 输出门 3 总结 4 LSTM的变体 4.1 Adding " ...
- 长短期记忆网络 LSTM
这里写目录标题 1. LSTM介绍 1.1 什么是LSTM 1.2 LSTM相较于RNN的优势 1.3 LSTM的结构图 1.3.1 LSTM的核心思想 1.3.2 LSTM的遗忘门 1.3.3 LS ...
- 白话机器学习-长短期记忆网络LSTM
一 背景 既然有了RNN,为何又需要LSTM呢? 循环神经网络RNN的网络结构使得它可以使用历史信息来帮助当前的决策.例如使用之前出现的单词来加强对当前文字的理解.可以解决传统神经网络模型不能充分利用 ...
- 长短期记忆网络LSTM
1. LSTM是循环神经网络的一个变体可以有效的解决简单循环神经网络的梯度消失和梯度爆炸的问题. 2. 改进方面: 新的内部状态 Ct专门进行线性的循环信息传递,同时(非线性的)输出信息给隐藏层的外部 ...
- keras实现简单lstm_四十二.长短期记忆网络(LSTM)过程和keras实现股票预测
一.概述 传统循环网络RNN可以通过记忆体实现短期记忆进行连续数据的预测,但是,当连续数据的序列边长时,会使展开时间步过长,在反向传播更新参数的过程中,梯度要按时间步连续相乘,会导致梯度消失或者梯度爆 ...
最新文章
- 0x02.基本算法 — 递推与递归
- java存储字节,java 数目字转化成字节存储算法
- 腾讯游戏主美:二次元卡通渲染有哪些黑科技?
- 纯 js 导出 excel
- DEDECMS 5.6整合Discuz_X1.5的方法
- 数易云备开启虚拟机备份新时代
- 创意三维c4d形式设计节气海报学习案例
- 全栈python_Pyodide:在浏览器端实现Python全栈科学计算
- 算法练习——聪明的情侣
- java jdbc_详解Java基础知识——JDBC
- UE中使用正则表达式的一些技巧
- 使用uniapp获取当前位置
- Vueb报错[WDS] Errors while compiling. Reload prevented
- Puppeteer之Pyppeteer——浏览某短视频,获取点赞和评论,收藏,转发数(5)
- 前端学习:jQuery学习--Day03
- mysql tablespace is missing for table_Mysql报错:Tablespace is missing for table ‘db_rsk/XXX”
- 生活中的统计概率思维
- MultipartFile 转换为File
- php属于c,c语言属于哪个?php还是java?
- 子查询和关联查询 效率
热门文章
- SpringBoot2.0学习笔记 使用Actualor监控项目运行状态
- 电脑录屏软件哪个好用?3款屏幕录制大师分享!
- JY62陀螺仪的联调用STM32CubeMX
- Spring-IoC-03
- 如何选定搭建个人独立博客工具
- 搜索排序LambdaMART中Lambda的计算过程java版本
- animate.css 动画库的下载与使用
- CTFHub-web前置技能-请求方式、302跳转、cookie、基础认证、响应包源代码
- Cesium实践(4)——空间数据加载
- redis学习--三种特殊数据类型,GEO地理位置,HyperLogLog,BitMap