前言

最近在由 TensorFlow 迁移至 Pytorch, 不得不说,真的香啊。 在写模型的时候发现 Pytorch 中处理变长序列与 TensorFlow 有很大的不同, 因此此处谈谈我自己的理解。

此外, 我对 LSTM, GRU 进行了二次加工, 将对变长序列的处理封装到内部细节中,感兴趣的可以看看:NLP-Pytorch

从 LSTM 谈起[1]

首先, 注意到这里LSTM的计算公式与我们常见的LSTM有所区别,虽然区别不大,但还是要提一下,因为后面的参数初始化会有所不同:

class torch.nn.LSTM(*args, **kwargs)

-- 参数列表:

-- input_size: x 的特征维度

-- hidden_size: 隐层的特征维度

-- num_layers: LSTM 层数,默认为1

-- bias: 是否采用 bias, 如果为False,则不采用。默认为True

-- batch_first: True, 则输入输出的数据格式为 [batch_size, seq_len, feature_dim],默认为False

-- dropout: dropout会在除最后一层外都进行dropout, 默认为0

-- bidirectional: 是否采用双向,默认为False

-- 输入数据:

-- input: [seq_len, batch_size, input_size], 输入的特征矩阵

-- h_0: [num_layers * num_directions, batch_size, hidden_size], 初始时 h 状态, 默认为0

-- c_0: [num_layers * num_directions, batch_size, hidden_size], 初始时 cell 状态, 默认为0

-- 输出数据:

-- output: [seq_len, batch_size, num_directions * hidden_size], 最后一层的所有隐层输出

-- h_n : [num_layers * num_directions, batch, hidden_size], 所有层的最后一个时刻隐层状态

-- c_n : [num_layers * num_directions, batch, hidden_size], 所有层的最后一格时刻的 cell 状态

-- W,b参数:

-- weight_ih_l[k]: 与输入x相关的第k层权重 W 参数, W_ii, W_if, W_ig, W_io

-- weight_hh_l[k]: 与上一时刻 h 相关的第k层权重参数, W_hi, W_hf, W_hg, W_ho

-- bias_ih_l[k]: 与输入x相关的第k层 b 参数, b_ii, b_if, b_ig, b_io

-- bias_hh_l[k]: 与上一时刻 h 相关的第k层 b 参数, b_hi, b_hf, b_hg, b_ho

需要注意的一点是, LSTM中所有的W,b 参数默认采用均匀分布 :

, 有一些初始化方法能够加速收敛过程, 因此,很多情况下我们需要自己初始化这些参数,我在 NLP-Pytorch: LSTM 对LSTM 进行了简要封装。

对比 GRU [1]

同样为了更鲜明的表明参数的初始化, 这里将 GRU 搬过来:

class torch.nn.GRU(*args, **kwargs)

-- 参数列表:与 LSTM 的一致, 不赘述了

-- 输入序列:input, h_0; 与 LSTM 差不多,只是省略了 cell 状态

-- 输出序列:output, h_n; 与 LSTM 差不多,只是省略了 cell 状态

-- W,b参数:

-- weight_ih_l[k]: 与输入x相关的第k层权重 W 参数, W_ir, W_iz, W_in

-- weight_hh_l[k]: 与上一时刻 h 相关的第k层权重参数, W_hr, W_hz, W_hn

-- bias_ih_l[k]: 与输入x相关的第k层 b 参数, b_ir, b_iz, b_in

-- bias_hh_l[k]: 与上一时刻 h 相关的第k层 b 参数, b_hr, b_hz, b_hn

与LSTM 一样, W,b参数的初始化默认采用均匀分布 :

如何处理变长序列?[2]

我们知道,在文本的处理过程中,句子的长度是不一的,对于这种数据,我们往往采用 将每个句子扩充到一样的长度, 然后我们就可以使用LSTM或GRU来处理了。

