本文主要依据 Pytorch 中LSTM官方文档,对其中的模型参数输入输出进行详细解释。

目录

基本原理

模型参数 Parameters

输入Inputs: input, (h_0, c_0)

输出Outputs: output, (h_n, c_n)

变量Variables

备注


基本原理

首先我们看下面这个LSTM图, 对应于输入时间序列中每个步长的LSTM计算。

对应的公式计算公式如下:

其中  表示时刻  时刻的隐含状态, 表示时刻  上的记忆细胞, 表示时刻  的输入(对应于单个样本),表示隐含层在时刻的隐含状态或者是在起始时间o的初始隐含状态,, , 表示对应的输入、遗忘、 输出门。σ 表示的是sigmoid 函数,⊙ 表示哈达玛积(Hadamard product)。

对于含多个隐含层的LSTM,第  层()的输入​ 则对应的是前一层的隐含状态与丢弃dropout  的乘积,其中每一个 是Bernoulli随机变量(以参数dropout 的概率等于0)。

如果参数指定proj_size > 0,则将对LSTM使用投影。他的运作方式包括以下几步。首先,的维度将从hidden_size 转换为proj_size ( 的维度也会同时被改变)。第二,每一个层的隐含状态输出将与一个(可学习)的投影矩阵相乘: 。注意,这种投影模式同样对LSTM的输出有影响,即变成proj_size.

模型参数 Parameters

  • input_size – 输入变量x的特征数量

  • hidden_size – 隐含层h的特征数量(即层中隐含单元的个数)

  • num_layers – 隐含层的层数,比如说num_layers =2, 意味着这是包含两个LSTM层,默认值:1

  • bias – 如果为False, 表示不使用偏置权重 b_ih 和 b_hh。默认值为:True

  • batch_first – 如果为True,则输入和输出的tensor维度为从(seq, batch, feature)变成 (batch, seq, feature)。 注意,这个维度变化对隐含和细胞的层并不起做用。参见下面的Inputs/Outputs 部分的说明,默认值:False

  • dropout – 如果非0,则会给除最后一个LSTM层以外的其他层引入一个Dropout层,其对应的丢弃概率为dropout,默认值:0

  • bidirectional – 如果为True,则是一个双向的(bidirectional )的LSTM,默认值:False

  • proj_size – 如果>0, 则会使用相应投影大小的LSTM,默认值:0

输入Inputs: input, (h_0, c_0)

  • input:当batch_first = False 时形状为(L,N,H_in),当 batch_first = True 则为(N, L, H_in​) ,包含批量样本的时间序列输入。该输入也可是一个可变换长度的时间序序列,参见 torch.nn.utils.rnn.pack_padded_sequence() 或者是 torch.nn.utils.rnn.pack_sequence() 了解详情。

  • h_0:形状为(D∗num_layers, N, H_out),指的是包含每一个批量样本的初始隐含状态。如果模型未提供(h_0, c_0) ,默认为是全0矩阵。

  • c_0:形状为(D∗num_layers, N, H_cell), 指的是包含每一个批量样本的初始记忆细胞状态。 如果模型未提供(h_0, c_0) ,默认为是全0矩阵。

其中:

N = 批量大小

L  = 序列长度

D = 2 如果模型参数bidirectional = 2,否则为1

H_in = 输入的特征大小(input_size)

H_cell = 隐含单元数量(hidden_size)

H_out = proj_size, 如果proj_size > 0, 否则的话 = 隐含单元数量(hidden_size)

输出Outputs: output, (h_n, c_n)

  • output: 当batch_first = False 形状为(L, N, D∗H_out​) ,当batch_first = True 则为 (N, L, D∗H_out​) ,包含LSTM最后一层每一个时间步长  的输出特征()。如果输入的是torch.nn.utils.rnn.PackedSequence,则输出同样将是一个packed sequence。

  • h_n: 形状为(D∗num_layers, N, H_out​),包括每一个批量样本最后一个时间步的隐含状态。

  • c_n: 形状为(D∗num_layers, N, H_cell​),包括每一个批量样本最后一个时间步的记忆细胞状态。

变量Variables

  • ~LSTM.weight_ih_l[k] – 学习得到的第k层的 input-hidden 权重 (W_ii|W_if|W_ig|W_io),当k=0 时形状为 (4*hidden_size, input_size) 。 否则,形状为 (4*hidden_size, num_directions * hidden_size)

  • ~LSTM.weight_hh_l[k] –学习得到的第k层的 hidden -hidden 权重(W_hi|W_hf|W_hg|W_ho), 想形状为 (4*hidden_size, hidden_size)。如果 Proj_size > 0,则形状为 (4*hidden_size, proj_size)

  • ~LSTM.bias_ih_l[k] – 学习得到的第k层的input-hidden 的偏置 (b_ii|b_if|b_ig|b_io), 形状为 (4*hidden_size)

  • ~LSTM.bias_hh_l[k] – 学习得到的第k层的hidden -hidden 的偏置  (b_hi|b_hf|b_hg|b_ho), 形状为 (4*hidden_size)

  • ~LSTM.weight_hr_l[k] – 学习得到第k层投影权重,形状为 (proj_size, hidden_size)。仅仅在 proj_size > 0 时该参数有效。

