最近想了解一些关于LSTM的相关知识,在进行代码测试的时候,有个地方一直比较疑惑,关于LSTM的输入和输出问题。一直不清楚在pytorch里面该如何定义LSTM的输入和输出。首先看个pytorch官方的例子:

# 首先导入LSTM需要的相关模块

import torch

import torch.nn as nn # 神经网络模块

# 数据向量维数10, 隐藏元维度20, 2个LSTM层串联(如果是1,可以省略,默认为1)

rnn = nn.LSTM(10, 20, 2)

# 序列长度seq_len=5, batch_size=3, 数据向量维数=10

input = torch.randn(5, 3, 10)

# 初始化的隐藏元和记忆元,通常它们的维度是一样的

# 2个LSTM层,batch_size=3,隐藏元维度20

h0 = torch.randn(2, 3, 20)

c0 = torch.randn(2, 3, 20)

# 这里有2层lstm,output是最后一层lstm的每个词向量对应隐藏层的输出,其与层数无关,只与序列长度相关

# hn,cn是所有层最后一个隐藏元和记忆元的输出

output, (hn, cn) = rnn(input, (h0, c0))

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

# 首先导入LSTM需要的相关模块

importtorch

importtorch.nnasnn# 神经网络模块

# 数据向量维数10, 隐藏元维度20, 2个LSTM层串联(如果是1,可以省略,默认为1)

rnn=nn.LSTM(10,20,2)

# 序列长度seq_len=5, batch_size=3, 数据向量维数=10

input=torch.randn(5,3,10)

# 初始化的隐藏元和记忆元,通常它们的维度是一样的

# 2个LSTM层,batch_size=3,隐藏元维度20

h0=torch.randn(2,3,20)

c0=torch.randn(2,3,20)

# 这里有2层lstm,output是最后一层lstm的每个词向量对应隐藏层的输出,其与层数无关,只与序列长度相关

# hn,cn是所有层最后一个隐藏元和记忆元的输出

output,(hn,cn)=rnn(input,(h0,c0))

在这里如果我们打印output、hn、cn的shape,我们可以看到,torch的输出已经变成了定义中的20。

print(output.size(),hn.size(),cn.size())

torch.Size([5, 3, 20]) torch.Size([2, 3, 20]) torch.Size([2, 3, 20])

1

2

print(output.size(),hn.size(),cn.size())

torch.Size([5,3,20])torch.Size([2,3,20])torch.Size([2,3,20])

接着来看一下LSTM的参数都有哪些:

LSTM一共有7个参数,其中前三个是必须的,分别为:input_size, hidden_size, num_layers.

1 input_size

在这里首先对输入解释一下,nn.LSTM()的第一个参数为输入的序列维度,它对应着torch.randn()中的第三个参数10。可能有人不太明白这个这个函数是怎么回事,在这里解释一下:

torch.randn(5, 3, 10)会生成五组数据,每组数据有3行10列。如果用在视频中的话,这里的5等于每个视频抽取的帧数,如果视频分辨率为100*100,则第二个参数为10000,若视频为彩色三通道的话,第三个参数为3,即输入序列变为(5,10000,3),看一下这张图:

上图是一个完整的LSTM流程,上面生成的五组数据就对应了五个A,即一个LSTM中有五个神经元。

再举个例子,比如现在有5个句子,每个句子由3个单词组成,每个单词用10维的向量组成,这样参数为:seq_len=3, batch=5, input_size=10.

输入LSTM中的X数据格式尺寸为(seq_len, batch, input_size),此外h0和c0尺寸如下

h0(num_layers * num_directions, batch_size, hidden_size)

c0(num_layers * num_directions, batch_size, hidden_size)

2 hidden_size

对照上图可以看出,隐藏层数即为中间的节点数量。这个数量可以由用户自定义。

3 num_layers

这个是LSTM的层数,默认是1,如果我们设置为2的话,第一层计算得到h,然后把h作为输入,输给第二层。然后在最后输出最终的O。

4 bias

表示是否添加bias偏置,默认为true

5 batch_first

与LSTM的输入格式有关。

输入输出的第一维是否为 batch_size,默认值 False。因为 Torch 中,人们习惯使用Torch中带有的dataset,dataloader向神经网络模型连续输入数据,这里面就有一个 batch_size 的参数,表示一次输入多少个数据。 在 LSTM 模型中,输入数据必须是一批数据,为了区分LSTM中的批量数据和dataloader中的批量数据是否相同意义,LSTM 模型就通过这个参数的设定来区分。 如果是相同意义的,就设置为True,如果不同意义的,设置为False。 torch.LSTM 中 batch_size 维度默认是放在第二维度,故此参数设置可以将 batch_size 放在第一维度。如:input 默认是(4,1,5),中间的 1 是 batch_size,指定batch_first=True后就是(1,4,5)。所以,如果你的输入数据是二维数据的话,就应该将 batch_first 设置为True;

6 dropout

是否进行dropout操作,默认为0,输入值范围为0~1的小数,表示每次丢弃的百分比。一般用来防止过拟合。

7 bidirectional

是否进行双向RNN,默认为false。

运行模型:

运行模型的格式是这样写的。output, (hn, cn) = model(input, (h0, c0))

从形式上看,输入结构和输出结构是一样的。都是3个输入,3个输出。

参数1:你输入的数据团。好像必须是 3 维数据。但必须注意 batch_size 的位置。是第一维,还是第二维。默认是在第二维度。是不可变的维度。最后一个维度是行数据的个数。剩下的1维数据是可变的,这就是长短数据。默认放在第一维。

