pytorch torch.nn.RNN
应用
rnn = nn.RNN(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = rnn(input, h0)
API
CLASS torch.nn.RNN(*args, **kwargs)
ht=tanh(Wihxt+bih+Whhht−1+bhhh_t=tanh(W_{ih}x_t+b_{ih}+W_{hh}h_{t-1}+b_{hh}ht=tanh(Wihxt+bih+Whhht−1+bhh
hth_tht:the hidden state at time t
xtx_txt:the input at time t
ht−1h_{t-1}ht−1:the hidden state of the previous layer at time t-1
如果nonlinearity
是relu
,则会替换tanh
类
参数 | 描述 |
---|---|
input_size | The number of expected features in the input x |
hidden_size | The number of features in the hidden state h |
num_layers | Number of recurrent layers,Default: 1 |
nonlinearity | The non-linearity to use.Default: ‘tanh’ |
bias | If False, then the layer does not use bias weights b_ih and b_hh. Default: True |
batch_first | If True, then the input and output tensors are provided as (batch, seq, feature). Default: False |
bidirectional | If True, becomes a bidirectional RNN. Default: False |
input_size:是RNN的维度,注意不是句子或序列的长度,而是句子的一个词,或序列的一个元素的维度。比如,词向量的维度。比如说NLP中你需要把一个单词输入到RNN中,这个单词的编码是300维的,那么这个input_size就是300.
hidden_size:每个RNN的节点实际上就是一个BP网络,包含输入层,隐含层,输出层。这里就是指隐藏层的节点个数。
num_layers:如果num_layer=2的话,表示两个RNN堆叠在一起。
参考:https://www.cnblogs.com/dhName/p/11760610.html
对象
输入:
参数 | 描述 |
---|---|
input of shape (seq_len, batch, input_size) | The input can also be a packed variable length sequence. See torch.nn.utils.rnn.pack_padded_sequence() or torch.nn.utils.rnn.pack_sequence() for details. |
h_0 of shape (num_layers * num_directions, batch, hidden_size) | tensor containing the initial hidden state for each element in the batch. Defaults to zero if not provided. If the RNN is bidirectional, num_directions should be 2, else it should be 1. |
输出:
参数 | 描述 |
---|---|
output of shape (seq_len, batch, num_directions * hidden_size): | |
h_n of shape (num_layers * num_directions, batch, hidden_size): |
参考:
https://www.icode9.com/content-4-622959.html
https://zhuanlan.zhihu.com/p/59772104
只是batch长度要求相同,但不同batch则不需要相同?
https://zhuanlan.zhihu.com/p/97378498
https://www.cnblogs.com/lindaxin/p/8052043.html
https://www.jianshu.com/p/f5b816750839
https://www.jianshu.com/p/efe045c24a93
https://zhuanlan.zhihu.com/p/161972223
https://zhuanlan.zhihu.com/p/34418001?edition=yidianzixun&utm_source=yidianzixun&yidian_docid=0IVwLf60
https://www.cnblogs.com/jiangkejie/p/13141664.html
https://zhuanlan.zhihu.com/p/64527432
pytorch torch.nn.RNN相关推荐
- (pytorch-深度学习)使用pytorch框架nn.RNN实现循环神经网络
使用pytorch框架nn.RNN实现循环神经网络 首先,读取周杰伦专辑歌词数据集. import time import math import numpy as np import torch f ...
- PyTorch : torch.nn.xxx 和 torch.nn.functional.xxx
PyTorch : torch.nn.xxx 和 torch.nn.functional.xxx 在写 PyTorch 代码时,我们会发现在 torch.nn.xxx 和 torch.nn.funct ...
- torch.nn.RNN基本用法
#torch.nn.RNN CLASS torch.nn.RNN(*args, **kwargs) **实现的功能:**实现一个用tanh或者ReLU作为非线性成分的Elman RNN(两种RNN中的 ...
- pytorch torch.nn.MSELoss
应用 # 1.计算绝对差总和:|0-1|^2+|1-1|^2+|2-1|^2+|3-1|^2=6 # 2.求平均: 6/4 =1.5 import torch import torch.nn as n ...
- pytorch torch.nn.Module.register_buffer
API register_buffer(name: str, tensor: Optional[torch.Tensor], persistent: bool = True) → None 注册buf ...
- pytorch torch.nn.TransformerEncoderLayer
API CLASS torch.nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=2048, dropout=0.1, activa ...
- pytorch torch.nn.LSTM
应用 >>> rnn = nn.LSTM(10, 20, 2) >>> input = torch.randn(5, 3, 10) >>> h0 ...
- pytorch torch.nn.TransformerEncoder
API CLASS torch.nn.TransformerEncoder(encoder_layer, num_layers, norm=None) TransformerEncoder is a ...
- pytorch torch.nn.Embedding
词嵌入矩阵,可以加载使用word2vector,glove API CLASS torch.nn.Embedding(num_embeddings: int, embedding_dim: int, ...
最新文章
- Pidgin Portable 使用点滴
- 有个产品经理女朋友是一种什么样的体验?
- LintCode-第k大元素
- [Material Design] 教你做一个Material风格、动画的button(MaterialButton)
- Blogger建立Blog部落格​​ - Blog透视镜
- 长语音识别体验_如何为语音体验写作
- 2021年河南高考成绩排名查询一分一段表,2018河南高考一分一段统计表,查排名必备!...
- 搭建Cacti监控系统(三)-- 监控Linux 主机
- 大屏监控系统实战(2)-后台工程搭建
- 个人总结——学期总结
- 驴妈妈、途牛们该如何收割亲子游市场的红利?
- CSS动态样式---基础-控制是否添加CSS类
- python二进制处理详述
- Java中删除文件或文件夹的几种方法
- 万娟 白话大数据和机械学习_《白话大数据与机器学习》.pdf
- 我所认知的世界,不是Fragmention,而是Think
- Classic界面chatter中的子选项卡配置
- 洛谷—P3387 【模板】缩点
- soapui 乱码_接口测试-soapui-中文乱码总结
- QPSK调制解调过程,包括串并转换,电平转换,载波调制,相干解调,抽样判决等
热门文章
- simulink快捷键_从EPB模型谈谈Simulink代码生成
- Python选择结构注意事项
- Python类中公开方法、私有方法和特殊方法的继承原理
- 下载安装vs2019详细版
- 7.18自学c++笔记
- java注册系统服务_奇葩需求:springboot项目注册为windows系统服务并设置开机自启...
- mysql dump 导出表_误删库,别跑路!教你一招MySQL 数据恢复
- 计算机网络中的语法 语义 时序的概念,网络iso协议及语义语法时序详解
- java发邮件的框架_Java的Spring框架中实现发送邮件功能的核心代码示例
- c# 计算圆锥的体积_急求用c#计算圆柱体和圆锥体的体积的代码,下面是要求: