官方文档在这里。

LSTM具体不做介绍了,本篇只做pytorch的API使用介绍

torch.nn.LSTM(*args, **kwargs)

输入张量

输入参数为两个,一个为input Tensor,一个为隐藏特征和状态组成的tuple

Inputs: input, (h_0, c_0)

input只能为3维张量,其中当初始参数batch_first=True时,shape为(L,N,H);而当初始参数batch_first=False时,shape为(N,L,H);

公式

LSTM中weights

以下公式介绍都忽略bias

  • 公式(1), 输入门
    it=δ(Wiixt+Whiht−1)i_t = \delta(W_{ii}x_t+W_{hi}h_{t-1})it​=δ(Wii​xt​+Whi​ht−1​), LSTM中有关输入的参是是WiiW_{ii}Wii​和WhiW_{hi}Whi​
  • 公式(2),遗忘门
    ft=δ(Wifxt+Whfht−1)f_t = \delta(W_{if}x_t+W_{hf}h_{t-1})ft​=δ(Wif​xt​+Whf​ht−1​), LSTM中有关输入的参数是WifW_{if}Wif​和WhfW_{hf}Whf​
  • 公式(3),细胞更新状态
    gt=δ(Wigxt+Whght−1)g_t = \delta(W_{ig}x_t+W_{hg}h_{t-1})gt​=δ(Wig​xt​+Whg​ht−1​), LSTM中有关输入的参数是WigW_{ig}Wig​和WhgW_{hg}Whg​
  • 公式(4),输出门
    ot=δ(Wioxt+Whoht−1)o_t = \delta(W_{io}x_t+W_{ho}h_{t-1})ot​=δ(Wio​xt​+Who​ht−1​), LSTM中有关输入的参数是WioW_{io}Wio​和WhoW_{ho}Who​

所以从输入张量和隐藏层张量来说,一共有两组参数

  1. input 组 {WiiW_{ii}Wii​, WifW_{if}Wif​,WigW_{ig}Wig​,WioW_{io}Wio​ }
  2. hidden组 { WhiW_{hi}Whi​,WhfW_{hf}Whf​, WhgW_{hg}Whg​, WhoW_{ho}Who​ }

这里就对应官网上的两个参数

因为hidden size为隐藏层特征输出长度,所以每个参数第一维度都是hidden size;然后每一组是把4个张量按照第一维度拼接,所以要乘以4

举例代码:

from torch import nnlstm = nn.LSTM(input_size=3, hidden_size=6, num_layers=1, bias=False)print('weight_ih_l0.shape = ', lstm.weight_ih_l0.shape, ', weight_hh_l0.shape = ' , lstm.weight_hh_l0.shape)

双向LSTM

如果要实现双向的LSTM,只需要增加参数bidirectional=True

双向的区别是LTSM参数中 hidden 参数会增加一个方向,即有来有回,所以要double以下。

举例代码


from torch import nnlstm = nn.LSTM(input_size=3, hidden_size=6, num_layers=2, bias=False, bidirectional=True)print('weight_ih_l0.shape = ', lstm.weight_ih_l0.shape, ', weight_ih_l0_reverse.shape = ', lstm.weight_ih_l0_reverse.shape,'\nweight_hh_l0.shape = ' , lstm.weight_hh_l0.shape, ', weight_hh_l0_reverse.shape = ', lstm.weight_hh_l0_reverse.shape)

主要是hh部分的最后一维增加了一倍。

多层的概念

LSTM中有个参数num_layers是设置层数的,如果num_layers大于1,则网络则会变成如下的拓展

有关讨论请参考这里。

每多一层,就多一组**LSTM.weight_ih_l[k]LSTM.weight_hh_l[k]**参数

两层的网络里有 LSTM.weight_ih_l0LSTM.weight_ih_l0LSTM.weight_ih_l1LSTM.weight_hh_l1


LSTM的计算示例代码

input_size 为3,hidden_size 为4,

x=torch.Tensor([0,0,1,1,1,1,0,0,0,0,0,0]).view(2,3,2)print(x,x.shape)h0=torch.zeros(1,3,4)
c0=torch.zeros(1,3,4)net=nn.LSTM(2, 4, bias=False)
print("net.weight_ih_l0=", net.weight_ih_l0, net.weight_ih_l0.shape)
print("net.weight_hh_l0=", net.weight_hh_l0, net.weight_hh_l0.shape)
y,_=net(x, (h0,c0))
print(y, y.shape)
tensor([[[0., 0.],[1., 1.],[1., 1.]],[[0., 0.],[0., 0.],[0., 0.]]]) torch.Size([2, 3, 2])
net.weight_ih_l0= Parameter containing:
tensor([[-0.1310,  0.0330],[-0.3115, -0.0417],[-0.1452,  0.0426],[-0.0096,  0.3305],[-0.1104,  0.1816],[ 0.1668, -0.3706],[ 0.4792, -0.3867],[-0.4565,  0.0688],[ 0.0975, -0.0737],[ 0.2898,  0.2739],[-0.3564,  0.2723],[-0.1759,  0.2534],[ 0.3471, -0.1051],[ 0.4057,  0.3256],[ 0.4224,  0.4646],[-0.0107, -0.2000]], requires_grad=True) torch.Size([16, 2])
net.weight_hh_l0= Parameter containing:
tensor([[ 0.4431,  0.4195, -0.4328,  0.0183],[ 0.4375,  0.0306, -0.0641, -0.2027],[-0.3726, -0.0434, -0.4403,  0.2741],[-0.2962, -0.2381, -0.4713, -0.1349],[-0.1447, -0.0184,  0.3634, -0.0840],[-0.4828, -0.2628,  0.4112, -0.0554],[ 0.2004, -0.4253, -0.1785,  0.4688],[ 0.2922, -0.1926, -0.2644,  0.2561],[ 0.4504, -0.3577, -0.2971, -0.2796],[ 0.0442, -0.0018, -0.3970,  0.3194],[ 0.2986, -0.2493, -0.4371,  0.2953],[-0.0342, -0.2422, -0.2986,  0.0775],[ 0.4996,  0.1421, -0.1665,  0.1231],[-0.3341, -0.0462, -0.1578, -0.1443],[-0.0194,  0.3979,  0.0734, -0.2276],[ 0.0581, -0.3744,  0.4785, -0.0365]], requires_grad=True) torch.Size([16, 4])
tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000],[ 0.0064,  0.1402, -0.0282,  0.0200],[ 0.0064,  0.1402, -0.0282,  0.0200]],[[ 0.0000,  0.0000,  0.0000,  0.0000],[-0.0089,  0.0553, -0.0138,  0.0050],[-0.0089,  0.0553, -0.0138,  0.0050]]], grad_fn=<StackBackward>) torch.Size([2, 3, 4])

