大家好,今天和各位分享一下长短时记忆网络 LSTM 的原理,并使用 Pytorch 从公式上实现 LSTM 层

上一节介绍了循环神经网络 RNN,感兴趣的可以看一下:https://blog.csdn.net/dgvv4/article/details/125424902

我的这个专栏中有许多 LSTM 的实战案例,便于大家巩固知识:https://blog.csdn.net/dgvv4/category_11712004.html


1. 引言

循环神经网络的记忆功能在处理时间序列问题上存在很大优势,但随着训练的不断进行,RNN 网络一直在不断的扩充记忆,致使 RNN 产生梯度消失以及梯度爆炸

为了解决RNN难以有效训练的问题,拥有选择记忆功能的 LSTM模型被提出。LSTM 是在 RNN 的基础上进行的改进,其既能学习数据中的长期依赖,又能解决梯度消失。LSTM 包含一个记忆单元和三个门结构,其中门结构分别是输入门、输出门和遗忘门。

LSTM 的工作过程如下:

首先由输入数据 X_t 前一时刻隐藏层的输出数据 h_t-1 共同作用于遗忘门,遗忘门对上述信息进行筛选,记忆时间序列中的重要特征信息,丢弃无关紧要的信息;然后将输入数据 x_t 以及前一时刻隐藏层的输出数据 h_t-1 作为输入门的输入信息,进行更新;其次记忆单元通过输入数据 X_t、前一时刻隐藏层的输出数据 h_t-1 以及前一时刻的记忆单元状态 C_t-1 对自身状态进行更新;最后将输入数据 X_t前一时刻隐藏层的输出数据 h_t-1 以及当前时刻的记忆单元状态 C_t 共同作用于输出门,输出当前时刻的隐藏层信息 h_t

LSTM 的结构图如下:


2. 原理解析

2.1 遗忘门

上一时刻的输出 h_t-1当前时刻的输入 X_t 结合,并通过 Sigmoid 函数计算得到一个阈值为 [0,1] 的张量 f_t,该 f_t 可以看作是对上一时刻的状态 C_t-1 的权重项,f_t 负责控制上一时刻状态需要被遗忘的程度

计算公式:

将公式展开,其中 W_if 是对当前时刻输入的特征提取,W_hf 是对前一时刻状态的特征提取,@ 代表矩阵相乘。


2.2 输入门

输入门是与 tanh 函数配合控制新信息加入的程度。在这个过程中,tanh 函数会给出一个新的候选向量 ,输入门为  中的每一项产生一个在 [0,1] 之间的值 i_t,控制新信息被加入的多少。

计算公式:

公式展开,其中 W_i 是对当前时刻输入的特征提取,W_h 是对前一时刻状态的特征提取,@ 代表矩阵相乘。

至此,模型已经计算了遗忘门的输出 f_t,和输入门的输出 i_t分别用来控制上一时刻的状态需要被遗忘的程度,和新增信息的规模,接下来可以根据这两个输出更新当前时刻的状态 C_t

计算公式,其中 * 代表张量之间逐元素相乘。


2.3 输出门

输出门用来过滤当前状态的某些信息,将其舍去。输出门的计算过程,将输入数据 X_t前一时刻隐藏层的输出数据 h_t-1 经过 sigmoid 函数,把每一项的值压缩到 [0-1] 之间作为过滤信息的权重项。然后与更新后的当前状态 C_t 逐元素相乘,

计算公式:

公式展开:


3. 代码实现

3.1 官方 API

torch.nn.LSTM() 参数如下:

lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False)
'''
input_size: 每个单词使用多少长的向量来表示
hidden_size: 隐含层,经过LSTM层后每个单词用多长的向量来表示
num_layers: LSTM的层数
bias: 是否使用偏置项,默认为True,即 w@x+b
batch_first: 对于输入是否将batch放在axis=0的位置,默认False,即[seq_len, batch, feature_len]
'''

实例化单层 LSTM,做一次前向传播,查看输出信息

