输入数据格式:
input(seq_len, batch, input_size)
h0(num_layers * num_directions, batch, hidden_size)
c0(num_layers * num_directions, batch, hidden_size)

输出数据格式:
output(seq_len, batch, hidden_size * num_directions)
hn(num_layers * num_directions, batch, hidden_size)
cn(num_layers * num_directions, batch, hidden_size)

import torch
import torch.nn as nn
from torch.autograd import Variable

#构建网络模型---输入矩阵特征数input_size、输出矩阵特征数hidden_size、层数num_layers
inputs = torch.randn(5,3,10) ->(seq_len,batch_size,input_size)
rnn = nn.LSTM(10,20,2) -> (input_size,hidden_size,num_layers)
h0 = torch.randn(2,3,20) ->(num_layers* 1,batch_size,hidden_size)
c0 = torch.randn(2,3,20) ->(num_layers*1,batch_size,hidden_size)
num_directions=1 因为是单向LSTM
'''
Outputs: output, (h_n, c_n)
'''
output,(hn,cn) = rnn(inputs,(h0,c0))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
batch_first: 输入输出的第一维是否为 batch_size,默认值 False。因为 Torch 中,人们习惯使用Torch中带有的dataset,dataloader向神经网络模型连续输入数据,这里面就有一个 batch_size 的参数,表示一次输入多少个数据。 在 LSTM 模型中,输入数据必须是一批数据,为了区分LSTM中的批量数据和dataloader中的批量数据是否相同意义,LSTM 模型就通过这个参数的设定来区分。 如果是相同意义的,就设置为True,如果不同意义的,设置为False。 torch.LSTM 中 batch_size 维度默认是放在第二维度,故此参数设置可以将 batch_size 放在第一维度。如:input 默认是(4,1,5),中间的 1 是 batch_size,指定batch_first=True后就是(1,4,5)。所以,如果你的输入数据是二维数据的话,就应该将 batch_first 设置为True;

inputs = torch.randn(5,3,10) :seq_len=5,bitch_size=3,input_size=10
我的理解:有3个句子,每个句子5个单词,每个单词用10维的向量表示;而句子的长度是不一样的,所以seq_len可长可短,这也是LSTM可以解决长短序列的特殊之处。只有seq_len这一参数是可变的。
关于hn和cn一些参数的详解看这里
而在遇到文本长度不一致的情况下,将数据输入到模型前的特征工程会将同一个batch内的文本进行padding使其长度对齐。但是对齐的数据在单向LSTM甚至双向LSTM的时候有一个问题,LSTM会处理很多无意义的填充字符,这样会对模型有一定的偏差,这时候就需要用到函数torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()
详情解释看这里

BiLSTM
BILSTM是双向LSTM;将前向的LSTM与后向的LSTM结合成LSTM。视图举例如下:

​​​​​​​​​​​​LSTM结构推导:

更详细公式推导https://blog.csdn.net/songhk0209/article/details/71134698

GRU公式推导:(网上的图看着有点费劲,就自己画了个数据流图)

---------------------
作者:向阳争渡
来源:CSDN
原文:https://blog.csdn.net/yangyang_yangqi/article/details/84585998
版权声明:本文为博主原创文章,转载请附上博文链接!