相关参考: Understanding LSTM and its diagram

【pytorch】nn.LSTM的使用相关推荐

  1. pytorch nn.LSTM()参数详解

    输入数据格式: input(seq_len, batch, input_size) h0(num_layers * num_directions, batch, hidden_size) c0(num ...

  2. pytorch中的torch.nn.LSTM解析

    文章目录 前言 多层LSTM 权重形状 batch_first 输入形状 输出形状 参考 前言 本文记录一下使用LSTM的一些心得. 多层LSTM 多层LSTM是这样: 而不是这样: 我们可以控制如下 ...

  3. 深度学习总结:tensorflow和pytorch关于RNN的对比,tf.nn.dynamic_rnn,nn.LSTM

    tensorflow和pytorch关于RNN的对比: tf.nn.dynamic_rnn很难理解,他的意思只是用数据走一遍你搭建的RNN网络. 可以明显看出pytorch封装更高,更容易理解,动态图 ...

  4. pytorch中的nn.LSTM模块参数详解

    直接去官网查看相关信息挺好的,但是为什么有的时候进不去 官网:https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM 使用示例,在使用中解释参数 单 ...

  5. pytorch中nn.Embedding和nn.LSTM和nn.Linear

    使用pytorch实现一个LSTM网络很简单,最基本的有三个要素:nn.Embedding, nn.LSTM, nn.Linear 基本框架为: class LSTMModel(nn.Module): ...

  6. 总结PYTORCH中nn.lstm(自官方文档整理 包括参数、实例)

    参考pytorch官方文档 https://pytorch.org/docs/master/nn.html#torch.nn.LSTM 先上原图 | 这里是关键参数介绍 input_size:输入特征 ...

  7. pytorch笔记:torch.nn.GRU torch.nn.LSTM

    1 函数介绍 (GRU) 对于输入序列中的每个元素,每一层计算以下函数: 其中是在t时刻的隐藏状态,是在t时刻的输入.σ是sigmoid函数,*是逐元素的哈达玛积 对于多层GRU 第l层的输入(l≥2 ...

  8. pytorch torch.nn.LSTM

    应用 >>> rnn = nn.LSTM(10, 20, 2) >>> input = torch.randn(5, 3, 10) >>> h0 ...

  9. Pytorch的LSTM的理解

    20211227 lstm和gru的区别 Pytorch实现LSTM案例学习(1)_ch206265的博客-CSDN博客_pytorch搭建lstm lstm案例 class torch.nn.LST ...

最新文章

  1. 命题模式持续在变 你变不变
  2. let 和 var 的区别
  3. UA MATH523A 实分析1 集合论基础7 一些度量空间基本概念
  4. 想不明白的时候可以干的十件事情
  5. Python 生成器 迭代器
  6. python如何安装扩展库openpyxl和numpy_Python第三方库之openpyxl(2)
  7. CF1088F Ehab and a weird weight formula(树上最优性问题、贪心+倍增)
  8. iOS:苹果内购实践
  9. ha linux 设置虚拟ip_如何在虚拟机中设置CentOS静态IP?
  10. react android 串口,Maix Bit(K210) 与上位机串口通信
  11. js 事件模型 + ( 事件类型 )
  12. shell学习之跳出循环
  13. 106. 数据库增删改的封装
  14. 向日葵远程控制使用方法
  15. 网络安全应急响应(文末附应急工具)
  16. LU列主元法解线性方程组
  17. 游戏里的攻防-检测与反检测
  18. 总结:读《程序员的自我修养》
  19. “数字孪生”语境下的城市:拼图模式与航向之争
  20. Multiplier和Finger的区别和优劣讨论

热门文章

  1. 第一讲 NLP和深度学习入门
  2. 基于python爬虫数据处理_基于Python爬虫的校园数据获取
  3. Science:“每周工作进展汇报”在博士培养中的作用
  4. Science子刊: 长期杀虫剂诱导选择下的宿主基因组与微生物组的共适应
  5. QIIME 2教程. 24Python命令行模式Artifact API(2021.2)
  6. linux 极简统计分析工具 datamash 必看教程
  7. 终极大招——怎么在学术会议上有所收获?
  8. Python使用matplotlib可视化时间序列季节图、使用季节图可以比较不同年份相同月份的数据差异、或者相同(年/月/周等)的时间序列在同一天的数据差异(Seasonal Plot)
  9. R语言plotly可视化:plotly可视化水平直方图(Horizontal Histogram)
  10. R语言使用ggplot2包使用geom_boxplot函数绘制基础分组水平箱图(boxplot)实战