pytorch nn.LSTM()参数详解
输入数据格式:
input(seq_len, batch, input_size)
h0(num_layers * num_directions, batch, hidden_size)
c0(num_layers * num_directions, batch, hidden_size)
输出数据格式:
output(seq_len, batch, hidden_size * num_directions)
hn(num_layers * num_directions, batch, hidden_size)
cn(num_layers * num_directions, batch, hidden_size)
import torch
import torch.nn as nn
from torch.autograd import Variable
#构建网络模型---输入矩阵特征数input_size、输出矩阵特征数hidden_size、层数num_layers
inputs = torch.randn(5,3,10) ->(seq_len,batch_size,input_size)
rnn = nn.LSTM(10,20,2) -> (input_size,hidden_size,num_layers)
h0 = torch.randn(2,3,20) ->(num_layers* 1,batch_size,hidden_size)
c0 = torch.randn(2,3,20) ->(num_layers*1,batch_size,hidden_size)
num_directions=1 因为是单向LSTM
'''
Outputs: output, (h_n, c_n)
'''
output,(hn,cn) = rnn(inputs,(h0,c0))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
batch_first: 输入输出的第一维是否为 batch_size,默认值 False。因为 Torch 中,人们习惯使用Torch中带有的dataset,dataloader向神经网络模型连续输入数据,这里面就有一个 batch_size 的参数,表示一次输入多少个数据。 在 LSTM 模型中,输入数据必须是一批数据,为了区分LSTM中的批量数据和dataloader中的批量数据是否相同意义,LSTM 模型就通过这个参数的设定来区分。 如果是相同意义的,就设置为True,如果不同意义的,设置为False。 torch.LSTM 中 batch_size 维度默认是放在第二维度,故此参数设置可以将 batch_size 放在第一维度。如:input 默认是(4,1,5),中间的 1 是 batch_size,指定batch_first=True后就是(1,4,5)。所以,如果你的输入数据是二维数据的话,就应该将 batch_first 设置为True;
inputs = torch.randn(5,3,10) :seq_len=5,bitch_size=3,input_size=10
我的理解:有3个句子,每个句子5个单词,每个单词用10维的向量表示;而句子的长度是不一样的,所以seq_len可长可短,这也是LSTM可以解决长短序列的特殊之处。只有seq_len这一参数是可变的。
关于hn和cn一些参数的详解看这里
而在遇到文本长度不一致的情况下,将数据输入到模型前的特征工程会将同一个batch内的文本进行padding使其长度对齐。但是对齐的数据在单向LSTM甚至双向LSTM的时候有一个问题,LSTM会处理很多无意义的填充字符,这样会对模型有一定的偏差,这时候就需要用到函数torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()
详情解释看这里
BiLSTM
BILSTM是双向LSTM;将前向的LSTM与后向的LSTM结合成LSTM。视图举例如下:
LSTM结构推导:
更详细公式推导https://blog.csdn.net/songhk0209/article/details/71134698
GRU公式推导:(网上的图看着有点费劲,就自己画了个数据流图)
---------------------
作者:向阳争渡
来源:CSDN
原文:https://blog.csdn.net/yangyang_yangqi/article/details/84585998
版权声明:本文为博主原创文章,转载请附上博文链接!
pytorch nn.LSTM()参数详解相关推荐
- tf.nn.softmax参数详解以及作用
tf.nn.softmax参数详解以及作用 参考地址:https://zhuanlan.zhihu.com/p/93054123 tf.nn.softmax(logits,axis=None,name ...
- [pytorch]yolov3.cfg参数详解(每层输出及route、yolo、shortcut层详解)
文章目录 Backbone(Darknet53) 第一次下采样(to 208) 第二次下采样(to 104) 第三次下采样(to 52) 第四次下采样(to 26) 第五次下采样(to 13) YOL ...
- PyTorch nn.GRU 使用详解
我们看官方文档一些参数介绍,以及如下一个简单例子: 看完之后,还是一脸懵逼: 输入什么鬼? 输出又什么鬼? (这里我先把官网中 h0 去掉了,便于大家先理解更重要的概念) import torch f ...
- pytorch教程之nn.Module类详解——使用Module类来自定义网络层
前言:前面介绍了如何自定义一个模型--通过继承nn.Module类来实现,在__init__构造函数中申明各个层的定义,在forward中实现层之间的连接关系,实际上就是前向传播的过程. 事实上,在p ...
- PyTorch实现AlexNet模型及参数详解
文章目录 一.卷积池化层原理 二.全连接层原理 三.模型参数详解 注:AlexNet论文错误点 1.卷积池化层1 (1)卷积运算 (2)分组 (3)激活函数层 (4)池化层 (5)归一化处理 (6)参 ...
- 【PyTorch】nn.Conv2d函数详解
文章目录 1. 函数语法格式 2. 参数解释 3. 尺寸关系 4. 使用案例 5. nn.functional.conv2d 1. 函数语法格式 CONV2D官方链接 torch.nn.Conv2d( ...
- pytorch MSELoss参数详解
pytorch MSELoss参数详解 import torch import numpy as np loss_fn = torch.nn.MSELoss(reduce=False, size_av ...
- pytorch之torch.nn.Conv2d()函数详解
文章目录 一.官方文档介绍 二.torch.nn.Conv2d()函数详解 参数详解 参数dilation--扩张卷积(也叫空洞卷积) 参数groups--分组卷积 三.代码实例 一.官方文档介绍 官 ...
- pytorch教程之nn.Module类详解——使用Module类来自定义模型
pytorch教程之nn.Module类详解--使用Module类来自定义模型_MIss-Y的博客-CSDN博客_nn是什么意思前言:pytorch中对于一般的序列模型,直接使用torch.nn.Se ...
最新文章
- 0x54. 动态规划 - 树形DP(习题详解 × 12)
- JSON 之 SuperObject(16): 实例 - 解析 Google 关键字搜索排名
- 关于Hyper-V备份的四大注意事项
- 高质量C /C编程指南---第1章 文件机关
- 斯坦福大学的机器学习课程,浓缩成6张速查表
- html+link+点击次数,使用正则表达式,取得点击次数,函数抽离(示例代码)
- python天下无敌表情包_这套打遍天下无敌手的“算我输”表情包 从哪儿蹦出来的?...
- AJAX框架眼镜穿搭夏天,30度的夏天,男生应该如何穿搭?看这9种时尚组合!
- 迁移操作系统:如何把系统迁移到固态硬盘SSD?
- 1.3 欠/过拟合,局部加权回归(Loess/LWR)及Python实现(基于随机梯度下降)
- RAR Extractor - The Unarchiver Pro for mac(解压缩软件)
- 静态代码检查-CheckStyle
- 1、Basic4android简介
- win7锁定计算机自动关机,Win7电脑老是自动关机怎么解决?
- [开源] OpWeb 框架 --快速高效的实时交互框架(更新至 0.0.4.0)
- 基于STM32的智能风扇的制作
- 程序员真人秀又来了!呼兰当主持挑灯狂补知识,SSS大佬本科竟是药学,清华朱军张敏等加入导师团...
- 如何解决 【eclipse】中注释时乱码的问题
- 网站接入第三方微博登录—PHP
- 飞思卡尔智能车—电磁循迹(节能组)
热门文章
- centos 非root用户(普通用户)替换yum安装软件方法
- MySQL中对varchar类型排序问题的解决
- 方差协方差以及协方差矩阵
- Tomcat 1099端口占用重启无效,查不到进程,改换端口无效解决方案
- 推荐系统的个性化排名
- FCN与U-Net语义分割算法
- 从单一图像中提取文档图像:ICCV2019论文解读
- 解决:Plugin ‘maven-compiler-plugin:3.1‘ not found
- 2021年大数据常用语言Scala(二十五):函数式编程 排序
- DCN-2655 ssh 远程登陆配置