-柚子皮-

RNN

参数

Parameters
input_size – The number of expected features in the input x

hidden_size – The number of features in the hidden state h

num_layers – Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two RNNs together to form a stacked RNN, with the second RNN taking in outputs of the first RNN and computing the final results. Default: 1 堆叠层数

nonlinearity – The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh'

bias – If False, then the layer does not use bias weights b_ih and b_hh. Default: True

batch_first – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False

dropout – If non-zero, introduces a Dropout layer on the outputs of each RNN layer except the last layer, with dropout probability equal to dropout. Default: 0

bidirectional – If True, becomes a bidirectional RNN. Default: False 是否使用双向rnn。

Note: RNN这里的序列长度,是动态的,不写在参数里的,具体会由输入的input参数而定。

Inputs: input, h_0
input维度 input of shape (seq_len, batch, input_size): tensor containing the features of the input sequence. The input can also be a packed variable length sequence. See torch.nn.utils.rnn.pack_padded_sequence() or torch.nn.utils.rnn.pack_sequence() for details.

[https://blog.csdn.net/zwqjoy/article/details/86490098]

h0维度 h_0 of shape (num_layers * num_directions, batch, hidden_size): tensor containing the initial hidden state for each element in the batch. Defaults to zero if not provided. If the RNN is bidirectional, num_directions should be 2, else it should be 1.h0是提供给每层RNN的初始输入,所有num_layers要和RNN的num_layers对得上。

Outputs: output, h_n
output of shape (seq_len, batch, num_directions * hidden_size): tensor containing the output features (h_t) from the last layer of the RNN, for each t. If a torch.nn.utils.rnn.PackedSequence has been given as the input, the output will also be a packed sequence.For the unpacked case, the directions can be separated using output.view(seq_len, batch, num_directions, hidden_size), with forward and backward being direction 0 and 1 respectively. Similarly, the directions can be separated in the packed case.RNN的上侧输出。

h_n of shape (num_layers * num_directions, batch, hidden_size): tensor containing the hidden state for t = seq_len.Like output, the layers can be separated using h_n.view(num_layers, num_directions, batch, hidden_size).RNN的右侧输出,如果是双向的话,就还有一个左侧输出。

具体参数和返回结果参考[https://pytorch.org/docs/stable/generated/torch.nn.RNN.html#torch.nn.RNN]

示例

rnn=nn.RNN(10,20,2) #(each_input_size, hidden_state, num_layers)
input=torch.randn(5,3,10) # (seq_len, batch, input_size)
h0=torch.randn(2,3,20) #(num_layers * num_directions, batch, hidden_size)
output,hn=rnn(input,h0)
print(output.size(),hn.size())

LSTM

具体参数和返回结果参考[https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM]

示例

rnn=nn.LSTM(10,20,2) #(each_input_size, hidden_state, num_layers)
input=torch.randn(5,3,10) # (seq_len, batch, input_size)
h0=torch.randn(2,3,20) #(num_layers * num_directions, batch, hidden_size)
c0=torch.randn(2,3,20) #(num_layers * num_directions, batch, hidden_size)
output,(hn,cn)=rnn(input,(h0,c0))   #seq_len x batch x hidden*bi_directional
print(output.size(),hn.size(),cn.size())

GRU

gru = nn.GRU(embed_size, hidden_size, n_layers, dropout=dropout, bidirectional=True)

具体参数参考:[https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#gru]

示例

import torch
import torch.nn as nn
rnn = nn.GRU(2, 4, 2,bidirectional=True)
input = torch.randn(2, 2, 2)
h0 = torch.randn(4, 2, 4)
output, hn = rnn(input, h0)
print(output)
print(hn)
print(output.size(),hn.size())

from: -柚子皮-

ref:[LSTM和GRU原理及pytorch代码,输入输出大小说明]

PyTorch:Encoder-RNN|LSTM|GRU相关推荐

  1. ​​​​​​​DL之RNN/LSTM/GRU:RNN/LSTM/GRU算法动图对比、TF代码定义之详细攻略

    DL之RNN/LSTM/GRU:RNN/LSTM/GRU算法动图对比.TF代码定义之详细攻略 目录 RNN.LSTM.GRU算法对比 1.RNN/LSTM/GRU对比 2.RNN/LSTM/GRU动图 ...

  2. DL之LSTM:LSTM算法论文简介(原理、关键步骤、RNN/LSTM/GRU比较、单层和多层的LSTM)、案例应用之详细攻略

    DL之LSTM:LSTM算法论文简介(原理.关键步骤.RNN/LSTM/GRU比较.单层和多层的LSTM).案例应用之详细攻略 目录 LSTM算法简介 1.LSTM算法论文 1.1.LSTM算法相关论 ...

  3. RNN LSTM GRU 代码实战 ---- 简单的文本生成任务

    RNN LSTM GRU 代码实战 ---- 简单的文本生成任务 import torch if torch.cuda.is_available():# Tell PyTorch to use the ...

  4. RNN, LSTM, GRU, SRU, Multi-Dimensional LSTM, Grid LSTM, Graph LSTM系列解读

    RNN/Stacked RNN rnn一般根据输入和输出的数目分为5种 一对一 最简单的rnn 一对多 Image Captioning(image -> sequence of words) ...

  5. RNN,LSTM,GRU计算方式及优缺点

    本文主要参考李宏毅老师的视频介绍RNN相关知识,主要包括两个部分: 分别介绍Navie RNN,LSTM,GRU的结构 对比这三者的优缺点 1.RNN,LSTM,GRU结构及计算方式 1.1 Navi ...

  6. Pytorch中如何理解RNN LSTM GRU的input(重点理解seq_len / time_steps)

    在建立时序模型时,若使用keras,我们在Input的时候就会在shape内设置好sequence_length(后面简称seq_len),接着便可以在自定义的data_generator内进行个性化 ...

  7. 图解 RNN, LSTM, GRU

    参考: Illustrated Guide to Recurrent Neural Networks Illustrated Guide to LSTM's and GRU's: A step by ...

  8. Rnn Lstm Gru Sru学习小结

    1.Rnn Rnn的详细介绍可以参考 深度学习之RNN(循环神经网络) 零基础入门深度学习(5) - 循环神经网络 详解循环神经网络(Recurrent Neural Network) 基本原理和算法 ...

  9. [PyTorch] rnn,lstm,gru中输入输出维度

    本文中的RNN泛指LSTM,GRU等等 CNN中和RNN中batchSize的默认位置是不同的. CNN中:batchsize的位置是position 0. RNN中:batchsize的位置是pos ...

  10. RNN,LSTM,GRU基本原理的个人理解重点

    20210626 循环神经网络_霜叶的博客-CSDN博客 LSTM的理解 - 走看看 重点 深入LSTM结构 首先使用LSTM的当前输入 (x^t)和上一个状态传递下来的 (h^{t-1}) 拼接训练 ...

最新文章

  1. CentOS基础网络配置路由和默认网关
  2. java链表寻找中间节点
  3. Optical-Flow光流halcon算子,持续更新
  4. Android手绘效果实现
  5. 2.2线性表的顺序表示和实现
  6. UIAlertController 大坑
  7. python 批量增加文件前缀_用python批量提取视频中的音频文件
  8. mongodb 存储过程 遍历表数据_一个mongodb存储过程
  9. AI学会了视觉推理,“脑补”看不清的物体 | 李佳李飞飞等的CVPR论文
  10. oracle open 101,解决oracle数据库ORA-65101 container database set up incorrectly
  11. Linux中文件的分类
  12. 微积分是研究连续变化的数学理论
  13. 机器人聊天软件c#_聊天机器人_c#应用
  14. arcgis 属性表中起点终点创建线_一种GIS单线路网自动生成双线路网的方法与流程...
  15. jQuery mouseover与mouseenter,mouseout与mouseleave的区别
  16. Java单元测试实践-09.Mockito的Stub参数条件
  17. Gradle配置多渠道打包
  18. VB.NET模拟LED数字钟
  19. 支付宝原型设计-低保真Axure9支付宝界面设计
  20. Ubuntu高效办公软件+插件

热门文章

  1. bzoj3620 似乎在梦中见过的样子
  2. 11.消息摘要算法之MD5
  3. Sql 2005 中比较两个数据库差异
  4. ProjectManage.rar 自动生成模板
  5. sql语句中遇到“被零除错误”提示的解决方法
  6. php gd实现简单图片验证码与图片背景文字水印
  7. 04,Django Form源码阅读
  8. 【带权二分】bzoj2654 tree
  9. JavaScript中值类型与引用类型
  10. Http协议对格式、请求头、方法