import torch
from torch import nn# 定义参数
batch = 3  # 现在有3个句子
seq_len = 10  # 每个句子有10个单词
feature_len = 100  # 每个单词用长度为100的向量来表示
hidden_len = 20  # 经过LSTM层后每个单词用长度为20的向量来表示# 当前时刻的输入 [batch, seq_len, feature_len]
inputs = torch.randn(batch, seq_len, feature_len)# 上一时刻的状态 [batch, hidden_len]
h0 = torch.randn(batch, hidden_len)
c0 = torch.randn(batch, hidden_len)# 实例化LSTM层
lstm = nn.LSTM(input_size=feature_len, hidden_size=hidden_len, num_layers=1, batch_first=True)# c:最后一个单词更新的状态,[num_layer, batch, hidden_size]
# h:最后一个单词的输出,[num_layer, batch, hidden_size]
# out: 整体输出结果,[batch, seq_len, hidden_size]
out, (h,c) = lstm(inputs)print('out:', out.shape,  # [3, 10, 20]'h:', h.shape,      # [1, 3, 20]'c:', c.shape)      # [1, 3, 20]# 查看权重信息
for k,v in lstm.named_parameters():print(k, v.shape)'''
weight_ih_l0 torch.Size([80, 100])
weight_hh_l0 torch.Size([80, 20])
bias_ih_l0 torch.Size([80])
bias_hh_l0 torch.Size([80])
'''

3.2 自定义函数

接下来根据第二小节解释过的公式,从原理上实现一个 LSTM 层,主要就是6个公式的计算,还要注意张量的shape变化。

代码实现如下:

import torch
from torch import nn'''
inputs: 当前时刻的输入 [batch, seq_len, feature_len]
c0: 上一时刻的状态,[batch, hidden_len]
h0: 上一时刻的输出,[batch, hidden_len]w_ih, b_ih: 对当前时刻输入的特征矩阵和偏置
w_hh, b_hh: 对上一时刻状态的特征矩阵和偏置w_ih.shape=[4*hdiien_size, feature_len]
w_hh.shape=[4*hdiien_size, hidden_len]
b.shape=[4*hidden_size]
'''# ------------------------------------------------------------- #
#(1)自定义LSTM模型
# ------------------------------------------------------------- #
def lstm_forward(inputs, initial_states, w_ih, w_hh, b_ih, b_hh):h0, c0 = initial_states  # 获取初始状态# batch代表序列个数,seq_len代表每个序列有多少个样本,feature_len代表每个样本有多少个特征batch, seq_len, feature_len = inputs.shape  # 获取输入的形状# 获取隐含层个数,根据公式由4个W拼接而成hidden_len = w_ih.shape[0] // 4   # weight_ih_l0 torch.Size([80, 100])# 初始化输出层 [batch, seq_len, hidden_len]outputs = torch.zeros(batch, seq_len, hidden_len)# 在LSTM中不断更新上一时刻的状态pre_h, pre_c = h0, c0# 扩充w的维度==>[b, 4*hdiien_size, feature_len]batch_w_ih = w_ih.unsqueeze(0).tile(batch, 1, 1)# ==>[b, 4*hdiien_size, hidden_len]batch_w_hh = w_hh.unsqueeze(0).tile(batch, 1, 1)# 遍历每个序列中的每个单词for t in range(seq_len):# 获取当前时刻的输入张量x = inputs[:, t, :]  # [b, feature_len]# 三维矩阵相乘 [b, 4*hdiien_size, feature_len] @ [b, feature_len, 1]w_time_x = torch.bmm(batch_w_ih, x.unsqueeze(-1))  # [b, 4*hidden_len, 1]w_time_x = w_time_x.squeeze(-1)  # [b, 4*hidden_len]# 状态的矩阵相乘 [b, 4*hdiien_size, hidden_len] @ [b, hidden_len, 1]w_time_h_pre = torch.bmm(batch_w_hh, pre_h.unsqueeze(-1))  # [b, 4*hidden_size, 1]w_time_h_pre = w_time_h_pre.squeeze(-1)   # [b, 4*hidden_size]# 取前1/4用作输入门(i)i_t = w_time_x[:, :hidden_len] + b_ih[:hidden_len] + w_time_h_pre[:, :hidden_len] + b_hh[:hidden_len]i_t = torch.sigmoid(i_t)# 遗忘门(f)f_t = w_time_x[:, hidden_len:hidden_len*2] + b_ih[hidden_len:hidden_len*2] + w_time_h_pre[:, hidden_len:hidden_len*2] + b_hh[hidden_len:hidden_len*2]f_t = torch.sigmoid(f_t)# 细胞门(g)g_t = w_time_x[:, hidden_len*2:hidden_len*3] + b_ih[hidden_len*2:hidden_len*3] + w_time_h_pre[:, hidden_len*2:hidden_len*3] + b_hh[hidden_len*2:hidden_len*3]g_t = torch.tanh(g_t)# 输出门(o)o_t = w_time_x[:, hidden_len*3:] + b_ih[hidden_len*3:] + w_time_h_pre[:, hidden_len*3:] + b_hh[hidden_len*3:]o_t = torch.tanh(o_t)# 状态(c)pre_c = f_t * pre_c + i_t * g_t# 当前时刻lstm的输出(h)pre_h = o_t * torch.tanh(pre_c)# 更新输出层outputs[:, t, :] = pre_h# 返回输出、最后一个时刻的输出h,状态creturn outputs, (pre_h, pre_c)# ------------------------------------------------------------- #
#(2)前向传播
# ------------------------------------------------------------- #
batch = 3  # 3个句子
seq_len = 10  # 序列长度,每个句子有10个单词
feature_len = 100  # 特征个数,一个单词用长度为100的向量来表示
hidden_len = 20  # 隐含层,经过LSTM层后用长度为20的向量来表示# 构造输入层 [batch, seq_len, feature_len]
inputs = torch.randn(batch, seq_len, feature_len)# 初始状态,不需要训练 [batch, hidden_len]
h0 = torch.randn(batch, hidden_len)
c0 = torch.randn(batch, hidden_len)# 构造权重
w_ih = torch.randn(hidden_len*4, feature_len)  # [80, 100]
w_hh = torch.randn(hidden_len*4, hidden_len)   # [80, 100]
# 构造偏执
b_ih = torch.randn(hidden_len*4)   # [80]
b_hh = torch.randn(hidden_len*4)   # [80]# lstm层计算结果
outputs, (final_h, final_c) = lstm_forward(inputs, (h0, c0), w_ih, w_hh, b_ih, b_hh)'''
outputs: 所有句子的输出,[batch,seq_len, hidden_len]
pre_h: 最后一次个单词的输出,[batch, hidden_len]
pre_c: 最后一个单词的状态,[batch, hidden_len]
'''print('outputs.shape:', outputs.shape,    # [3, 10, 20]'pre_h.shape:', final_h.shape,      # [3, 20]'pre_c.shape:', final_c.shape)      # [3, 20]

【深度学习理论】(7) 长短时记忆网络 LSTM相关推荐

  1. 深度学习(7) - 长短时记忆网络(LSTM)

    长短时记忆网络是啥 我们首先了解一下长短时记忆网络产生的背景.回顾一下零基础入门深度学习(5) - 循环神经网络中推导的,误差项沿时间反向传播的公式: 我们可以根据下面的不等式,来获取的模的上界(模可 ...

  2. 长短时记忆神经网络python代码_零基础入门深度学习(6) - 长短时记忆网络(LSTM)

    无论即将到来的是大数据时代还是人工智能时代,亦或是传统行业使用人工智能在云上处理大数据的时代,作为一个有理想有追求的程序员,不懂深度学习(Deep Learning)这个超热的技术,会不会感觉马上就o ...

  3. 深度学习之长短时记忆网络(LSTM)

    本文转自<零基础入门深度学习>系列文章,阅读原文请移步这里 之前我们介绍了循环神经网络以及它的训练算法.我们也介绍了循环神经网络很难训练的原因,这导致了它在实际应用中,很难处理长距离的依赖 ...

  4. 小常识10: 循环神经网络(RNN)与长短时记忆网络LSTM简介。

    小常识10:  循环神经网络(RNN)与长短时记忆网络LSTM简介. 本文目的:在计算机视觉(CV)中,CNN 通过局部连接/权值共享/池化操作/多层次结构逐层自动的提取特征,适应于处理如图片类的网格 ...

  5. 长短时记忆网络(LSTM)的训练

    长短时记忆网络的训练 熟悉我们这个系列文章的同学都清楚,训练部分往往比前向计算部分复杂多了.LSTM的前向计算都这么复杂,那么,可想而知,它的训练算法一定是非常非常复杂的.现在只有做几次深呼吸,再一头 ...

  6. 长短时记忆网络LSTM

    网络介绍 长短时记忆网络(Long short time memory network, LSTM)是RNN的重要变体,解决了RNN无法长距离依赖的问题,同时缓了RNN的梯度爆炸问题.LSTM由遗忘门 ...

  7. 长短时记忆网络(LSTM)部分组件(六)

    在前面的几篇文章中试着实现了CNN,RNN的一些组件,这里继续学习LSTM,也是是实现部分组件,旨在学习其LSTM的原理. 具体参考: https://www.zybuluo.com/hanbingt ...

  8. 深度学习代码实战演示_Tensorflow_卷积神经网络CNN_循环神经网络RNN_长短时记忆网络LSTM_对抗生成网络GAN

    前言 经过大半年断断续续的学习和实践,终于将深度学习的基础知识看完了,虽然还有很多比较深入的内容没有涉及到,但也是感觉收获满满.因为是断断续续的学习做笔记写代码跑实验,所以笔记也零零散散的散落在每个角 ...

  9. 多元经验模态分解_交通运输|基于小波分解和长短时记忆网络的地铁进站量短时预测...

    山东科学 ›› 2019, Vol. 32 ›› Issue (4): 56-63.doi: 10.3976/j.issn.1002-4026.2019.04.008 摘要: 针对城市地铁车站进站客流 ...

最新文章

  1. bzoj 3339 莫队
  2. vue 带全选和多选的表格怎么写_vue中使用计算属性巧妙的实现多选框的“全选”...
  3. VI常用使用命令 为初次接触VI 的兄弟们献微利
  4. 学习python用哪个app-Python和R:学哪个好?
  5. 算法导论6.1-2习题解答
  6. Oracle 快速插入1000万条数据的实现方式
  7. LightSpeed 的Left Join Bug解决方案
  8. Hadoop YARN:调度性能优化实践【转】
  9. python常用8大算法
  10. 使用QtService实现Qt后台服务程序
  11. 数据类型的判断 --Object.prototype.toString.call(obj)精准检测对象类型
  12. 模糊查询关键字不区分大小写_SQL简单查询
  13. 调参方法论:如何提高机器学习模型的性能?
  14. 会说话的狗狗本电脑版_「电脑知识」硬件狗狗专业电脑硬件检测跑分工具免安装单文件版|电脑硬件|电脑|显卡|狗狗|操作系统...
  15. 中安证件识别系统介绍
  16. 新概念51单片机c语言教程考试题,新概念51单片机C语言教程例题.doc
  17. 微云html网页,微云收藏在哪里_以及腾讯微云收藏网页版怎么用? - 软件教程 - 格子啦...
  18. JavaScript中的作用域及作用域链
  19. C++中UTF-8, Unicode, GB2312转换及有无BOM相关问题
  20. ccc加拿大计算机竞赛在线评测系统,加拿大计算机竞赛简介

热门文章

  1. Logback (7) | Missing integer token, that is %i, in FileNamePattern
  2. ubuntu16.04 画图软件kolourpaint
  3. aws mysql 升级_AWS最新核心武器:升级主数据库转换工具
  4. Tarjan算法求无向图割边割点、最近公共祖先的总结
  5. ABLIC Inc.推出S-35710/20(I-系列)唤醒计时器IC
  6. 史上最糟糕的两个变量名
  7. ubuntu 20.04 Linux下查看当前文件夹的大小
  8. 特征检测与特征匹配方法汇总
  9. String类型相加随笔
  10. 通过 zerotier 访问所在局域网的其他设备