一、问题背景

  在NLP的相关任务中,我们使用RNN或LSTM处理文本序列时,通常来说句子的长度是不一致的,我们常常采用的方法使用< PAD >(0)来补全至相同长度的序列。虽然这个时候序列的长度是一致的,但是序列中填充了许多无效值 0 ,这个时候喂给 RNN 进行 forward 计算,不仅1.浪费计算资源,最后得到的值2.可能还会存在误差
  因此,为了解决这样的问题,在将序列送给 RNN 进行处理之前,需要采用 nn.utils.rnn.pack_padded_sequence 进行压缩,压缩掉无效的填充值。序列经过 RNN 处理之后的输出仍然是压紧的序列,需要采用 pad_packed_sequence 把压紧的序列再填充回来,便于进行后续的处理。

二、使用方法

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
import torch.nn.functional as F# 将数据转换到GPU上
def to_cuda(x, use_cuda=True):if use_cuda and torch.cuda.is_available():x = x.cuda()return xclass EncoderRNN(nn.Module):def __init__(self, vocab_size, embed_size, hidden_size, dropout=None, \bidirectional=False, shared_embed=None, init_word_embed=None, rnn_type='lstm', use_cuda=True):super(EncoderRNN, self).__init__()if not rnn_type in ('lstm', 'gru'):raise RuntimeError('rnn_type is expected to be lstm or gru, got {}'.format(rnn_type))if bidirectional:print('[ Using bidirectional {} encoder ]'.format(rnn_type))else:print('[ Using {} encoder ]'.format(rnn_type))if bidirectional and hidden_size % 2 != 0:raise RuntimeError('hidden_size is expected to be even in the bidirectional mode!')self.dropout = dropoutself.rnn_type = rnn_typeself.use_cuda = use_cudaself.hidden_size = hidden_size // 2 if bidirectional else hidden_sizeself.num_directions = 2 if bidirectional else 1self.embed = shared_embed if shared_embed is not None else nn.Embedding(vocab_size, embed_size, padding_idx=0)model = nn.LSTM if rnn_type == 'lstm' else nn.GRUself.model = model(embed_size, self.hidden_size, 1, batch_first=True, bidirectional=bidirectional)if shared_embed is None:self.init_weights(init_word_embed)def init_weights(self, init_word_embed):if init_word_embed is not None:print('[ Using pretrained word embeddings ]')self.embed.weight.data.copy_(torch.from_numpy(init_word_embed))else:self.embed.weight.data.uniform_(-0.08, 0.08)def forward(self, x, x_len):"""x: [batch_size * max_length]x_len: [batch_size] 45423"""x = self.embed(x)if self.dropout:x = F.dropout(x, p=self.dropout, training=self.training)print("x = ", x)sorted_x_len, indx = torch.sort(x_len, dim=-1, descending=True)# print(sorted_x_len)# sort_x_len 是数据的真实长度,由大到小(这里是因为pack_padded_sequence函数中enforce_sorted参数默认为True,则输入的长度序列必须是降序排列)x = pack_padded_sequence(x[indx], sorted_x_len.data.tolist(), batch_first=True)print(x)h0 = to_cuda(torch.zeros(self.num_directions, x_len.size(0), self.hidden_size), self.use_cuda)if self.rnn_type == 'lstm':c0 = to_cuda(torch.zeros(self.num_directions, x_len.size(0), self.hidden_size), self.use_cuda)packed_h, (packed_h_t, _) = self.model(x, (h0, c0))print("1:",packed_h)print("2:",packed_h_t)if self.num_directions == 2:packed_h_t = torch.cat([packed_h_t[i] for i in range(packed_h_t.size(0))], -1)else:packed_h, packed_h_t = self.model(x, h0)if self.num_directions == 2:packed_h_t = packed_h_t.transpose(0, 1).contiguous().view(query_lengths.size(0), -1)hh, out_len = pad_packed_sequence(packed_h, batch_first=True)print("hh =", hh, out_len)# restore the sorting,把压紧的序列再填充回来o, inverse_indx = torch.sort(indx, 0)print(o, " ", inverse_indx)restore_hh = hh[inverse_indx]# restore_packed_h_t = packed_h_t[inverse_indx]return restore_hh # , restore_packed_h_t

  假设我们的输入是:queries、query_lengths

