一、为什么RNN需要处理变长输入

假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示:

思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练样例长度不同的情况,这样我们就会很自然的进行padding,将短句子padding为跟最长的句子一样。

比如向下图这样:

但是这会有一个问题,什么问题呢?比如上图,句子“Yes”只有一个单词,但是padding了5的pad符号,这样会导致LSTM对它的表示通过了非常多无用的字符,这样得到的句子表示就会有误差,更直观的如下图:

那么我们正确的做法应该是怎么样呢?

这就引出pytorch中RNN需要处理变长输入的需求了。在上面这个例子,我们想要得到的表示仅仅是LSTM过完单词"Yes"之后的表示,而不是通过了多个无用的“Pad”得到的表示:如下图:

二、pytorch中RNN如何处理变长padding

主要是用函数torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这两个函数的用法。

这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

输入的形状可以是(T×B×* )。T是最长序列长度,B是batch size,*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)。

Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后(特别注意需要进行排序)。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

参数说明:

input (Variable) – 变长序列 被填充后的 batch

lengths (list[int]) – Variable 中 每个序列的长度。(知道了每个序列的长度,才能知道每个序列处理到多长停止

batch_first (bool, optional) – 如果是True,input的形状应该是B*T*size。

返回值:

一个PackedSequence 对象。一个PackedSequence表示如下所示:

具体代码如下:

embed_input_x_packed = pack_padded_sequence(embed_input_x, sentence_lens, batch_first=True)
encoder_outputs_packed, (h_last, c_last) = self.lstm(embed_input_x_packed)

此时,返回的h_last和c_last就是剔除padding字符后的hidden state和cell state,都是Variable类型的。代表的意思如下(各个句子的表示,lstm只会作用到它实际长度的句子,而不是通过无用的padding字符,下图用红色的打钩来表示):

但是返回的output是PackedSequence类型的,可以使用:

encoder_outputs, _ = pad_packed_sequence(encoder_outputs_packed, batch_first=True)

将encoderoutputs在转换为Variable类型,得到的_代表各个句子的长度。

三、总结

这样综上所述,RNN在处理类似变长的句子序列的时候,我们就可以配套使用torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来避免padding对句子表示的影响

参考:

pytorch中如何处理RNN输入变长序列padding相关推荐

  1. PyTorch中使用LSTM处理变长序列

    使用LSTM算法处理的序列经常是变长的,这里介绍一下PyTorch框架下使用LSTM模型处理变长序列的方法.需要使用到PyTorch中torch.nn.utils包中的pack_padded_sequ ...

  2. NLP中各框架对变长序列的处理全解

    ©PaperWeekly 原创 · 作者|海晨威 学校|同济大学硕士生 研究方向|自然语言处理 在 NLP 中,文本数据大都是变长的,为了能够做 batch 的训练,需要 padding 到相同的长度 ...

  3. evaluate函数使用无效_使用Keras和Pytorch处理RNN变长序列输入的方法总结

    最近在使用Keras和Pytorch处理时间序列数据,在变长数据的输入处理上踩了很多坑.一般的通用做法都需要先将一个batch中的所有序列padding到同一长度,然后需要在网络训练时屏蔽掉paddi ...

  4. lstm 变长序列_keras在构建LSTM模型时对变长序列的处理操作

    我就废话不多说了,大家还是直接看代码吧~ print(np.shape(X))#(1920, 45, 20) X=sequence.pad_sequences(X, maxlen=100, paddi ...

  5. PyTorch中的数据输入和预处理

    文章目录 PyTorch中的数据输入和预处理 数据载入类 映射类型的数据集 torchvision工具包的使用 可迭代类型的数据集 总结 PyTorch中的数据输入和预处理 数据载入类 在使用PyTo ...

  6. lstm 变长序列_基于变长时间间隔LSTM方法的胎儿异常体重预测

    1 介绍 预测胎儿体重是产前监护的重要内容, 是医生对孕妇进行临床处理的重要依据. 近年来研究显示, 低体重儿的存活率和扛感染能力相对低下[, 并且与低智商有密切联系[. 而巨大儿则会引起胎儿宫内窘迫 ...

  7. 在C#中如何定义一个变长的结构数组?如果定义好了,如何获得当前数组的长度?...

    用ArrayList,他就相当于动态数组,用add方法添加元素,remove删除元素,count计算长度

  8. [深度学习] Pytorch中RNN/LSTM 模型小结

    目录 一 Liner 二 RNN 三 LSTM 四 LSTM 代码例子 概念介绍可以参考:[深度学习]理解RNN, GRU, LSTM 网络 Pytorch中所有模型分为构造参数和输入和输出构造参数两 ...

  9. pytorch中RNN注意事项(关于input和output维度)

    pytorch中RNN注意事项 batch_first为False的情况下,认为input的数据维度是(seq,batch,feature),output的数据维度(seq,batch,feature ...

最新文章

  1. REST接口设计规范
  2. Spring Boot集成CKEditor
  3. 使用pm2启动node文件_使用 PM2 管理nodejs进程
  4. C#委托、事件、消息(入门级)
  5. centos5.8安装mysql_Centos5.8上面用Shell脚本一键安装mysql5.5.25源码包
  6. 多线程 进度条 C# .net
  7. hive表定义(3种方式)
  8. 我的Go+语言初体验——(4)零基础学习 Go+ 爬虫
  9. 无连接可靠传输_FPC连接器的特点以及弹片微针模组的作用
  10. phpstrom配置Xdebug
  11. Android的SharedPreferences存取String和List<String>类型(在Activity和Fragment内使用)
  12. 想知道账号被封的感觉么?
  13. kibana日志收集
  14. 力扣题解:面试题 02.03. 删除中间节点
  15. 小程序该怎么去做引流和变现呢
  16. 在剧里跟“刘亦菲”学营销:地推撬动“社交播传”
  17. Jenkins教程(六)脚本与方法执行效果不合预期,如何及时中止pipeline
  18. Python—re正则表达式
  19. 怎样阅读论文(台湾彭明辉)
  20. 新建小程序项目提示:登录用户不是该小程序的开发者

热门文章

  1. php会话的销毁和退出,销毁PHP会话
  2. Go 知识点(09)— for select 作用于 channel
  3. OpenCV 笔记(01)— OpenCV 概念、整体架构、各模块主要功能
  4. centos使用yum快速安装java的方法
  5. 如何学习数据挖掘和数据科学的7个步骤
  6. hadoop问题小结
  7. 矩阵乘法的性能提升 AutoKernel
  8. LeetCode简单题之好对数的数目
  9. 至强® 平台配备先进遥测技术让您的数据中心更智能
  10. 使用多个推理芯片需要仔细规划