点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

最近在学习LSTM应用在时间序列的预测上,但是遇到一个很大的问题就是LSTM在传统BP网络上加上时间步后,其结构就很难理解了,同时其输入输出数据格式也很难理解,网络上有很多介绍LSTM结构的文章,但是都不直观,对初学者是非常不友好的。我也是苦苦冥思很久,看了很多资料和网友分享的LSTM结构图形才明白其中的玄机。

本文内容如下:

1、传统的BP网络和CNN网络
2、LSTM网络
3、LSTM的输入结构
4、pytorch中的LSTM
4.1 pytorch中定义的LSTM模型
4.2 喂给LSTM的数据格式
4.3 LSTM的output格式
5、LSTM和其他网络组合

传统的BP网络和CNN网络

BP网络和CNN网络没有时间维,和传统的机器学习算法理解起来相差无几,CNN在处理彩色图像的3通道时,也可以理解为叠加多层,图形的三维矩阵当做空间的切片即可理解,写代码的时候照着图形一层层叠加即可。如下图是一个普通的BP网络和CNN网络。

BP网络

CNN网络

图中的隐含层、卷积层、池化层、全连接层等,都是实际存在的,一层层前后叠加,在空间上很好理解,因此在写代码的时候,基本就是看图写代码,比如用keras就是:

# 示例代码,没有实际意义
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu'))  # 添加卷积层
model.add(MaxPooling2D(pool_size=(2, 2)))         # 添加池化层
model.add(Dropout(0.25))                          # 添加dropout层model.add(Conv2D(32, (3, 3), activation='relu'))  # 添加卷积层
model.add(MaxPooling2D(pool_size=(2, 2)))         # 添加池化层
model.add(Dropout(0.25))                          # 添加dropout层....   # 添加其他卷积操作model.add(Flatten())                            # 拉平三维数组为2维数组
model.add(Dense(256, activation='relu'))        添加普通的全连接层
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))....  # 训练网络

LSTM网络

当我们在网络上搜索看LSTM结构的时候,看最多的是下面这张图:

RNN网络

这是RNN循环神经网络经典的结构图,LSTM只是对隐含层节点A做了改进,整体结构不变,因此本文讨论的也是这个结构的可视化问题。

中间的A节点隐含层,左边是表示只有一层隐含层的LSTM网络,所谓LSTM循环神经网络就是在时间轴上的循环利用,在时间轴上展开后得到右图。

看左图,很多同学以为LSTM是单输入、单输出,只有一个隐含神经元的网络结构,看右图,以为LSTM是多输入、多输出,有多个隐含神经元的网络结构,A的数量就是隐含层节点数量。

WTH?思维转不过来啊。这就是传统网络和空间结构的思维。

实际上,右图中,我们看Xt表示序列,下标t是时间轴,所以,A的数量表示的是时间轴的长度,是同一个神经元在不同时刻的状态(Ht),不是隐含层神经元个数。

我们知道,LSTM网络在训练时会使用上一时刻的信息,加上本次时刻的输入信息来共同训练。

举个简单的例子:在第一天我生病了(初始状态H0),然后吃药(利用输入信息X1训练网络),第二天好转但是没有完全好(H1),再吃药(X2),病情得到好转(H2),如此循环往复知道病情好转。因此,输入Xt是吃药,时间轴T是吃多天的药,隐含层状态是病情状况。因此我还是我,只是不同状态的我。

实际上,LSTM的网络是这样的:

LSTM网络结构

上面的图表示包含2个隐含层的LSTM网络,在T=1时刻看,它是一个普通的BP网络,在T=2时刻看也是一个普通的BP网络,只是沿时间轴展开后,T=1训练的隐含层信息H,C会被传递到下一个时刻T=2,如下图所示。上图中向右的五个常常的箭头,所的也是隐含层状态在时间轴上的传递。

注意,图中H表示隐藏层状态,C是遗忘门,后面会讲解它们的维度。

LSTM的输入结构

为了更好理解LSTM结构,还必须理解LSTM的数据输入情况。仿照3通道图像的样子,在加上时间轴后的多样本的多特征的不同时刻的数据立方体如下图所示:

三维数据立方体

右边的图是我们常见模型的输入,比如XGBOOST,lightGBM,决策树等模型,输入的数据格式都是这种(N*F)的矩阵,而左边是加上时间轴后的数据立方体,也就是时间轴上的切片,它的维度是(N*T*F),第一维度是样本数,第二维度是时间,第三维度是特征数,如下图所示:

这样的数据立方体很多,比如天气预报数据,把样本理解成城市,时间轴是日期,特征是天气相关的降雨风速PM2.5等,这个数据立方体就很好理解了。在NLP里面,一句话会被embedding成一个矩阵,词与词的顺序是时间轴T,索引多个句子的embedding三维矩阵如下图所示:

pytorch中的LSTM

4.1 pytorch中定义的LSTM模型

pytorch中定义的LSTM模型的参数如下