queries = torch.tensor([[1,2,3,4,0],[2,3,4,5,6],[4,5,6,7,0],[5,6,0,0,0],[6,7,8,0,0]])
query_lengths = torch.tensor([4,5,4,2,3])Que_encoder = EncoderRNN(vocab_size=10, embed_size=4, hidden_size=4, \bidirectional=False, \rnn_type='lstm', \use_cuda=False)Q_r = Que_encoder(queries, query_lengths)
print("编码后的Que为:"Q_r)'''
其中的  x = pack_padded_sequence(x[indx], sorted_x_len.data.tolist(), batch_first=True), 输出的x为:
PackedSequence(data=tensor([[-0.0783, -0.0091, -0.0775,  0.0731],[ 0.0100, -0.0372, -0.0139, -0.0669],[ 0.0167,  0.0425, -0.0596, -0.0765],[ 0.0652, -0.0491,  0.0274, -0.0430],[ 0.0217,  0.0014,  0.0430, -0.0057],[-0.0014, -0.0411, -0.0739, -0.0768],[-0.0783, -0.0091, -0.0775,  0.0731],[ 0.0217,  0.0014,  0.0430, -0.0057],[ 0.0518,  0.0061,  0.0161,  0.0411],[ 0.0652, -0.0491,  0.0274, -0.0430],[ 0.0167,  0.0425, -0.0596, -0.0765],[-0.0014, -0.0411, -0.0739, -0.0768],[ 0.0652, -0.0491,  0.0274, -0.0430],[-0.0562,  0.0797, -0.0044, -0.0591],[ 0.0217,  0.0014,  0.0430, -0.0057],[ 0.0167,  0.0425, -0.0596, -0.0765],[ 0.0518,  0.0061,  0.0161,  0.0411],[ 0.0652, -0.0491,  0.0274, -0.0430]],grad_fn=<PackPaddedSequenceBackward>), batch_sizes=tensor([5, 5, 4, 3, 1]), sorted_indices=None, unsorted_indices=None)注:这样输入LSTM中就不包含0,每一个时间步中输入到lstm的batch大小分别为[5, 5, 4, 3, 1]。最终的Que的编码结果就是:[bts,seq_max,hidden_size] tensor([[[-0.1572, -0.0628,  0.1359,  0.0133],[-0.2072, -0.0917,  0.1704, -0.0067],[-0.2320, -0.1170,  0.1780, -0.0071],[-0.2383, -0.1252,  0.1847, -0.0125],[ 0.0000,  0.0000,  0.0000,  0.0000]],[[-0.1553, -0.0554,  0.1370, -0.0058],[-0.2124, -0.0962,  0.1708, -0.0013],[-0.2307, -0.1140,  0.1833, -0.0063],[-0.2367, -0.1202,  0.1781, -0.0072],[-0.2451, -0.1270,  0.1763, -0.0020]],[[-0.1570, -0.0598,  0.1424,  0.0083],[-0.2096, -0.0908,  0.1686,  0.0090],[-0.2342, -0.1123,  0.1744,  0.0117],[-0.2393, -0.1167,  0.1786,  0.0018],[ 0.0000,  0.0000,  0.0000,  0.0000]],[[-0.1567, -0.0560,  0.1341,  0.0124],[-0.2144, -0.0937,  0.1665,  0.0188],[ 0.0000,  0.0000,  0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000]],[[-0.1595, -0.0614,  0.1344,  0.0208],[-0.2122, -0.0900,  0.1692,  0.0147],[-0.2282, -0.1101,  0.1799, -0.0051],[ 0.0000,  0.0000,  0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000]]], grad_fn=<IndexBackward>)'''

三、pack_padded_sequence原理详解

pad_packed_sequence函数实际上是 pack_padded_sequence 函数的逆向操作。就是把压紧的序列再填充回来。

#详细的解释过程图如下:

