Pytorch中如何理解RNN LSTM GRU的input(重点理解seq_len / time_steps)
在建立时序模型时,若使用keras,我们在Input的时候就会在shape内设置好sequence_length(后面简称seq_len),接着便可以在自定义的data_generator内进行个性化的使用。这个值同时也就是time_steps,它代表了RNN内部的cell的数量,有点懵的朋友可以再去看看RNN的相关内容:个人总结:从RNN(内含BPTT以及梯度消失/爆炸)到 LSTM(内含GRU)Seq2Seq Attention
所以设定好这个值是很重要的事情,它和batch_size,feature_dimensions(在词向量的时候就是embedding_size了)构成了我们Input的三大维度,无论是keras/tensorflow,亦或是Pytorch,本质上都是这样。
牵涉到这个问题是听说Pytorch自由度更高,最近在做实验的时候开始尝试用Pytorch了,写完代码跑通后,过了段时间才意识到,好像没有用到seq_len这个参数,果然是Keras用多了的后遗症?!(果然是博主比较蠢!)检查了一下才发现,DataLoader生成数据的时候,默认生成为(batch_size, 1, feature_dims)。(这里无视了batch_size和seq_len的顺序,在建立模型的时候,比如nn.LSTM有个batch_first的参数,它决定了谁前谁后,但这不是我们这里讨论的重点)。
所以我们的seq_len/time_steps被默认成了1,这是在使用Pytorch的时候容易发生的问题,由于Keras先天的接口设置在Input时就让我们无脑设置seq_len,这反而不会成为我们在使用Keras时发生的问题,而Pytorch没有让我们在哪里设置这个参数,所以一不小心可能就忽视了。
好了,接下来就来找找问题怎么出现的,又怎么解决。果然问题还是出现在了DataLoader,在__getitem__(self, index)这里,决定了我们如何取出数据,这里我发现我自己还是用的最简单的方式一条一条取的。
def __getitem__(self, idx):return self.input[idx], self.target[idx]
完全没有意识到Torch需要在这里进行seq_len的修饰,接下来该怎么解决呢,首先看看我们希望的“取数据方式”。
假如我们有id = 1,2,3,4,5,6,7,8,9,10一共10个sample。
假设我们设定seq_len是3。
那现在数据的形式应该为1-2-3,2-3-4,3-4-5,4-5-6,5-6-7,6-7-8,7-8-9,8-9-10,9-10-0,10-0-0(最后两个数据不完整,进行补零)的10个数据。这是我们真正有了seq_len这个参数,带有“循环”这个概念,要放进RNN等序列模型中进行处理的数据。所以之前说seq_len被我默认弄成了1,那就是把1,2,3,4,5,6,7,8,9,10这样形式的10个数据分别放进了模型训练,自然在DataLoader里取数据的size就成了(batch_size, 1, feature_dims),而我们现在取数据才会是(batch_size, 3, feature_dims)。
假设我们设定batch_size为2。
那我们取出第一个batch为1-2-3,2-3-4。这个batch的size就是(2,3,feature_dims)了。我们把这个玩意儿喂进模型。
接下来第二个batch为3-4-5,4-5-6。
第三个batch为5-6-7,6-7-8。
第四个batch为7-8-9,8-9-10。
第五个batch为9-10-0,10-0-0。我们的数据一共生成了5个batch。
可以看到,num_batch = num_samples / batch_size(这里没有进行向上或向下取整是因为在某些地方可以设置是否需要那些不完整的被进行补零的batch),seq_len仍然不会影响最后生成的batch的数量,只有batch_size和num_samples会对batch的数量进行影响。
可能忽略了feature_dims仅凭借id来代表数据难以理解,那换种方式看看,假如feature_dims为6:
data_ = [[1, 10, 11, 15, 9, 100],[2, 11, 12, 16, 9, 100],[3, 12, 13, 17, 9, 100],[4, 13, 14, 18, 9, 100],[5, 14, 15, 19, 9, 100],[6, 15, 16, 10, 9, 100],[7, 15, 16, 10, 9, 100],[8, 15, 16, 10, 9, 100],[9, 15, 16, 10, 9, 100],[10, 15, 16, 10, 9, 100]]
仍然设置seq_len为3,batch_size为2。
这时我们的第一个batch为
tensor([[[ 1., 10., 11., 15., 9., 100.],[ 2., 11., 12., 16., 9., 100.],[ 3., 12., 13., 17., 9., 100.]],[[ 2., 11., 12., 16., 9., 100.],[ 3., 12., 13., 17., 9., 100.],[ 4., 13., 14., 18., 9., 100.]]])
这就是刚刚的1-2-3,2-3-4嘛。
而最后一个batch为
tensor([[[ 9., 15., 16., 10., 9., 100.],[ 10., 15., 16., 10., 9., 100.],[ 0., 0., 0., 0., 0., 0.]],[[ 10., 15., 16., 10., 9., 100.],[ 0., 0., 0., 0., 0., 0.],[ 0., 0., 0., 0., 0., 0.]]])
最后放上Demo,由于每个人的数据甚至loss等等都不一样,不过大家应该能够从Demo中得到一些如何针对自己的project进行修改的点子。
# -*- coding: utf-8 -*-import torch
import torch.utils.data as Data
import torch.nn as nn
import torchvision.transforms as transforms
import numpy as np
### Demo datasetdata_ = [[1, 10, 11, 15, 9, 100],[2, 11, 12, 16, 9, 100],[3, 12, 13, 17, 9, 100],[4, 13, 14, 18, 9, 100],[5, 14, 15, 19, 9, 100],[6, 15, 16, 10, 9, 100],[7, 15, 16, 10, 9, 100],[8, 15, 16, 10, 9, 100],[9, 15, 16, 10, 9, 100],[10, 15, 16, 10, 9, 100]]### Demo Dataset classclass DemoDatasetLSTM(Data.Dataset):"""Support class for the loading and batching of sequences of samplesArgs:dataset (Tensor): Tensor containing all the samplessequence_length (int): length of the analyzed sequence by the LSTMtransforms (object torchvision.transform): Pytorch's transforms used to process the data"""## Constructordef __init__(self, dataset, sequence_length=1, transforms=None):self.dataset = datasetself.seq_len = sequence_lengthself.transforms = transforms## Override total dataset's length getterdef __len__(self):return self.dataset.__len__()## Override single items' getterdef __getitem__(self, idx):if idx + self.seq_len > self.__len__():if self.transforms is not None:item = torch.zeros(self.seq_len, self.dataset[0].__len__())item[:self.__len__()-idx] = self.transforms(self.dataset[idx:])return item, itemelse:item = []item[:self.__len__()-idx] = self.dataset[idx:]return item, itemelse:if self.transforms is not None:return self.transforms(self.dataset[idx:idx+self.seq_len]), self.transforms(self.dataset[idx:idx+self.seq_len])else:return self.dataset[idx:idx+self.seq_len], self.dataset[idx:idx+self.seq_len]### Helper for transforming the data from a list to Tensordef listToTensor(list):tensor = torch.empty(list.__len__(), list[0].__len__())for i in range(list.__len__()):tensor[i, :] = torch.FloatTensor(list[i])return tensor### Dataloader instantiation# Parameters
seq_len = 3
batch_size = 2
data_transform = transforms.Lambda(lambda x: listToTensor(x))dataset = DemoDatasetLSTM(data_, seq_len, transforms=data_transform)
data_loader = Data.DataLoader(dataset, batch_size, shuffle=False)for data in data_loader:x, _ = dataprint(x)print('\n')
Pytorch中如何理解RNN LSTM GRU的input(重点理解seq_len / time_steps)相关推荐
- RNN,LSTM,GRU基本原理的个人理解重点
20210626 循环神经网络_霜叶的博客-CSDN博客 LSTM的理解 - 走看看 重点 深入LSTM结构 首先使用LSTM的当前输入 (x^t)和上一个状态传递下来的 (h^{t-1}) 拼接训练 ...
- RNN LSTM GRU 代码实战 ---- 简单的文本生成任务
RNN LSTM GRU 代码实战 ---- 简单的文本生成任务 import torch if torch.cuda.is_available():# Tell PyTorch to use the ...
- RNN, LSTM, GRU, SRU, Multi-Dimensional LSTM, Grid LSTM, Graph LSTM系列解读
RNN/Stacked RNN rnn一般根据输入和输出的数目分为5种 一对一 最简单的rnn 一对多 Image Captioning(image -> sequence of words) ...
- RNN,LSTM,GRU计算方式及优缺点
本文主要参考李宏毅老师的视频介绍RNN相关知识,主要包括两个部分: 分别介绍Navie RNN,LSTM,GRU的结构 对比这三者的优缺点 1.RNN,LSTM,GRU结构及计算方式 1.1 Navi ...
- DL之RNN/LSTM/GRU:RNN/LSTM/GRU算法动图对比、TF代码定义之详细攻略
DL之RNN/LSTM/GRU:RNN/LSTM/GRU算法动图对比.TF代码定义之详细攻略 目录 RNN.LSTM.GRU算法对比 1.RNN/LSTM/GRU对比 2.RNN/LSTM/GRU动图 ...
- DL之LSTM:LSTM算法论文简介(原理、关键步骤、RNN/LSTM/GRU比较、单层和多层的LSTM)、案例应用之详细攻略
DL之LSTM:LSTM算法论文简介(原理.关键步骤.RNN/LSTM/GRU比较.单层和多层的LSTM).案例应用之详细攻略 目录 LSTM算法简介 1.LSTM算法论文 1.1.LSTM算法相关论 ...
- 图解 RNN, LSTM, GRU
参考: Illustrated Guide to Recurrent Neural Networks Illustrated Guide to LSTM's and GRU's: A step by ...
- python中size_x的意思,对pytorch中x = x.view(x.size(0), -1) 的理解说明
在pytorch的CNN代码中经常会看到 x.view(x.size(0), -1) 首先,在pytorch中的view()函数就是用来改变tensor的形状的,例如将2行3列的tensor变为1行6 ...
- [PyTorch] rnn,lstm,gru中输入输出维度
本文中的RNN泛指LSTM,GRU等等 CNN中和RNN中batchSize的默认位置是不同的. CNN中:batchsize的位置是position 0. RNN中:batchsize的位置是pos ...
最新文章
- Marshal.ReleaseComObject
- 电脑主板维修_自学电脑主板维修第48讲
- python csv使用_python CSV模块的使用
- 收藏 : 50个Excel逆天功能,一秒变“表哥”
- P2513-[HAOI2009]逆序对数列【dp,前缀和】
- java坦克大战源码下载
- UrlRewriter 伪url的配置
- python中or是什么意思-Python 中 (,|)和(and,or)之间的区别
- psql json操作符合函数
- 云计算之openstack(N版)neutron网络服务最佳实践
- php文件包含漏洞的危害,php文件包含漏洞小结
- python爬虫执行js代码_python爬虫执行js代码-execjs
- PHP100的php教程批量打包下载
- 面试题:CSS3实现折角效果
- 该微信用户未开启“公众号安全助手”的消息接收功能,请先开启后再绑定的解决办法
- php短信功能实现原理,基于信息熵原理分词的php实现
- [初级理论]给老婆做测试培训-02
- 解决Virtualbox安装系统界面显示不全问题
- ALV中的回车事件相应及添加F4帮助
- 截图工具GifCam简单使用教程
热门文章
- 计算机硬盘硬件的配置问题,磁盘硬件配置问题windows无法正常启动怎么解决
- 谷歌翻拉取别的分支_如何将品牌分支机构的位置添加到Google地图
- matlab 向量变标量,MATLAB变量——标量,向量,矩阵
- Automatic differentiation in PyTorch
- 04_Initial Design/Floorplan实操2021-09-08上午
- 洛谷P1868 饥饿的奶牛 题解
- USACO 奶牛食品(最大流)
- 数据分析之路的尽头是创业?
- 公众号留言板怎么开通
- 微信公众号留言评论功能最新开通信息讲解(内附留言功能开通视频信息讲解链接)...