class torch.nn.LSTM(*args, **kwargs)
参数有:input_size:x的特征维度hidden_size:隐藏层的特征维度num_layers:lstm隐层的层数,默认为1bias:False则bihbih=0和bhhbhh=0. 默认为Truebatch_first:True则输入输出的数据格式为 (batch, seq, feature)dropout:除最后一层,每一层的输出都进行dropout,默认为: 0bidirectional:True则为双向lstm默认为False

结合前面的图形,我们一个个看。

(1)input_size:x的特征维度,就是数据立方体中的F,在NLP中就是一个词被embedding后的向量长度,如下图所示:

(2)hidden_size:隐藏层的特征维度(隐藏层神经元个数),如下图所示,我们有两个隐含层,每个隐藏层的特征维度都是5。注意,非双向LSTM的输出维度等于隐藏层的特征维度。

(3)num_layers:lstm隐层的层数,上面的图我们定义了2个隐藏层。

(4)batch_first:用于定义输入输出维度,后面再讲。

(5)bidirectional:是否是双向循环神经网络,如下图是一个双向循环神经网络,因此在使用双向LSTM的时候我需要特别注意,正向传播的时候有(Ht, Ct),反向传播也有(Ht', Ct'),前面我们说了非双向LSTM的输出维度等于隐藏层的特征维度,而双向LSTM的输出维度是隐含层特征数*2,而且H,C的维度是时间轴长度*2。

4.2 喂给LSTM的数据格式

pytorch中LSTM的输入数据格式默认如下:

input(seq_len, batch, input_size)
参数有:seq_len:序列长度,在NLP中就是句子长度,一般都会用pad_sequence补齐长度batch:每次喂给网络的数据条数,在NLP中就是一次喂给网络多少个句子input_size:特征维度,和前面定义网络结构的input_size一致。

前面也说到,如果LSTM的参数 batch_first=True,则要求输入的格式是:

input(batch, seq_len, input_size)

刚好调换前面两个参数的位置。其实这是比较好理解的数据形式,下面以NLP中的embedding向量说明如何构造LSTM的输入。

之前我们的embedding矩阵如下图:

如果把batch放在第一位,则三维矩阵的形式如下:

其转换过程如下图所示:

看懂了吗,这就是输入数据的格式,是不是很简单。

LSTM的另外两个输入是 h0 和 c0,可以理解成网络的初始化参数,用随机数生成即可。

h0(num_layers * num_directions, batch, hidden_size)
c0(num_layers * num_directions, batch, hidden_size)
参数:num_layers:隐藏层数num_directions:如果是单向循环网络,则num_directions=1,双向则num_directions=2batch:输入数据的batchhidden_size:隐藏层神经元个数

注意,如果我们定义的input格式是:

input(batch, seq_len, input_size)

则H和C的格式也是要变的:

h0(batc,num_layers * num_directions, h, hidden_size)
c0(batc,num_layers * num_directions, h, hidden_size)

4.3 LSTM的output格式

LSTM的输出是一个tuple,如下:

output,(ht, ct) = net(input)output: 最后一个状态的隐藏层的神经元输出ht:最后一个状态的隐含层的状态值ct:最后一个状态的隐含层的遗忘门值

output的默认维度是:

output(seq_len, batch, hidden_size * num_directions)
ht(num_layers * num_directions, batch, hidden_size)
ct(num_layers * num_directions, batch, hidden_size)

和input的情况类似,如果我们前面定义的input格式是:

input(batch, seq_len, input_size)

则ht和ct的格式也是要变的:

ht(batc,num_layers * num_directions, h, hidden_size)
ct(batc,num_layers * num_directions, h, hidden_size)

说了这么多,我们回过头来看看ht和ct在哪里,请看下图:

output在哪里?请看下图:

LSTM和其他网络组合

还记得吗,output的维度等于隐藏层神经元的个数,即hidden_size,在一些时间序列的预测中,会在output后,接上一个全连接层,全连接层的输入维度等于LSTM的hidden_size,之后的网络处理就和BP网络相同了,如下图:

用pytorch实现上面的结构:

import torch
from torch import nnclass RegLSTM(nn.Module):def __init__(self):super(RegLSTM, self).__init__()# 定义LSTMself.rnn = nn.LSTM(input_size, hidden_size, hidden_num_layers)# 定义回归层网络,输入的特征维度等于LSTM的输出,输出维度为1self.reg = nn.Sequential(nn.Linear(hidden_size, 1))def forward(self, x):x, (ht,ct) = self.rnn(x)seq_len, batch_size, hidden_size= x.shapex = y.view(-1, hidden_size)x = self.reg(x)x = x.view(seq_len, batch_size, -1)return x

当然,有些模型则是将输出当做另一个LSTM的输入,或者使用隐藏层ht,ct的信息进行建模,不一而足。

好了,以上就是我对LSTM的一些学习心得,看完记得关注点赞。

参考链接:

https://zhuanlan.zhihu.com/p/94757947

https://zhuanlan.zhihu.com/p/59862381

