1 函数介绍 (GRU)

对于输入序列中的每个元素,每一层计算以下函数:

其中是在t时刻的隐藏状态,是在t时刻的输入。σ是sigmoid函数,*是逐元素的哈达玛积

对于多层GRU 第l层的输入(l≥2)是之前一层的隐藏状态,乘以dropout 

2 输入参数介绍(GRU)

input_size 输入特征的大小
hidden_size 隐藏层h特征的大小
num_layers

GRU层数。

例如,设置 num_layers=2 意味着将两个 GRU 堆叠在一起形成一个堆叠的 GRU,第二个 GRU 接收第一个 GRU 的输出并计算最终结果。

默认值:1

bias

默认值True

上式的那些b是否为0,如果是False的话,那么这些b就都是0

batch_first

如果为 True,则输入和输出Tensor的维度为 (batch, seq, feature) 而不是 (seq, batch, feature)。

默认值:False

dropout

如果非零,则在除最后一层之外的每个 GRU 层的输出上引入一个 Dropout 层,dropout 概率等于 dropout。

默认值:0

bidirectional

如果是True,那么就变成双向GRU

默认值:False

3 使用举例(GRU)

3.1 输入tensor的维度

input:当batch_first=False的时候,维度为;否则是

h_0: 

3.2 输出tensor的维度

output:当batch_first=False的时候,维度为;否则是

h_n:

3.3 实例说明

import torchGRU=torch.nn.GRU(input_size=10,hidden_size=20,num_layers=20)input_tensor=torch.randn(5,3,10)
'''
输入的sequence长5
batch_size为3
输入sequence每一个元素的维度为10
'''
h0=torch.randn(1*20,3,20)
'''
第一个参数:单方向GRU(1),20层GRU(20)
第二个参数:batch_size
第三个参数:hidden_size的大小
'''
output,hn=GRU(input_tensor,h0)
output.shape,hn.shape
#(torch.Size([5, 3, 20]), torch.Size([20, 3, 20]))

4 torch.nn.LSTM

和GRU几乎完全一模一样,这里说几个不同的地方:

声明的时候,在bidirectional 后面还有一个参数proj_size,默认为0。如果这个参数为0,那么Hc和Hout的维度都是;如果参数大于0,那么Hc仍然是​​​​​​​,Hout变成

​​​​​​​

pytorch笔记:torch.nn.GRU torch.nn.LSTM相关推荐

  1. PyTorch 笔记(16)— torch.nn.Sequential、torch.nn.Linear、torch.nn.RelU

    PyTorch 中的 torch.nn 包提供了很多与实现神经网络中的具体功能相关的类,这些类涵盖了深度神经网络模型在搭建和参数优化过程中的常用内容,比如神经网络中的卷积层.池化层.全连接层这类层次构 ...

  2. torch的拼接函数_从零开始深度学习Pytorch笔记(13)—— torch.optim

    前文传送门: 从零开始深度学习Pytorch笔记(1)--安装Pytorch 从零开始深度学习Pytorch笔记(2)--张量的创建(上) 从零开始深度学习Pytorch笔记(3)--张量的创建(下) ...

  3. PyTorch 笔记(18)— torch.optim 优化器的使用

    到目前为止,代码中的神经网络权重的参数优化和更新还没有实现自动化,并且目前使用的优化方法都有固定的学习速率,所以优化函数相对简单,如果我们自己实现一些高级的参数优化算法,则优化函数部分的代码会变得较为 ...

  4. pytorch学习(五)---torch.nn模块

            本篇自学笔记来自于b站<PyTorch深度学习快速入门教程(绝对通俗易懂!)[小土堆]>,Up主讲的非常通俗易懂,文章下方有视频连接,如有需要可移步up主讲解视频,如有侵权 ...

  5. 深入理解Pytorch负对数似然函数(torch.nn.NLLLoss)和交叉熵损失函数(torch.nn.CrossEntropyLoss)

    在看Pytorch的交叉熵损失函数torch.nn.CrossEntropyLoss官方文档介绍中,给出的表达式如下.不免有点疑惑为何交叉熵损失的表达式是这个样子的 loss ⁡ ( y , clas ...

  6. [Pytorch系列-30]:神经网络基础 - torch.nn库五大基本功能:nn.Parameter、nn.Linear、nn.functioinal、nn.Module、nn.Sequentia

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  7. nn.Module、nn.Sequential和torch.nn.parameter学习笔记

    nn.Module.nn.Sequential和torch.nn.parameter是利用pytorch构建神经网络最重要的三个函数.搞清他们的具体用法是学习pytorch的必经之路. 目录 nn.M ...

  8. Pytorch:模型的保存与加载 torch.save()、torch.load()、torch.nn.Module.load_state_dict()

    Pytorch 保存和加载模型后缀:.pt 和.pth 1 torch.save() [source] 保存一个序列化(serialized)的目标到磁盘.函数使用了Python的pickle程序用于 ...

  9. Pytorch:核心模块,torch.nn与网络组成单元

    Pytorch: torch.nn 模块与网络组成单元 Copyright: Jingmin Wei, Pattern Recognition and Intelligent System, Scho ...

最新文章

  1. JSTL标签之核心标签
  2. Lvs Tun隧道模式配置
  3. 工具类—KeyValuePair
  4. matlab生成exe独立运行文件已破解(好用)
  5. 『飞秋』小项目心得交流
  6. sql子查询示例_学习SQL:SQL查询示例
  7. 2016 英语作文二
  8. 我来做百科(第七天)
  9. golang(4)使用beego + ace admin 开发后台系统 CRUD
  10. socks代理和http代理的区别_浅析socks代理如何使用TCP和UDP协议
  11. armbian编译安装mentohust 认证锐捷客户端
  12. 诺基亚pc远程服务器,用远程桌面把win10装进iphone —-40核256G内存的生产力工具随身带...
  13. P3369 普通平衡树模板 treap
  14. 391、Java框架46 -【Hibernate - 查询HQL、查询Criteria、查询标准SQL】 2020.10.19
  15. [渝粤教育] 西南科技大学 仓储与配送管理 在线考试复习资料
  16. 工具善其事,必先被苦逼的其器所钝伤然后打磨之才能利其器
  17. 深入学习Docker网络(看这篇就完全够了)
  18. 【Pyecharts50例】自定义饼图标签/显示百分比
  19. c#窗体编辑个人简历_C#开发工程师完整简历范文
  20. (软考)系统架构师大纲

热门文章

  1. angular2 安装
  2. ios开发中,User Defined Runtime Attributes的应用
  3. ESP-TOUCH编码规则及解码
  4. Android学习记录:SQLite数据库、res中raw的文件调用
  5. python绘图使用subplots出现标题重叠的解决方法
  6. PAT甲级1015 Reversible Primes :[C++题解]进制位、秦九韶算法、判质数
  7. mac phpstorm调试php,MAC下phpstorm20190302+Xdebug2.7断点调试PHP | 朱斌技术博客
  8. if vue 跳出_vue使用v-if v-show 页面闪烁
  9. python字典里可以放列表吗_学习python之列表及字典
  10. python测验4_python接口自动化测试四:代码发送HTTPS请求