pytorch中的torch.nn.LSTM解析
文章目录
- 前言
- 多层LSTM
- 权重形状
- batch_first
- 输入形状
- 输出形状
- 参考
前言
本文记录一下使用LSTM的一些心得。
多层LSTM
多层LSTM是这样:
而不是这样:
我们可以控制如下的参数来控制:
权重形状
上面的权重除了偏置可以归结为3类,即U(输入专用,就是上面那些含有i的W),V(目标输出专用),W(隐藏层之间专用,含有h的W)。不过,这里没有目标输出。所以只有两大类,各四个,一共8个。
U类矩阵的形状都是[input_size,hidden_size]
W类矩阵的形状都是[hidden_size,hidden_size]
为什么官方不写出V呢?因为这个东西你可以自己搞定,因为取决于你怎么定义。你可以直接将hhh作为该位置的输出,也可以再乘以一个线性层WvW_vWv作为该位置的输出,比较灵活,所以官方干脆不定义,让你自己搞定接下来应该怎么办。
batch_first
LSTM或者RNN系列的模型默认batch_first=False,即batch在第二个维度。因此在将数据送入LSTM之前,x的形状你必须确保为:(seq_len,batch_size,emb_size)。通常,我们不太习惯,所以一般我们使用batch_first=True这个参数,变成batch在第一个维度。
输入形状
Inputs: input, (h_0, c_0)
其中
1.input形状为:(seq_len,batch_size,embedding_size)。就是下面的这个黑色箭头的输入。
2.h_0是指如下黄色的东西,初始隐状态,形状是(1,batch_size,hidden_size)
3.c_0是LSTM特有的,即细胞。和上面一样,形状是(1,batch_size,hidden_size)。
输出形状
Outputs: output, (h_n, c_n)
1.output不是目标输出y,而是我们得到的一系列隐状态(这个output指的是lstm层),其形状为:(seq_len,batch_size,hidden_size)
2.h_n,即最后一个隐状态,上面output是隐状态序列,所以这个理论上是多余,可以通过Output切片得到。即h_n=output[len(output)-1]。可想而知,其形状为:(1,batch_size,hidden_size)
3.c_n,即最后一个细胞状态,这个不是多余的。不过其形状仍然为(1,batch_size,hidden_size)。
lstm=nn.LSTM(3,2)#embedding_size=3,hidden_size=2
a=torch.rand(2,1,3)#长度为2的句子,seq_len=2,batch_size=1,embedding_size=3
lstm(a)#我们没有给出(h_0,c_0)默认即为0.
测试验证:
我们发现,两个红色箭头是一样的,这验证了h_n是多余的,可以由前者切片得到。
参考
本文总参考是:pytorch的官方文档https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html?highlight=lstm#torch.nn.LSTM。
pytorch中的torch.nn.LSTM解析相关推荐
- PyTorch中的torch.nn.Parameter() 详解
PyTorch中的torch.nn.Parameter() 详解 今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实 ...
- Pytorch 学习(6):Pytorch中的torch.nn Convolution Layers 卷积层参数初始化
Pytorch 学习(6):Pytorch中的torch.nn Convolution Layers 卷积层参数初始化 class Conv1d(_ConvNd):......def __init ...
- pytorch笔记:torch.nn.GRU torch.nn.LSTM
1 函数介绍 (GRU) 对于输入序列中的每个元素,每一层计算以下函数: 其中是在t时刻的隐藏状态,是在t时刻的输入.σ是sigmoid函数,*是逐元素的哈达玛积 对于多层GRU 第l层的输入(l≥2 ...
- pytorch torch.nn.LSTM
应用 >>> rnn = nn.LSTM(10, 20, 2) >>> input = torch.randn(5, 3, 10) >>> h0 ...
- gather torch_浅谈Pytorch中的torch.gather函数的含义
pytorch中的gather函数 pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验. 立个flag开始学习pytorch,新开一个分类整理学习pyt ...
- Pytorch LSTM初识(详解LSTM+torch.nn.LSTM()实现)1
pytorch LSTM1初识 目录 pytorch LSTM1初识 一.LSTM简介1
- Pytorch中如何理解RNN LSTM GRU的input(重点理解seq_len / time_steps)
在建立时序模型时,若使用keras,我们在Input的时候就会在shape内设置好sequence_length(后面简称seq_len),接着便可以在自定义的data_generator内进行个性化 ...
- 神经网路:pytorch中Variable和view参数解析
在PyTorch中计算图的特点总结如下: autograd根据用户对Variable的操作来构建其计算图. requires_grad variable默认是不需要被求导的,即requires_gra ...
- pytorch笔记:torch.nn.functional.pad
1 torch.nn.functional.pad函数 torch.nn.functional.pad是pytorch内置的tensor扩充函数,便于对数据集图像或中间层特征进行维度扩充 torch. ...
最新文章
- NVIDIA® TensorRT™ supports different data formats
- Python:给定一个不超过5位的正整数,判断有几位
- python3字典列表_python3入门(3)---列表、元组、字典、集合详解
- 20应用统计考研复试要点(part8)--应用多元分析
- [抄]外部奖励对内在动机的侵蚀
- mysql教程日志_MySQL日志
- Part2--排序算法类模板
- Protel入门教程
- PSV1000刷黑商
- 每月的第一个工作日执行的corn表达式
- uc android 面试题,一道新浪UC部门软件测试面试题
- C# 使用SharpGL-Perspective和LookAt
- 如何将ida中的悬浮窗口恢复原位
- 4.20 扣1送地狱火
- 破解微信 DB, 导出 Mac 微信聊天记录
- html一号店项目代码,项目一号店素材(html模板)
- 湾区潮涌·香港向前 | 香港科大副校长汪扬:用好一国两制制度优势,香港要成数字经济接轨世界桥梁...
- MQL--量化交易编程语言
- Android HTTPS请求总结
- 互联网首发 | 闲鱼程序员公开多年 Flutter 实践经验
热门文章
- 深度学习在单图像超分辨率上的应用:SRCNN、Perceptual loss、SRResNet
- Java的CountDownLatch和CyclicBarrier的理解和区别
- Yann LeCun主讲!纽约大学《深度学习》2021课程全部放出,附slides与视频
- 独家 | 如何通过TensorFlow 开发者资格考试(附链接)
- 一文概览图卷积网络基本结构和最新进展(附视频代码)
- 近期活动盘点:工业大数据讲座、大数据自杀风险感知讲座、数据法学研讨会、海外学者短期讲学(12.3-12.13)
- 舍友清华博士毕业,我建议他留在高校
- 写出漂亮 Python 代码的 20条准则
- 别说了,有画面了!Google文本生成图像取得新SOTA,CVPR2021已接收
- 谷歌的最新NLP模型,现在能陪你从诗词歌赋谈到人生哲学