https://zhuanlan.zhihu.com/p/36455374

https://www.zhihu.com/question/41949741/answer/318771336

https://blog.csdn.net/android_ruben/article/details/80206792

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲

在「小白学视觉」公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲

在「小白学视觉」公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~

漂亮,LSTM模型结构的可视化相关推荐

  1. 收藏 | LSTM模型结构的可视化

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:知乎  作者 | master苏 链接 | https:// ...

  2. Pytorch实现的LSTM模型结构

    LSTM模型结构 1.LSTM模型结构 2.LSTM网络 3.LSTM的输入结构 4.Pytorch中的LSTM 4.1.pytorch中定义的LSTM模型 4.2.喂给LSTM的数据格式 4.3.L ...

  3. lstm结构图_LSTM模型结构的可视化

    目录: 1.传统的BP网络和CNN网络 2.LSTM网络 3.LSTM的输入结构 4.pytorch中的LSTM 4.1 pytorch中定义的LSTM模型 4.2 喂给LSTM的数据格式 4.3 L ...

  4. LSTM模型结构讲解

    人类并不是每时每刻都从一片空白的大脑开始他们的思考.在你阅读这篇文章时候,你都是基于自己已经拥有的对先前所见词的理解来推断当前词的真实含义.我们不会将所有的东西都全部丢弃,然后用空白的大脑进行思考.我 ...

  5. LSTM模型与前向反向传播算法

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 前  言 在循环神经网络(RNN)模型与前向反向传播算法中,我们总 ...

  6. Python中利用LSTM模型进行时间序列预测分析

    时间序列模型 时间序列预测分析就是利用过去一段时间内某事件时间的特征来预测未来一段时间内该事件的特征.这是一类相对比较复杂的预测建模问题,和回归分析模型的预测不同,时间序列模型是依赖于事件发生的先后顺 ...

  7. Pytorch LSTM模型 参数详解

    本文主要依据 Pytorch 中LSTM官方文档,对其中的模型参数.输入.输出进行详细解释. 目录 基本原理 模型参数 Parameters 输入Inputs: input, (h_0, c_0) 输 ...

  8. 【深度学习入门到精通系列】模型结构可视化神器Netron(连.pth都可以~!)

    文章目录 1 Netron 2 说明 1 Netron 目前的Netron支持主流各种框架的模型结构可视化工作,我直接给出gayhub链接: https://github.com/lutzroeder ...

  9. NLP-阅读理解-2015:MRC模型-指导机器去阅读并理解【开篇之作】【完形填空任务】【第一次构建大批量有监督机器阅读理解训练语料】【三种模型结构:LSTM、Attention、Impatient】

    <原始论文:Teaching Machines to Read and Comprehend> 作者想要研究的问题是什么?一一在当下神经网络迅速发展的时代,如何针对机器阅读理解提出一个网络 ...

最新文章

  1. 中psr_PSR-SX900测评:雅马哈升级幅度较大的高品质编曲键盘
  2. Spreadsheet Tracking
  3. 3.4 SE55表维护生成器
  4. #pragma once 和 #ifndef ... #define ... #endif 的区别【转载】
  5. 8款惊艳的HTML5粒子动画特效
  6. DevExpress gridcontrol添加了复选框删除选中的多行/批量删除的方法
  7. 期刊计算机仿真地址在哪,计算机仿真杂志社地址
  8. 函数在区间连续可以推出什么_A-22 函数的点连续、单侧连续、区间连续
  9. MySQL检测 explain解析
  10. CentOS首次安装,网络环境配置
  11. 医疗知识图谱NLP项目,实体规模4.4万,实体关系规模30万
  12. python模拟登录人人
  13. 161212 笔记--无线传感网络中的MAC协议
  14. 详细讲解半加器、全加器、四位全加器,并使用FPGA实现半加器、全加器
  15. Nodejs+MongoDB+WebRTC搭建视频通话协同应用
  16. 最新!SPDK宣布在NVMe-oF Fabrics中支持TCP transport
  17. vscode json文件编辑工具
  18. python根据x轴、y轴坐标在坐标轴里画出曲线图
  19. 编程练习:既是完全平方数又有两位数字相同的三位数
  20. 图像滤镜处理算法:灰度、黑白、底片、浮雕

热门文章

  1. 首次!腾讯全面公开整体开源路线图
  2. 如何为回归问题选择最合适的机器学习方法?
  3. UC伯克利开源照片“隐写术”StegaStamp,打印照片能当二维码用!| 技术头条
  4. Grid R-CNN解读:商汤最新目标检测算法,定位精度超越Faster R-CNN
  5. 6月机器学习热文TOP10,精选自1400篇文章
  6. Spring Boot 实现接口幂等性的 4 种方案!还有谁不会?
  7. 字节跳动一面:i++ 是线程安全的吗?
  8. 目标检测模型从训练到部署!
  9. 本硕非科班,单模型获得亚军!
  10. 现金奖励+实习offer!数据库大赛来了