但仔细一想, 又有些不对,我们将句子扩充,那么扩充的信息必然会对我们的结果产生影响,这与我们正常的思路完全不同,虽然我自己做比较实验表明,二者之间差距并不明显,但我个人认为这是数据集与初始化的关系。

那么,Pytorch 中如何处理这种变长的情况,去掉 呢,答案就是 torch.nn.utils.rnn.pack_padded_sequence()以及 torch.nn.utils.rnn.pad_packed_sequence()

压缩序列

压缩序列所使用的API为torch.nn.utils.rnn.pack_padded_sequence(), 其目的是将多余的 去除,获得一个干净的,最初的序列。

pack_padded_sequence(...)

-- 参数列表:

-- input: 有 的 batch 序列

-- lengths: input 中每个序列的长度

-- batch_first: 如果为True, input 必须是 [batch_size, seq_len, input_size], 参见LSTM

-- enforce_sorted: 如果为True, 那么 input 中的序列需要按照 长度递减排列

-- 返回值:

一个 PackedSequence 对象

我们获得干净的序列之后,就可以将其放入 LSTM 中了, 具体可参见我的实现:NLP-Pytorch: LSTM

解压序列

我们通过 LSTM 对压缩后的序列处理后,还需要将压缩后的信息解压缩,本质上是将数据从PackedSequence 类型转化为Tensor ,此部分主要是做 Attention 的时候会用到。

pad_packed_sequence(...)

-- 参数列表:

-- sequence: 一个PackedSequence 对象

-- batch_first:

-- padding_value: padding 该序列的values

-- total_length: Padding 到多长, 一般为None

-- Returns:

-- output: 有 padding 信息的序列输出

-- output_lengths: 每个序列没有Padding之前的长度

最后

最后,知乎上有一个很有趣的问题:你在训练RNN的时候有哪些特殊的trick, 十分值得一看,感兴趣的可以自己做一做相关的实验,ok, 就酱。

Reference