参数2:隐藏层数据,也必须是3维的,第一维:是LSTM的层数,第二维:是隐藏层的batch_size数,必须和输入数据的batch_size一致。第三维:是隐藏层节点数,必须和模型实例时的参数一致。

参数3:传递层数据,也必须是3维的,通常和参数2的设置一样。它的作用是LSTM内部循环中的记忆体,用来结合新的输入一起计算。

本文最后更新于2019年11月11日,已超过 1 年没有更新,如果文章内容或图片资源失效,请留言反馈,我们会及时处理,谢谢!

lstm 输入数据维度_[mcj]pytorch中LSTM的输入输出解释||LSTM输入输出详解相关推荐

  1. lstm 输入数据维度_理解Pytorch中LSTM的输入输出参数含义

    本文不会介绍LSTM的原理,具体可看如下两篇文章 Understanding LSTM Networks DeepLearning.ai学习笔记(五)序列模型 -- week1 循环序列模型 1.举个 ...

  2. python中的class怎么用_对python 中class与变量的使用方法详解

    python中的变量定义是很灵活的,很容易搞混淆,特别是对于class的变量的定义,如何定义使用类里的变量是我们维护代码和保证代码稳定性的关键. #!/usr/bin/python #encoding ...

  3. python中class变量_对python 中class与变量的使用方法详解

    python中的变量定义是很灵活的,很容易搞混淆,特别是对于class的变量的定义,如何定义使用类里的变量是我们维护代码和保证代码稳定性的关键. #!/usr/bin/python #encoding ...

  4. python中if语句的实例_对python中if语句的真假判断实例详解

    说明 在python中,if作为条件语句,当if后面的条件参数为真时,则执行后面的语句块,反之跳过,为了深入理解if语句,我们需要知道if语句的真假判断方式. 示例 在python交互器中,经过测试发 ...

  5. pythonbool类型数组生成_对numpy中布尔型数组的处理方法详解

    布尔数组的操作方式主要有两种,any用于查看数组中是否有True的值,而all则用于查看数组是否全都是True. 如果用于计算的时候,布尔量会被转换成1和0,True转换成1,False转换成0.通过 ...

  6. python中append函数解析_对python中的pop函数和append函数详解

    对python中的pop函数和append函数详解 pop()函数 1.描述 pop() 函数用于移除列表中的一个元素(默认最后一个元素),并且返回该元素的值. 语法 pop()方法语法: list. ...

  7. python中的pop函数和append函数_对python中的pop函数和append函数详解

    pop()函数 1.描述 pop() 函数用于移除列表中的一个元素(默认最后一个元素),并且返回该元素的值. 语法 pop()方法语法: list.pop(obj=list[-1]) 2.参数 obj ...

  8. php回调函数和匿名函数吗,php回调函数_关于php中匿名函数与回调函数的详解

    摘要 腾兴网为您分享:关于php中匿名函数与回调函数的详解,壹学车,小天才,尚游戏,厦门百姓等软件知识,以及微信一键转发工具,幸运抽奖系统,文字识别app,垃圾清理管家,王者荣耀起名神器,叮咚出行,世 ...

  9. python布尔型数组_对numpy中布尔型数组的处理方法详解

    布尔数组的操作方式主要有两种,any用于查看数组中是否有True的值,而all则用于查看数组是否全都是True. 如果用于计算的时候,布尔量会被转换成1和0,True转换成1,False转换成0.通过 ...

最新文章

  1. Jsp获得Map中map.put(2, bb);此类的value值
  2. python如何定义类_python类定义的讲解
  3. HA集群实现原理 切换 JAVA_HA(一)高可用集群原理
  4. HTML元素和标签的区别
  5. 史上最贵!iPhone 12S系列9月亮相,全系标配激光雷达
  6. 如何手动合成年度夜间灯光影像
  7. matlab的简单使用-matlab画f(x)=x^2+y^2的图像
  8. 联想服务器ts系列介绍,联想服务器ThinkServerTS230.ppt
  9. PDF密码可以破解吗?有没有PDF解密的方法
  10. 平行束滤波fbp_CT平行束和扇形束算法的转换.pptx
  11. MRR(Mean Reciprocal Rank)笔记
  12. 一款仿古文本编辑器---edit.exe
  13. matlab 数据分割,科学网—MATLAB把一个包含多个站点数据的文件分割到各个站点单独的文件夹 - 张乐乐的博文...
  14. C++ POCO库(访问数据库,版本问题,本人配置失败)
  15. error: OpenCV(4.1.2) ..\modules\imgcodecs\src\loadsave.cpp:715: error: (-215:Assertion failed) !_img
  16. SAP:在互联网时代帮助企业夺回数据
  17. 简单的局域网直播方案(OBS+Smart_rtmpd)
  18. xpath 定位同级倒数第二个元素
  19. 利用Python及OpenCv 识别车牌号
  20. 14个提高代码质量的好问题

热门文章

  1. 梅科尔工作室-赵凌志-鸿蒙笔记1
  2. 汇编语言求一组数中的最大值,最小值和总和(以10个数为例)
  3. 软件工程专业测试,软件工程专业测试科目组成及分值情况.PDF
  4. djay Pro AI Mac(DJ混音软件)
  5. matlab故障识别,基于MATLAB故障诊断系统设计.doc
  6. 【ESP32调试-快速入门】
  7. GitHub上受欢迎的Android UI Library-项目开发实战篇:带各类框架链接地址详细解说及使用方法
  8. java stroke_stroke用法
  9. 29 岁成为阿里巴巴P8,工作前5年完成晋升3连跳,他如何做到?
  10. 青春激扬,创意无限——记美和易思特色班软件设计大赛