Pytorch LSTM模型 参数详解
本文主要依据 Pytorch 中LSTM官方文档,对其中的模型参数、输入、输出进行详细解释。
目录
基本原理
模型参数 Parameters
输入Inputs: input, (h_0, c_0)
输出Outputs: output, (h_n, c_n)
变量Variables
备注
基本原理
首先我们看下面这个LSTM图, 对应于输入时间序列中每个步长的LSTM计算。
对应的公式计算公式如下:
其中 表示时刻 时刻的隐含状态, 表示时刻 上的记忆细胞, 表示时刻 的输入(对应于单个样本),表示隐含层在时刻的隐含状态或者是在起始时间o的初始隐含状态,, , 表示对应的输入、遗忘、 输出门。σ 表示的是sigmoid 函数,⊙ 表示哈达玛积(Hadamard product)。
对于含多个隐含层的LSTM,第 层()的输入 则对应的是前一层的隐含状态与丢弃dropout 的乘积,其中每一个 是Bernoulli随机变量(以参数dropout 的概率等于0)。
如果参数指定proj_size > 0,则将对LSTM使用投影。他的运作方式包括以下几步。首先,的维度将从hidden_size 转换为proj_size ( 的维度也会同时被改变)。第二,每一个层的隐含状态输出将与一个(可学习)的投影矩阵相乘: 。注意,这种投影模式同样对LSTM的输出有影响,即变成proj_size.
模型参数 Parameters
input_size – 输入变量x的特征数量
hidden_size – 隐含层h的特征数量(即层中隐含单元的个数)
num_layers – 隐含层的层数,比如说num_layers =2, 意味着这是包含两个LSTM层,默认值:1
bias – 如果为False, 表示不使用偏置权重 b_ih 和 b_hh。默认值为:True
batch_first – 如果为True,则输入和输出的tensor维度为从(seq, batch, feature)变成 (batch, seq, feature)。 注意,这个维度变化对隐含和细胞的层并不起做用。参见下面的Inputs/Outputs 部分的说明,默认值:False
dropout – 如果非0,则会给除最后一个LSTM层以外的其他层引入一个Dropout层,其对应的丢弃概率为dropout,默认值:0
bidirectional – 如果为True,则是一个双向的(bidirectional )的LSTM,默认值:False
proj_size – 如果>0, 则会使用相应投影大小的LSTM,默认值:0
输入Inputs: input, (h_0, c_0)
input:当batch_first = False 时形状为(L,N,H_in),当 batch_first = True 则为(N, L, H_in) ,包含批量样本的时间序列输入。该输入也可是一个可变换长度的时间序序列,参见 torch.nn.utils.rnn.pack_padded_sequence() 或者是 torch.nn.utils.rnn.pack_sequence() 了解详情。
h_0:形状为(D∗num_layers, N, H_out),指的是包含每一个批量样本的初始隐含状态。如果模型未提供(h_0, c_0) ,默认为是全0矩阵。
c_0:形状为(D∗num_layers, N, H_cell), 指的是包含每一个批量样本的初始记忆细胞状态。 如果模型未提供(h_0, c_0) ,默认为是全0矩阵。
其中:
N = 批量大小
L = 序列长度
D = 2 如果模型参数bidirectional = 2,否则为1
H_in = 输入的特征大小(input_size)
H_cell = 隐含单元数量(hidden_size)
H_out = proj_size, 如果proj_size > 0, 否则的话 = 隐含单元数量(hidden_size)
输出Outputs: output, (h_n, c_n)
output: 当batch_first = False 形状为(L, N, D∗H_out) ,当batch_first = True 则为 (N, L, D∗H_out) ,包含LSTM最后一层每一个时间步长 的输出特征()。如果输入的是torch.nn.utils.rnn.PackedSequence,则输出同样将是一个packed sequence。
h_n: 形状为(D∗num_layers, N, H_out),包括每一个批量样本最后一个时间步的隐含状态。
c_n: 形状为(D∗num_layers, N, H_cell),包括每一个批量样本最后一个时间步的记忆细胞状态。
变量Variables
~LSTM.weight_ih_l[k] – 学习得到的第k层的 input-hidden 权重 (W_ii|W_if|W_ig|W_io),当k=0 时形状为 (4*hidden_size, input_size) 。 否则,形状为 (4*hidden_size, num_directions * hidden_size)
~LSTM.weight_hh_l[k] –学习得到的第k层的 hidden -hidden 权重(W_hi|W_hf|W_hg|W_ho), 想形状为 (4*hidden_size, hidden_size)。如果 Proj_size > 0,则形状为 (4*hidden_size, proj_size)
~LSTM.bias_ih_l[k] – 学习得到的第k层的input-hidden 的偏置 (b_ii|b_if|b_ig|b_io), 形状为 (4*hidden_size)
~LSTM.bias_hh_l[k] – 学习得到的第k层的hidden -hidden 的偏置 (b_hi|b_hf|b_hg|b_ho), 形状为 (4*hidden_size)
~LSTM.weight_hr_l[k] – 学习得到第k层投影权重,形状为 (proj_size, hidden_size)。仅仅在 proj_size > 0 时该参数有效。
备注
- 所有的权重和偏置的初始化方法均取值于:
- 对于双向 LSTMs,前向和后向的方向分别为0 和1。当batch_first = False 时,对两个方向的输出层的提取可以使用方式:output.view(seq_len, batch, num_directions, hidden_size)。
- 具体怎么使用,可以参考本人博文 从零开始实现,LSTM模型进行单变量时间序列预测
- 关于LSTM模型的结构如果还有不清晰的可以参考这篇博客:Pytorch实现的LSTM模型结构
Pytorch LSTM模型 参数详解相关推荐
- 【直播】陈安东,但扬:CNN模型搭建、训练以及LSTM模型思路详解
CNN模型搭建.训练以及LSTM模型思路详解 目前 Datawhale第24期组队学习 正在如火如荼的进行中.为了大家更好的学习"零基础入门语音识别(食物声音识别)"的课程设计者 ...
- 网络模型 LSTM模型内容详解
网络模型 LSTM模型内容详解
- [pytorch]yolov3.cfg参数详解(每层输出及route、yolo、shortcut层详解)
文章目录 Backbone(Darknet53) 第一次下采样(to 208) 第二次下采样(to 104) 第三次下采样(to 52) 第四次下采样(to 26) 第五次下采样(to 13) YOL ...
- pytorch中的nn.LSTM模块参数详解
直接去官网查看相关信息挺好的,但是为什么有的时候进不去 官网:https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM 使用示例,在使用中解释参数 单 ...
- libsvm 训练后,模型参数详解
本节主要就是讲解利用libsvm-mat工具箱建立分类(回归模型)后,得到的模型model里面参数的意义,以及如果通过model得到相应模型的表达式,这里主要以分类问题为例子. 测试数据使用的是lib ...
- 数据挖掘介绍以及模型参数详解
http://www.cnblogs.com/pinard/ cc 转载于:https://www.cnblogs.com/wangleBlogs/p/6803978.html
- xgboost模型参数详解
- 机器人系统的基本概念及外部模型参数详解
目录 线控底盘介绍 遥控器说明 线控底盘使用操作 充电 上层传感器介绍 配置单 Xavier 介绍 RobSense简介(激光雷达) IMU简介 RealSense D435介绍 电气通讯拓扑连接说明 ...
- PyTorch实现AlexNet模型及参数详解
文章目录 一.卷积池化层原理 二.全连接层原理 三.模型参数详解 注:AlexNet论文错误点 1.卷积池化层1 (1)卷积运算 (2)分组 (3)激活函数层 (4)池化层 (5)归一化处理 (6)参 ...
最新文章
- vc6中进行多行注释和反注释的方法
- Win10如何取消开机密码
- gdb调试多进程程序
- python将列表中反序输出_python中sorted怎么反序排列
- python每隔几秒执行一次_Python设置程序等待时间
- 开源:Taurus.MVC 框架 (已支持.NET Core)
- docker pull下载很慢_假如服务器上没有 Docker 环境,你还能愉快的拉取容器镜像吗?...
- 引用和使用引用传递参数《二》
- 网易2018校招内推编程题 小易喜欢的数列
- python 前端开发_python和前端开发怎么抉择?
- 使用a标签下载文件时成了预览,并非是下载
- github1s 油猴插件
- Lync添加自定义菜单
- 北京大学可视化发展前沿研究生暑期学校Day4
- 关于Certificate、Provisioning Profile、App ID的介绍及其之间的关系
- hadoop入门教程免费下载
- JVM--类加载器详解
- 2012年百度实习生招聘-java开发
- 单模SIW的设计步骤
- 腾讯云云函数SCF—入门须知
热门文章
- 基于SpingBoot和Thymelaf框架的旅游网设计
- VM10装Mac OS X 10.9.3
- CPA十三--借款费用的内容(转载)
- SpringBoot12 QueryDSL01之QueryDSL介绍、springBoot项目中集成QueryDSL、利用QueryDSL实现单表RUD、新增类初始化逻辑...
- java调用https的webservice,https的wsdl
- 转:苹果企业级开发者账号申请流程
- dbeaver下载镜像站
- 2023年跨境电商行业研究报告
- SUSCTF2022misc——ra2
- 微信小程序+uni-app知识点总结