lstm 变长序列_Pytorch 是如何处理变长序列的相关推荐

  1. PAT甲级1020变体:已知二叉树层序+中序序列,求后序遍历序列

    PAT甲级1020变体:已知二叉树层序+中序序列,求后序遍历序列 题目 输入格式 输出格式 输入样例 输出样例 代码 题目 已知二叉树层序+中序序列,求后序遍历序列. 输入格式 第一行给出该二叉树的节 ...

  2. 快抖“变长”、爱优腾“变短”

    8月24日,抖音宣布将逐步开放15分钟的视频发布能力,此前抖音已多次局部提升视频时长:无独有偶,快手在上个月开始内测长视频功能,时长限制在57秒以上,10分钟以内. 短视频在变长前,长视频早已在变短. ...

  3. mag6000变送器怎么使用_变送器的迁移原理和故障分析

    一.差压液位计的工作原理 差压液位计是利用容器内的液位改变时,液柱产生的静压也相应变化的原理工作的,当差压变送器一端接液相,另一端接气相时,根据流体静力学原理计算P=ρgh(式中h为液位高度,ρ为被测 ...

  4. 如何处理电脑长时间未操作出现的假死?

    如何处理电脑长时间未操作出现的假死? 我们平时经常会遇到由于长时间未操作电脑,再使用时只有鼠标光标可以移动,桌面上的图标无法响应,包括任务栏的程序,那么我们应该怎么处理比较得当呢? 尝试使用光盘插拔动 ...

  5. 各种骚操作试试 V7变回V5试试,直接变胖FAT,刷。。。2020-10-26

    本文稿原创,未经许可授权,不得转载!!!! 机器买回来还是很兴奋的... 查资料,看说明,学技术....终于!从廋变成胖的了....兴奋! 买了二个,一个胖一个廋,两全了,接入AC,配置,一开始不成功 ...

  6. 用final关键字修饰一个变量时,是引用不能变,还是引用的对象不能变

    使用final关键字修饰一个变量时,是引用不能变,还是引用的对象不能变 答: 使用final关键字修饰一个变量时,是指引用变量不能变,引用变量所指向的对象中的内容还是可以改变的. 代码 public ...

  7. 1tensorflow 实现端到端的OCR:二代身份证号识别 + 2tensorflow LSTM+CTC实现端到端的不定长数字串识别

    1tensorflow 实现端到端的OCR:二代身份证号识别 链接地址:https://www.jianshu.com/p/803642d0d8f8 2tensorflow LSTM+CTC实现端到端 ...

  8. DNA计算 与 肽展公式 推导 AOPM-A 变胸腺苷, AOPM-O尿胞变腺苷, AOPM-P尿胞变鸟苷, AOPM-M鸟腺苷的 S形螺旋纹 血氧峰 触发器分子式 严谨完整过程

    作者 罗瑶光 随着VECS[15][8]-IDUQ[14][9][10][11][12]完整解码, AOPM[7]就简单了.准备描述下. 作者已经拥有 肽展公式[12] A = V + S O = E ...

  9. 最长不下降子序列O(NlogN) 输出序列

    文章目录 不输出序列的思路 输出序列 代码 例题:导弹拦截 不输出序列的思路 我们对于O(n2)O(n^2)O(n2)的最长不下降子序列十分熟悉了. #include <bits/stdc++. ...

  10. 第十一章 会打电话 天涯变咫尺 不会打电话 咫尺变天涯

    第十一章 会打电话 天涯变咫尺 不会打电话 咫尺变天涯 董事长算什么 朱经理没错,错在他"使用电话"的技巧. 想一想:朱经理错在哪里,错在他不应该已经拿起电话了,还在继续骂小邱. ...

最新文章

  1. Shiro 权限框架使用总结
  2. phpstorm 提示请配置PHP解释器的解决办法
  3. Linux chattr 与 lsattr命令
  4. 打脸往事!罗振宇2015年谈乐视、暴风 一口毒奶“奶死”不少人?
  5. 怎么在服务器添加充值网站,云服务器怎么弄充值
  6. 【紫书第九章】动态规划(DP)常见模型汇总与DP问题分析方法
  7. 计算机对用户的操作做出反应,云南省计算机二级VB考试真题题库
  8. Setup Factory 点击uninstall.exe Invalid start mode : archive filename
  9. 台达plc自由口通讯_台达PLC串行通讯应用原理
  10. 怎样自学python_怎样自学Python?
  11. 推荐一个Github上很酷的开源项目——The Octo-Bouncer
  12. 芯片数据分析步骤4 标准化-affy
  13. 有什么好用的语音转文字软件?介绍三个语音文件转文字的软件
  14. c语言三日通 下载,sama
  15. Codeforces D. Omkar and Bed Wars
  16. oracle为表空间增加数据文件,创建Oracle表空间,增加数据文件的步骤
  17. 第三章:fog(恐惧感 fear,责任感obligation,罪恶感guilty)
  18. 使用seleinum模块动态爬取熊猫直播平台全部的主播房间。
  19. 头文件 string.h cstring string 区别
  20. 收集的多家大公司Java面试题

热门文章

  1. 【TSP】基于matlab自适应动态邻域布谷鸟混合算法求解旅行商问题【含Matlab源码 1513期】
  2. 【图像增强】基于matlab GUI图像双边滤波【含Matlab源码 1492期】
  3. 【语言去噪】基于matlab GUI IIR+FIR滤波器语音去噪【含Matlab源码 1027期】
  4. 【TSP】基于matlab粒子群算法求解旅行商问题【含Matlab源码 445期】
  5. 【车间调度】基于matlab GUI遗传算法求解车间调度问题【含Matlab源码 049期】
  6. 西蒙决策_西蒙的象棋因子
  7. datatable高效写入mysql_如何将DataTable批量写入数据库
  8. 问题1:VS2017:找不到 Windows SDK 版本10.0.17134.0
  9. 随笔记--Pycharm中Terminal字体大小的设置
  10. Ubuntun系统查看系统版本和Python版本的方法