备注

  • 所有的权重和偏置的初始化方法均取值于:

  • 对于双向 LSTMs,前向和后向的方向分别为0 和1。当batch_first = False 时,对两个方向的输出层的提取可以使用方式:output.view(seq_len, batch, num_directions, hidden_size)。
  • 具体怎么使用,可以参考本人博文 从零开始实现,LSTM模型进行单变量时间序列预测
  • 关于LSTM模型的结构如果还有不清晰的可以参考这篇博客:Pytorch实现的LSTM模型结构

Pytorch LSTM模型 参数详解相关推荐

  1. 【直播】陈安东,但扬:CNN模型搭建、训练以及LSTM模型思路详解

    CNN模型搭建.训练以及LSTM模型思路详解 目前 Datawhale第24期组队学习 正在如火如荼的进行中.为了大家更好的学习"零基础入门语音识别(食物声音识别)"的课程设计者 ...

  2. 网络模型 LSTM模型内容详解

    网络模型 LSTM模型内容详解

  3. [pytorch]yolov3.cfg参数详解(每层输出及route、yolo、shortcut层详解)

    文章目录 Backbone(Darknet53) 第一次下采样(to 208) 第二次下采样(to 104) 第三次下采样(to 52) 第四次下采样(to 26) 第五次下采样(to 13) YOL ...

  4. pytorch中的nn.LSTM模块参数详解

    直接去官网查看相关信息挺好的,但是为什么有的时候进不去 官网:https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM 使用示例,在使用中解释参数 单 ...

  5. libsvm 训练后,模型参数详解

    本节主要就是讲解利用libsvm-mat工具箱建立分类(回归模型)后,得到的模型model里面参数的意义,以及如果通过model得到相应模型的表达式,这里主要以分类问题为例子. 测试数据使用的是lib ...

  6. 数据挖掘介绍以及模型参数详解

    http://www.cnblogs.com/pinard/ cc 转载于:https://www.cnblogs.com/wangleBlogs/p/6803978.html

  7. xgboost模型参数详解

  8. 机器人系统的基本概念及外部模型参数详解

    目录 线控底盘介绍 遥控器说明 线控底盘使用操作 充电 上层传感器介绍 配置单 Xavier 介绍 RobSense简介(激光雷达) IMU简介 RealSense D435介绍 电气通讯拓扑连接说明 ...

  9. PyTorch实现AlexNet模型及参数详解

    文章目录 一.卷积池化层原理 二.全连接层原理 三.模型参数详解 注:AlexNet论文错误点 1.卷积池化层1 (1)卷积运算 (2)分组 (3)激活函数层 (4)池化层 (5)归一化处理 (6)参 ...

最新文章

  1. vc6中进行多行注释和反注释的方法
  2. Win10如何取消开机密码
  3. gdb调试多进程程序
  4. python将列表中反序输出_python中sorted怎么反序排列
  5. python每隔几秒执行一次_Python设置程序等待时间
  6. 开源:Taurus.MVC 框架 (已支持.NET Core)
  7. docker pull下载很慢_假如服务器上没有 Docker 环境,你还能愉快的拉取容器镜像吗?...
  8. 引用和使用引用传递参数《二》
  9. 网易2018校招内推编程题 小易喜欢的数列
  10. python 前端开发_python和前端开发怎么抉择?
  11. 使用a标签下载文件时成了预览,并非是下载
  12. github1s 油猴插件
  13. Lync添加自定义菜单
  14. 北京大学可视化发展前沿研究生暑期学校Day4
  15. 关于Certificate、Provisioning Profile、App ID的介绍及其之间的关系
  16. hadoop入门教程免费下载
  17. JVM--类加载器详解
  18. 2012年百度实习生招聘-java开发
  19. 单模SIW的设计步骤
  20. 腾讯云云函数SCF—入门须知

热门文章

  1. 基于SpingBoot和Thymelaf框架的旅游网设计
  2. VM10装Mac OS X 10.9.3
  3. CPA十三--借款费用的内容(转载)
  4. SpringBoot12 QueryDSL01之QueryDSL介绍、springBoot项目中集成QueryDSL、利用QueryDSL实现单表RUD、新增类初始化逻辑...
  5. java调用https的webservice,https的wsdl
  6. 转:苹果企业级开发者账号申请流程
  7. dbeaver下载镜像站
  8. 2023年跨境电商行业研究报告
  9. SUSCTF2022misc——ra2
  10. 微信小程序+uni-app知识点总结