【pytorch】nn.LSTM的使用
官方文档在这里。
LSTM具体不做介绍了,本篇只做pytorch的API使用介绍
torch.nn.LSTM(*args, **kwargs)
输入张量
输入参数为两个,一个为input Tensor,一个为隐藏特征和状态组成的tuple
Inputs: input, (h_0, c_0)
input只能为3维张量,其中当初始参数batch_first=True时,shape为(L,N,H);而当初始参数batch_first=False时,shape为(N,L,H);
公式
LSTM中weights
以下公式介绍都忽略bias
- 公式(1), 输入门
it=δ(Wiixt+Whiht−1)i_t = \delta(W_{ii}x_t+W_{hi}h_{t-1})it=δ(Wiixt+Whiht−1), LSTM中有关输入的参是是WiiW_{ii}Wii和WhiW_{hi}Whi - 公式(2),遗忘门
ft=δ(Wifxt+Whfht−1)f_t = \delta(W_{if}x_t+W_{hf}h_{t-1})ft=δ(Wifxt+Whfht−1), LSTM中有关输入的参数是WifW_{if}Wif和WhfW_{hf}Whf - 公式(3),细胞更新状态
gt=δ(Wigxt+Whght−1)g_t = \delta(W_{ig}x_t+W_{hg}h_{t-1})gt=δ(Wigxt+Whght−1), LSTM中有关输入的参数是WigW_{ig}Wig和WhgW_{hg}Whg - 公式(4),输出门
ot=δ(Wioxt+Whoht−1)o_t = \delta(W_{io}x_t+W_{ho}h_{t-1})ot=δ(Wioxt+Whoht−1), LSTM中有关输入的参数是WioW_{io}Wio和WhoW_{ho}Who
所以从输入张量和隐藏层张量来说,一共有两组参数
- input 组 {WiiW_{ii}Wii, WifW_{if}Wif,WigW_{ig}Wig,WioW_{io}Wio }
- hidden组 { WhiW_{hi}Whi,WhfW_{hf}Whf, WhgW_{hg}Whg, WhoW_{ho}Who }
这里就对应官网上的两个参数
因为hidden size为隐藏层特征输出长度,所以每个参数第一维度都是hidden size;然后每一组是把4个张量按照第一维度拼接,所以要乘以4
举例代码:
from torch import nnlstm = nn.LSTM(input_size=3, hidden_size=6, num_layers=1, bias=False)print('weight_ih_l0.shape = ', lstm.weight_ih_l0.shape, ', weight_hh_l0.shape = ' , lstm.weight_hh_l0.shape)
双向LSTM
如果要实现双向的LSTM,只需要增加参数bidirectional=True
双向的区别是LTSM参数中 hidden 参数会增加一个方向,即有来有回,所以要double以下。
举例代码
from torch import nnlstm = nn.LSTM(input_size=3, hidden_size=6, num_layers=2, bias=False, bidirectional=True)print('weight_ih_l0.shape = ', lstm.weight_ih_l0.shape, ', weight_ih_l0_reverse.shape = ', lstm.weight_ih_l0_reverse.shape,'\nweight_hh_l0.shape = ' , lstm.weight_hh_l0.shape, ', weight_hh_l0_reverse.shape = ', lstm.weight_hh_l0_reverse.shape)
主要是hh部分的最后一维增加了一倍。
多层的概念
LSTM中有个参数num_layers是设置层数的,如果num_layers大于1,则网络则会变成如下的拓展
有关讨论请参考这里。
每多一层,就多一组**LSTM.weight_ih_l[k]和LSTM.weight_hh_l[k]**参数
两层的网络里有 LSTM.weight_ih_l0、LSTM.weight_ih_l0和LSTM.weight_ih_l1和LSTM.weight_hh_l1
LSTM的计算示例代码
input_size 为3,hidden_size 为4,
x=torch.Tensor([0,0,1,1,1,1,0,0,0,0,0,0]).view(2,3,2)print(x,x.shape)h0=torch.zeros(1,3,4)
c0=torch.zeros(1,3,4)net=nn.LSTM(2, 4, bias=False)
print("net.weight_ih_l0=", net.weight_ih_l0, net.weight_ih_l0.shape)
print("net.weight_hh_l0=", net.weight_hh_l0, net.weight_hh_l0.shape)
y,_=net(x, (h0,c0))
print(y, y.shape)
tensor([[[0., 0.],[1., 1.],[1., 1.]],[[0., 0.],[0., 0.],[0., 0.]]]) torch.Size([2, 3, 2])
net.weight_ih_l0= Parameter containing:
tensor([[-0.1310, 0.0330],[-0.3115, -0.0417],[-0.1452, 0.0426],[-0.0096, 0.3305],[-0.1104, 0.1816],[ 0.1668, -0.3706],[ 0.4792, -0.3867],[-0.4565, 0.0688],[ 0.0975, -0.0737],[ 0.2898, 0.2739],[-0.3564, 0.2723],[-0.1759, 0.2534],[ 0.3471, -0.1051],[ 0.4057, 0.3256],[ 0.4224, 0.4646],[-0.0107, -0.2000]], requires_grad=True) torch.Size([16, 2])
net.weight_hh_l0= Parameter containing:
tensor([[ 0.4431, 0.4195, -0.4328, 0.0183],[ 0.4375, 0.0306, -0.0641, -0.2027],[-0.3726, -0.0434, -0.4403, 0.2741],[-0.2962, -0.2381, -0.4713, -0.1349],[-0.1447, -0.0184, 0.3634, -0.0840],[-0.4828, -0.2628, 0.4112, -0.0554],[ 0.2004, -0.4253, -0.1785, 0.4688],[ 0.2922, -0.1926, -0.2644, 0.2561],[ 0.4504, -0.3577, -0.2971, -0.2796],[ 0.0442, -0.0018, -0.3970, 0.3194],[ 0.2986, -0.2493, -0.4371, 0.2953],[-0.0342, -0.2422, -0.2986, 0.0775],[ 0.4996, 0.1421, -0.1665, 0.1231],[-0.3341, -0.0462, -0.1578, -0.1443],[-0.0194, 0.3979, 0.0734, -0.2276],[ 0.0581, -0.3744, 0.4785, -0.0365]], requires_grad=True) torch.Size([16, 4])
tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000],[ 0.0064, 0.1402, -0.0282, 0.0200],[ 0.0064, 0.1402, -0.0282, 0.0200]],[[ 0.0000, 0.0000, 0.0000, 0.0000],[-0.0089, 0.0553, -0.0138, 0.0050],[-0.0089, 0.0553, -0.0138, 0.0050]]], grad_fn=<StackBackward>) torch.Size([2, 3, 4])
相关参考: Understanding LSTM and its diagram
【pytorch】nn.LSTM的使用相关推荐
- pytorch nn.LSTM()参数详解
输入数据格式: input(seq_len, batch, input_size) h0(num_layers * num_directions, batch, hidden_size) c0(num ...
- pytorch中的torch.nn.LSTM解析
文章目录 前言 多层LSTM 权重形状 batch_first 输入形状 输出形状 参考 前言 本文记录一下使用LSTM的一些心得. 多层LSTM 多层LSTM是这样: 而不是这样: 我们可以控制如下 ...
- 深度学习总结:tensorflow和pytorch关于RNN的对比,tf.nn.dynamic_rnn,nn.LSTM
tensorflow和pytorch关于RNN的对比: tf.nn.dynamic_rnn很难理解,他的意思只是用数据走一遍你搭建的RNN网络. 可以明显看出pytorch封装更高,更容易理解,动态图 ...
- pytorch中的nn.LSTM模块参数详解
直接去官网查看相关信息挺好的,但是为什么有的时候进不去 官网:https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM 使用示例,在使用中解释参数 单 ...
- pytorch中nn.Embedding和nn.LSTM和nn.Linear
使用pytorch实现一个LSTM网络很简单,最基本的有三个要素:nn.Embedding, nn.LSTM, nn.Linear 基本框架为: class LSTMModel(nn.Module): ...
- 总结PYTORCH中nn.lstm(自官方文档整理 包括参数、实例)
参考pytorch官方文档 https://pytorch.org/docs/master/nn.html#torch.nn.LSTM 先上原图 | 这里是关键参数介绍 input_size:输入特征 ...
- pytorch笔记:torch.nn.GRU torch.nn.LSTM
1 函数介绍 (GRU) 对于输入序列中的每个元素,每一层计算以下函数: 其中是在t时刻的隐藏状态,是在t时刻的输入.σ是sigmoid函数,*是逐元素的哈达玛积 对于多层GRU 第l层的输入(l≥2 ...
- pytorch torch.nn.LSTM
应用 >>> rnn = nn.LSTM(10, 20, 2) >>> input = torch.randn(5, 3, 10) >>> h0 ...
- Pytorch的LSTM的理解
20211227 lstm和gru的区别 Pytorch实现LSTM案例学习(1)_ch206265的博客-CSDN博客_pytorch搭建lstm lstm案例 class torch.nn.LST ...
最新文章
- 命题模式持续在变 你变不变
- let 和 var 的区别
- UA MATH523A 实分析1 集合论基础7 一些度量空间基本概念
- 想不明白的时候可以干的十件事情
- Python 生成器 迭代器
- python如何安装扩展库openpyxl和numpy_Python第三方库之openpyxl(2)
- CF1088F Ehab and a weird weight formula(树上最优性问题、贪心+倍增)
- iOS:苹果内购实践
- ha linux 设置虚拟ip_如何在虚拟机中设置CentOS静态IP?
- react android 串口,Maix Bit(K210) 与上位机串口通信
- js 事件模型 + ( 事件类型 )
- shell学习之跳出循环
- 106. 数据库增删改的封装
- 向日葵远程控制使用方法
- 网络安全应急响应(文末附应急工具)
- LU列主元法解线性方程组
- 游戏里的攻防-检测与反检测
- 总结:读《程序员的自我修养》
- “数字孪生”语境下的城市:拼图模式与航向之争
- Multiplier和Finger的区别和优劣讨论
热门文章
- 第一讲 NLP和深度学习入门
- 基于python爬虫数据处理_基于Python爬虫的校园数据获取
- Science:“每周工作进展汇报”在博士培养中的作用
- Science子刊: 长期杀虫剂诱导选择下的宿主基因组与微生物组的共适应
- QIIME 2教程. 24Python命令行模式Artifact API(2021.2)
- linux 极简统计分析工具 datamash 必看教程
- 终极大招——怎么在学术会议上有所收获?
- Python使用matplotlib可视化时间序列季节图、使用季节图可以比较不同年份相同月份的数据差异、或者相同(年/月/周等)的时间序列在同一天的数据差异(Seasonal Plot)
- R语言plotly可视化:plotly可视化水平直方图(Horizontal Histogram)
- R语言使用ggplot2包使用geom_boxplot函数绘制基础分组水平箱图(boxplot)实战