pytorch nn.LSTM()参数详解相关推荐

  1. tf.nn.softmax参数详解以及作用

    tf.nn.softmax参数详解以及作用 参考地址:https://zhuanlan.zhihu.com/p/93054123 tf.nn.softmax(logits,axis=None,name ...

  2. [pytorch]yolov3.cfg参数详解(每层输出及route、yolo、shortcut层详解)

    文章目录 Backbone(Darknet53) 第一次下采样(to 208) 第二次下采样(to 104) 第三次下采样(to 52) 第四次下采样(to 26) 第五次下采样(to 13) YOL ...

  3. PyTorch nn.GRU 使用详解

    我们看官方文档一些参数介绍,以及如下一个简单例子: 看完之后,还是一脸懵逼: 输入什么鬼? 输出又什么鬼? (这里我先把官网中 h0 去掉了,便于大家先理解更重要的概念) import torch f ...

  4. pytorch教程之nn.Module类详解——使用Module类来自定义网络层

    前言:前面介绍了如何自定义一个模型--通过继承nn.Module类来实现,在__init__构造函数中申明各个层的定义,在forward中实现层之间的连接关系,实际上就是前向传播的过程. 事实上,在p ...

  5. PyTorch实现AlexNet模型及参数详解

    文章目录 一.卷积池化层原理 二.全连接层原理 三.模型参数详解 注:AlexNet论文错误点 1.卷积池化层1 (1)卷积运算 (2)分组 (3)激活函数层 (4)池化层 (5)归一化处理 (6)参 ...

  6. 【PyTorch】nn.Conv2d函数详解

    文章目录 1. 函数语法格式 2. 参数解释 3. 尺寸关系 4. 使用案例 5. nn.functional.conv2d 1. 函数语法格式 CONV2D官方链接 torch.nn.Conv2d( ...

  7. pytorch MSELoss参数详解

    pytorch MSELoss参数详解 import torch import numpy as np loss_fn = torch.nn.MSELoss(reduce=False, size_av ...

  8. pytorch之torch.nn.Conv2d()函数详解

    文章目录 一.官方文档介绍 二.torch.nn.Conv2d()函数详解 参数详解 参数dilation--扩张卷积(也叫空洞卷积) 参数groups--分组卷积 三.代码实例 一.官方文档介绍 官 ...

  9. pytorch教程之nn.Module类详解——使用Module类来自定义模型

    pytorch教程之nn.Module类详解--使用Module类来自定义模型_MIss-Y的博客-CSDN博客_nn是什么意思前言:pytorch中对于一般的序列模型,直接使用torch.nn.Se ...

最新文章

  1. 0x54. 动态规划 - 树形DP(习题详解 × 12)
  2. JSON 之 SuperObject(16): 实例 - 解析 Google 关键字搜索排名
  3. 关于Hyper-V备份的四大注意事项
  4. 高质量C /C编程指南---第1章 文件机关
  5. 斯坦福大学的机器学习课程,浓缩成6张速查表
  6. html+link+点击次数,使用正则表达式,取得点击次数,函数抽离(示例代码)
  7. python天下无敌表情包_这套打遍天下无敌手的“算我输”表情包 从哪儿蹦出来的?...
  8. AJAX框架眼镜穿搭夏天,30度的夏天,男生应该如何穿搭?看这9种时尚组合!
  9. 迁移操作系统:如何把系统迁移到固态硬盘SSD?
  10. 1.3 欠/过拟合,局部加权回归(Loess/LWR)及Python实现(基于随机梯度下降)
  11. RAR Extractor - The Unarchiver Pro for mac(解压缩软件)
  12. 静态代码检查-CheckStyle
  13. 1、Basic4android简介
  14. win7锁定计算机自动关机,Win7电脑老是自动关机怎么解决?
  15. [开源] OpWeb 框架 --快速高效的实时交互框架(更新至 0.0.4.0)
  16. 基于STM32的智能风扇的制作
  17. 程序员真人秀又来了!呼兰当主持挑灯狂补知识,SSS大佬本科竟是药学,清华朱军张敏等加入导师团...
  18. 如何解决 【eclipse】中注释时乱码的问题
  19. 网站接入第三方微博登录—PHP
  20. 飞思卡尔智能车—电磁循迹(节能组)

热门文章

  1. centos 非root用户(普通用户)替换yum安装软件方法
  2. MySQL中对varchar类型排序问题的解决
  3. 方差协方差以及协方差矩阵
  4. Tomcat 1099端口占用重启无效,查不到进程,改换端口无效解决方案
  5. 推荐系统的个性化排名
  6. FCN与U-Net语义分割算法
  7. 从单一图像中提取文档图像:ICCV2019论文解读
  8. 解决:Plugin ‘maven-compiler-plugin:3.1‘ not found
  9. 2021年大数据常用语言Scala(二十五):函数式编程 排序
  10. DCN-2655 ssh 远程登陆配置