【Pytorch】pack_padded_sequence与pad_packed_sequence实战详解相关推荐

  1. 《Unity 4 3D开发实战详解》一6.7 物理引擎综合案例

    本节书摘来异步社区<Unity 4 3D开发实战详解>一书中的第6章,第6.7节,作者: 吴亚峰 , 杜化美 , 张月霞 , 索依娜 责编: 张涛,更多章节内容可以访问云栖社区" ...

  2. R语言基于forestplot包可视化森林图实战详解:美化的森林图:自定义字体设置、置信区间、坐标轴(刻度、标签、范围)、无效线去除、水平线、辅助线、box形状、色彩等

    R语言基于forestplot包可视化森林图实战详解:美化的森林图:自定义字体设置.置信区间.坐标轴(刻度.标签.范围).无效线去除.水平线.辅助线.box形状.色彩等 目录

  3. R语言使用survminer包生存分析及可视化(ggsurvplot)实战详解:从数据集导入、生存对象生成、ggsurvplot可视化参数配置、设置、可视化对比

    R语言使用survminer包生存分析及可视化(ggsurvplot)实战详解:从数据集导入.生存对象生成.ggsurvplot可视化参数配置.设置.可视化对比 目录 R语言使用survminer包生 ...

  4. R语言tidyr包gather()函数实战详解:数据收缩、从宽表到窄表

    R语言tidyr包gather()函数实战详解:数据收缩.从宽表到窄表 目录 R语言tidyr包gather()函数实战详解:数据收缩.从宽表到窄表 收缩两列数据

  5. R语言tidyr包spread()函数实战详解:数据裂变、从窄表到宽表

    R语言tidyr包spread()函数实战详解:数据裂变.从窄表到宽表 目录 R语言tidyr包spread()函数实战详解:数据裂变.从窄表到宽表

  6. R语言tidyr包Unite()函数实战详解:多个数据列合并为一列

    R语言tidyr包Unite()函数实战详解:多个数据列合并为一列 目录 R语言tidyr包Unite()函数实战详解:多个数据列合并为一列

  7. R语言tidyr包separate()函数实战详解:一列裂变为多列

    R语言tidyr包separate()函数实战详解:一列裂变为多列 目录 R语言tidyr包separate()函数实战详解:一列裂变为多列 一列裂变为两列

  8. 《oracle大型数据库系统在AIX/unix上的实战详解》讨论31: oracle、sybase 数据库的不同访问...

    <Oracle大型数据库系统在AIX/UNIX上的实战详解> 讨论31:  oracle.sybase 数据库的不同访问方式   文平. 用户来信要求更细节比较一下Oracle和sybas ...

  9. 《Java和Android开发实战详解》——2.5节良好的Java程序代码编写风格

    本节书摘来自异步社区<Java和Android开发实战详解>一书中的第2章,第2.5节良好的Java程序代码编写风格,作者 陈会安,更多章节内容可以访问云栖社区"异步社区&quo ...

  10. python 自动化-Python API 自动化实战详解(纯代码)

    主要讲如何在公司利用Python 搞API自动化. 1.分层设计思路 dataPool :数据池层,里面有我们需要的各种数据,包括一些公共数据等 config :基础配置 tools : 工具层 co ...

最新文章

  1. JetsonTX2上安装tensorflow的心酸史
  2. CMake手册详解 (十二)
  3. 【数据竞赛】2020腾讯广告算法大赛冠军方案分享及代码
  4. 【学习笔记】opencv的python接口 形态学操作 腐蚀 膨胀 通用形态学函数
  5. ikvm java转换成dll_利用IKVM.NET将Java jar包转换成可供C#调用的dll文件
  6. OpenMP 多核编程(转载)
  7. [Winform]只允许运行一个exe,如果已运行则将窗口置前
  8. 处理机调度的概念、层次
  9. Python画一个国旗
  10. Kubernetes1.91(K8s)安装部署过程(六)--node节点部署
  11. 人脸识别技术大起底,你了解多少?
  12. QQ电脑管家 vs 360 安全助手 (客观+主观)
  13. 如何修改计算机无线mac地址,如何修改无线网卡物理地址,求指教
  14. 钉钉机器人关键字自动回复_【原创新软件】办公引流机器人个人微信企业微信QQ通用的自动回复,群发助手...
  15. ipad怎样和计算机连接网络,ipad怎样连接电脑itunes
  16. 计算机应用的时间识别的,计算机人工智能识别关键技术及运用
  17. mysql语句大小写要求_mysql踩坑(一)-大小写规则
  18. IE浏览器下载文本文件(txt,csv等)
  19. 从首个「数实融合」公益球场,看元宇宙奏响创新「三重奏」
  20. 常见算法思想2:递推法

热门文章

  1. 某连锁酒店集团实行积分奖励计划,会员每次入住集团旗下酒店均可以获得一定积分,积分由欢迎积分加消费积分构成。其中欢迎积分跟酒店等级有关,具体标准如表2-1所示;消费积分跟每次入住消费金额有关,具体标准为
  2. npm 报错 426 Upgrade Required
  3. 【TWS使用系列1】如何从TWS的自选列表中添加/删除自选股?
  4. 蓦然回首,会员制CRM就在下里巴人处
  5. C++解压zip压缩文件
  6. Linux固态硬盘 设置写入缓存,Win10下的写入缓存策略严重影响SSD硬盘的性能!
  7. OSChina 周三乱弹 ——发福利的日子到了!来领妹子!
  8. 常用的27个Stata命令
  9. 豆果美食APP,看一下都给[Python爬虫爱好者]提供了哪些接口
  10. 如何查看网页元素的名称ID和其他信息