文章目录

  • 前言
  • 多层LSTM
  • 权重形状
  • batch_first
  • 输入形状
  • 输出形状
  • 参考

前言

本文记录一下使用LSTM的一些心得。

多层LSTM

多层LSTM是这样:

而不是这样:

我们可以控制如下的参数来控制:

权重形状


上面的权重除了偏置可以归结为3类,即U(输入专用,就是上面那些含有i的W),V(目标输出专用),W(隐藏层之间专用,含有h的W)。不过,这里没有目标输出。所以只有两大类,各四个,一共8个。

U类矩阵的形状都是[input_size,hidden_size]
W类矩阵的形状都是[hidden_size,hidden_size]

为什么官方不写出V呢?因为这个东西你可以自己搞定,因为取决于你怎么定义。你可以直接将hhh作为该位置的输出,也可以再乘以一个线性层WvW_vWv作为该位置的输出,比较灵活,所以官方干脆不定义,让你自己搞定接下来应该怎么办。

batch_first

LSTM或者RNN系列的模型默认batch_first=False,即batch在第二个维度。因此在将数据送入LSTM之前,x的形状你必须确保为:(seq_len,batch_size,emb_size)。通常,我们不太习惯,所以一般我们使用batch_first=True这个参数,变成batch在第一个维度。

输入形状

Inputs: input, (h_0, c_0)

其中
1.input形状为:(seq_len,batch_size,embedding_size)。就是下面的这个黑色箭头的输入。

2.h_0是指如下黄色的东西,初始隐状态,形状是(1,batch_size,hidden_size)

3.c_0是LSTM特有的,即细胞。和上面一样,形状是(1,batch_size,hidden_size)。

输出形状

Outputs: output, (h_n, c_n)

1.output不是目标输出y,而是我们得到的一系列隐状态(这个output指的是lstm层),其形状为:(seq_len,batch_size,hidden_size)

2.h_n,即最后一个隐状态,上面output是隐状态序列,所以这个理论上是多余,可以通过Output切片得到。即h_n=output[len(output)-1]。可想而知,其形状为:(1,batch_size,hidden_size)

3.c_n,即最后一个细胞状态,这个不是多余的。不过其形状仍然为(1,batch_size,hidden_size)。

lstm=nn.LSTM(3,2)#embedding_size=3,hidden_size=2
a=torch.rand(2,1,3)#长度为2的句子,seq_len=2,batch_size=1,embedding_size=3
lstm(a)#我们没有给出(h_0,c_0)默认即为0.

测试验证:


我们发现,两个红色箭头是一样的,这验证了h_n是多余的,可以由前者切片得到。

参考

本文总参考是:pytorch的官方文档https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html?highlight=lstm#torch.nn.LSTM。

pytorch中的torch.nn.LSTM解析相关推荐

  1. PyTorch中的torch.nn.Parameter() 详解

    PyTorch中的torch.nn.Parameter() 详解 今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实 ...

  2. Pytorch 学习(6):Pytorch中的torch.nn Convolution Layers 卷积层参数初始化

    Pytorch 学习(6):Pytorch中的torch.nn  Convolution Layers  卷积层参数初始化 class Conv1d(_ConvNd):......def __init ...

  3. pytorch笔记:torch.nn.GRU torch.nn.LSTM

    1 函数介绍 (GRU) 对于输入序列中的每个元素,每一层计算以下函数: 其中是在t时刻的隐藏状态,是在t时刻的输入.σ是sigmoid函数,*是逐元素的哈达玛积 对于多层GRU 第l层的输入(l≥2 ...

  4. pytorch torch.nn.LSTM

    应用 >>> rnn = nn.LSTM(10, 20, 2) >>> input = torch.randn(5, 3, 10) >>> h0 ...

  5. gather torch_浅谈Pytorch中的torch.gather函数的含义

    pytorch中的gather函数 pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验. 立个flag开始学习pytorch,新开一个分类整理学习pyt ...

  6. Pytorch LSTM初识(详解LSTM+torch.nn.LSTM()实现)1

    pytorch  LSTM1初识 目录 pytorch  LSTM1初识 ​​​​​​​​​​​​​​​​​​​​​ 一.LSTM简介1

  7. Pytorch中如何理解RNN LSTM GRU的input(重点理解seq_len / time_steps)

    在建立时序模型时,若使用keras,我们在Input的时候就会在shape内设置好sequence_length(后面简称seq_len),接着便可以在自定义的data_generator内进行个性化 ...

  8. 神经网路:pytorch中Variable和view参数解析

    在PyTorch中计算图的特点总结如下: autograd根据用户对Variable的操作来构建其计算图. requires_grad variable默认是不需要被求导的,即requires_gra ...

  9. pytorch笔记:torch.nn.functional.pad

    1 torch.nn.functional.pad函数 torch.nn.functional.pad是pytorch内置的tensor扩充函数,便于对数据集图像或中间层特征进行维度扩充 torch. ...

最新文章

  1. NVIDIA® TensorRT™ supports different data formats
  2. Python:给定一个不超过5位的正整数,判断有几位
  3. python3字典列表_python3入门(3)---列表、元组、字典、集合详解
  4. 20应用统计考研复试要点(part8)--应用多元分析
  5. [抄]外部奖励对内在动机的侵蚀
  6. mysql教程日志_MySQL日志
  7. Part2--排序算法类模板
  8. Protel入门教程
  9. PSV1000刷黑商
  10. 每月的第一个工作日执行的corn表达式
  11. uc android 面试题,一道新浪UC部门软件测试面试题
  12. C# 使用SharpGL-Perspective和LookAt
  13. 如何将ida中的悬浮窗口恢复原位
  14. 4.20 扣1送地狱火
  15. 破解微信 DB, 导出 Mac 微信聊天记录
  16. html一号店项目代码,项目一号店素材(html模板)
  17. 湾区潮涌·香港向前 | 香港科大副校长汪扬:用好一国两制制度优势,香港要成数字经济接轨世界桥梁...
  18. MQL--量化交易编程语言
  19. Android HTTPS请求总结
  20. 互联网首发 | 闲鱼程序员公开多年 Flutter 实践经验

热门文章

  1. 深度学习在单图像超分辨率上的应用:SRCNN、Perceptual loss、SRResNet
  2. Java的CountDownLatch和CyclicBarrier的理解和区别
  3. Yann LeCun主讲!纽约大学《深度学习》2021课程全部放出,附slides与视频
  4. 独家 | 如何通过TensorFlow 开发者资格考试(附链接)
  5. 一文概览图卷积网络基本结构和最新进展(附视频代码)
  6. 近期活动盘点:工业大数据讲座、大数据自杀风险感知讲座、数据法学研讨会、海外学者短期讲学(12.3-12.13)
  7. 舍友清华博士毕业,我建议他留在高校
  8. 写出漂亮 Python 代码的 20条准则
  9. 别说了,有画面了!Google文本生成图像取得新SOTA,CVPR2021已接收
  10. 谷歌的最新NLP模型,现在能陪你从诗词歌赋谈到人生哲学