1. 为什么要用pack_padded_sequence

在使用深度学习特别是RNN(LSTM/GRU)进行序列分析时,经常会遇到序列长度不一样的情况,此时就需要对同一个batch中的不同序列使用padding的方式进行序列长度对齐(可以都填充为batch中最长序列的长度,也可以设置一个统一的长度,对所有序列长截短填),方便将训练数据输入到LSTM模型进行训练,填充后一个batch的序列可以统一处理,加快速度。但是此时会有一个问题,LSTM会对序列中非填充部分和填充部分同等看待,这样会影响模型训练的精度,应该告诉LSTM相关序列的padding情况,让LSTM只对非填充部分进行运算。此时,pytorch中的pack_padded_sequence就有了用武之地。

其实有时候,可以填充后直接做,影响有时也不是很大,使用pack_padded_sequence后效果可能会更好。

结合例子分析:

如果不用pack和pad操作会有一个问题,什么问题呢?比如上图,句子“Yes”只有一个单词,但是padding了多余的pad符号,这样会导致LSTM对它的表示通过了非常多无用的字符,这样得到的句子表示就会有误差,更直观的如下图:

那么我们正确的做法应该是怎么样呢?

在上面这个例子,我们想要得到的表示仅仅是LSTM过完单词"Yes"之后的表示,而不是通过了多个无用的“Pad”得到的表示:如下图:

torch.nn.utils.rnn.pack_padded_sequence()

这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

其中pack的过程为:(注意pack的形式,不是按行压,而是按列压)

pack之后,原来填充的 PAD(一般初始化为0)占位符被删掉了。

输入的形状可以是(T×B×* )。T是最长序列长度,Bbatch size*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)

Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

NOTE: 只要是维度大于等于2的input都可以作为这个函数的参数。你可以用它来打包labels,然后用RNN的输出和打包后的labels来计算loss。通过PackedSequence对象的.data属性可以获取 Variable

参数说明:

  • input (Variable) – 变长序列 被填充后的 batch
  • lengths (list[int]) – Variable 中 每个序列的有效长度(即去掉pad的真实长度)。
  • batch_first (bool, optional) – 如果是True,input的形状应该是B*T*size

返回值:

一个PackedSequence 对象。

torch.nn.utils.rnn.pad_packed_sequence()

填充packed_sequence

上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。填充时会初始化为0。

返回的Varaible的值的size是 T×B×*T 是最长序列的长度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*

Batch中的元素将会以它们长度的逆序排列。

参数说明:

  • sequence (PackedSequence) – 将要被填充的 batch
  • batch_first (bool, optional) – 如果为True,返回的数据的格式为 B×T×*

返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表

2. 小案

假设有demo.txt文件,包含下面5段文本/序列:

txt = ["Some people like to choose those who are different from themselves while others prefer those who are similar to themselves.","People choose friends in differrent ways.","For instance, if an active and energetic guy proposes to his equally active and energetic friends that they should have some activities, it is more likely that his will agree at once.","When people have friends similar to themselves, they and their friends chat, play, and do thing together natually and harmoniously.","The result is that they all can feel relaxed and can trully enjoy each other's company."]

使用下面的脚本将单词转换为索引,并填充为统一的长度:

import numpy as np
import torch
import torch.nn as nnvocab = {} #词到索引的映射字典
token_id = 1 #token_id=0 预留给填充符号
lengths = [] #存储每个文本的实际长度for l in txt:tokens = l.strip().split() #这里对英文分词 简单的按空格切分。(当然可以使用一些效果更好的分词工具,可以把标点分出来)print(tokens)lengths.append(len(tokens))for t in tokens:if t not in vocab:vocab[t] = token_idtoken_id += 1x = np.zeros((len(lengths), max(lengths))) #所有文本填充为最大的长度
l_no = 0for l in txt:tokens = l.strip().split()for i in range(len(tokens)):x[l_no, i] = vocab[tokens[i]]l_no += 1print(x)
print(x.shape)x = torch.tensor(x,requires_grad=True)
lengths = torch.Tensor(lengths)
print("lenghts:",lengths)#所有文本长度按从大到小排序 (降序),返回排序后的索引idx_sort
_, idx_sort = torch.sort(torch.Tensor(lengths), dim=0, descending=True)
print("idx_sort:",idx_sort)
#对索引idx_sort进行从小到大排序 ,返回排序后的索引 idx_unsort
_, idx_unsort = torch.sort(idx_sort, dim=0)
print("idx_unsort:",idx_unsort)x1 = x[idx_sort]#x中的各个文本 随着排序 即最长的文本在第一行...
lengths1 = list(lengths[idx_sort])#此时各个文本对应的长度(从大到小排序后)
print("lenghts1:",lengths1)
print("x1的形状与内容:")
print(x1)
print(x1.shape)
x2=x1[idx_unsort]
print("x2的形状与内容:")
print(x2)
print(x2.shape)

输出:

['Some', 'people', 'like', 'to', 'choose', 'those', 'who', 'are', 'different', 'from', 'themselves', 'while', 'others', 'prefer', 'those', 'who', 'are', 'similar', 'to', 'themselves.']
['People', 'choose', 'friends', 'in', 'differrent', 'ways.']
['For', 'instance,', 'if', 'an', 'active', 'and', 'energetic', 'guy', 'proposes', 'to', 'his', 'equally', 'active', 'and', 'energetic', 'friends', 'that', 'they', 'should', 'have', 'some', 'activities,', 'it', 'is', 'more', 'likely', 'that', 'his', 'will', 'agree', 'at', 'once.']
['When', 'people', 'have', 'friends', 'similar', 'to', 'themselves,', 'they', 'and', 'their', 'friends', 'chat,', 'play,', 'and', 'do', 'thing', 'together', 'natually', 'and', 'harmoniously.']
['The', 'result', 'is', 'that', 'they', 'all', 'can', 'feel', 'relaxed', 'and', 'can', 'trully', 'enjoy', 'each', "other's", 'company.'][[ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14.  6.  7.  8. 15.4. 16.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.][17.  5. 18. 19. 20. 21.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.][22. 23. 24. 25. 26. 27. 28. 29. 30.  4. 31. 32. 26. 27. 28. 18. 33. 34.35. 36. 37. 38. 39. 40. 41. 42. 33. 31. 43. 44. 45. 46.][47.  2. 36. 18. 15.  4. 48. 34. 27. 49. 18. 50. 51. 27. 52. 53. 54. 55.27. 56.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.][57. 58. 40. 33. 34. 59. 60. 61. 62. 27. 60. 63. 64. 65. 66. 67.  0.  0.0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]]
(5, 32)lenghts: tensor([20.,  6., 32., 20., 16.])idx_sort: tensor([2, 0, 3, 4, 1])
idx_unsort: tensor([1, 4, 0, 2, 3])lenghts1: [tensor(32.), tensor(20.), tensor(20.), tensor(16.), tensor(6.)]
x1的形状与内容:
tensor([[22., 23., 24., 25., 26., 27., 28., 29., 30.,  4., 31., 32., 26., 27.,28., 18., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 33., 31.,43., 44., 45., 46.],[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,6.,  7.,  8., 15.,  4., 16.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,0.,  0.,  0.,  0.],[47.,  2., 36., 18., 15.,  4., 48., 34., 27., 49., 18., 50., 51., 27.,52., 53., 54., 55., 27., 56.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,0.,  0.,  0.,  0.],[57., 58., 40., 33., 34., 59., 60., 61., 62., 27., 60., 63., 64., 65.,66., 67.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,0.,  0.,  0.,  0.],[17.,  5., 18., 19., 20., 21.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,0.,  0.,  0.,  0.]], dtype=torch.float64, grad_fn=<IndexBackward>)
torch.Size([5, 32])
x2的形状与内容:
tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,6.,  7.,  8., 15.,  4., 16.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,0.,  0.,  0.,  0.],[17.,  5., 18., 19., 20., 21.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,0.,  0.,  0.,  0.],[22., 23., 24., 25., 26., 27., 28., 29., 30.,  4., 31., 32., 26., 27.,28., 18., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 33., 31.,43., 44., 45., 46.],[47.,  2., 36., 18., 15.,  4., 48., 34., 27., 49., 18., 50., 51., 27.,52., 53., 54., 55., 27., 56.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,0.,  0.,  0.,  0.],[57., 58., 40., 33., 34., 59., 60., 61., 62., 27., 60., 63., 64., 65.,66., 67.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,0.,  0.,  0.,  0.]], dtype=torch.float64, grad_fn=<IndexBackward>)
torch.Size([5, 32])

由x2与原始x的形状是一样的,主要是因为下面两行

idx_sort: tensor([2, 0, 3, 4, 1])
idx_unsort: tensor([1, 4, 0, 2, 3])

x_packed = nn.utils.rnn.pack_padded_sequence(input=x1, lengths=lengths1, batch_first=True)
print(x_packed)

需要注意的是,pack_padded_sequence函数的参数,lengths需要从大到小排序(length1),x1已根据长度大小排好序(最长的序列在第一行…),batch_first如果设置为true,则x的第一维为batch_size,第二维为seq_length,否则相反。
打印x_packed如下:

PackedSequence(data=tensor([22.,  1., 47., 57., 17., 23.,  2.,  2., 58.,  5., 24.,  3., 36., 40.,18., 25.,  4., 18., 33., 19., 26.,  5., 15., 34., 20., 27.,  6.,  4.,59., 21., 28.,  7., 48., 60., 29.,  8., 34., 61., 30.,  9., 27., 62.,4., 10., 49., 27., 31., 11., 18., 60., 32., 12., 50., 63., 26., 13.,51., 64., 27., 14., 27., 65., 28.,  6., 52., 66., 18.,  7., 53., 67.,33.,  8., 54., 34., 15., 55., 35.,  4., 27., 36., 16., 56., 37., 38.,39., 40., 41., 42., 33., 31., 43., 44., 45., 46.], dtype=torch.float64,grad_fn=<PackPaddedSequenceBackward>), batch_sizes=tensor([5, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 1, 1, 1, 1,1, 1, 1, 1, 1, 1, 1, 1]), sorted_indices=None, unsorted_indices=None)

他把x1的两个维度合并成了一个维度,原本x1(batch_size,max_seq_len)=(5,32),x_packed相当于对x1按列进行访问,并且忽略掉其中的填充值0;下面多出的batch_size有max_seq_len=32个数字,可以理解为对x1进行按列访问时,每一列非填充值的个数,可以看到刚开始的几列没有填充值(每个序列的开始部分),值为batch_size=5,后面由于有的序列不够长,逐渐出现填充值0,所以batch_size的大小逐渐变小<5,直到最后等于1,也就是只有那个batch中最长的序列还有非填充值,其余序列都是填充值0.

参考文献:

https://blog.csdn.net/sdu_hao/article/details/105408552

https://www.cnblogs.com/sbj123456789/p/9834018.html

https://www.cnblogs.com/luckyplj/p/13370072.html

Pytorch-> pack_padded_sequence()和pad_packed_sequence()相关推荐

  1. pack_padded_sequence 和 pad_packed_sequence

    参考: pytorch中如何处理RNN输入变长序列padding pack_padded_sequence 和 pad_packed_sequence 大佬们写的都非常nice啊

  2. 通过例子10分钟快速看懂pad_sequence、pack_padded_sequence以及pad_packed_sequence

    前言 import torch import torch.nn as nnfrom torch.nn.utils.rnn import pad_sequence from torch.nn.utils ...

  3. pack_padded_sequence和pad_packed_sequence详解

    先提供一个官网解读 https://pytorch.org/docs/1.0.1/nn.html#torch.nn.utils.rnn.pack_padded_sequence 在使用深度学习特别是L ...

  4. evaluate函数使用无效_使用Keras和Pytorch处理RNN变长序列输入的方法总结

    最近在使用Keras和Pytorch处理时间序列数据,在变长数据的输入处理上踩了很多坑.一般的通用做法都需要先将一个batch中的所有序列padding到同一长度,然后需要在网络训练时屏蔽掉paddi ...

  5. PyTorch中使用LSTM处理变长序列

    使用LSTM算法处理的序列经常是变长的,这里介绍一下PyTorch框架下使用LSTM模型处理变长序列的方法.需要使用到PyTorch中torch.nn.utils包中的pack_padded_sequ ...

  6. Pytorch使用实践,教程,库,调优,计算量,模型搭建

    参考文章: PyTorch官方教程中文版 http://pytorch123.com/ pytorch handbook是一本开源的书籍,目标是帮助那些希望和使用PyTorch进行深度学习开发和研究的 ...

  7. pytorch中torch.nn.utils.rnn相关sequence的pad和pack操作

    目录 一.pad_sequence 二.pack_padded_sequence 三.pad_packed_sequence 四.pack_sequence 自然语言处理任务中,模型的输入一般都是变长 ...

  8. NLP中各框架对变长序列的处理全解

    ©PaperWeekly 原创 · 作者|海晨威 学校|同济大学硕士生 研究方向|自然语言处理 在 NLP 中,文本数据大都是变长的,为了能够做 batch 的训练,需要 padding 到相同的长度 ...

  9. 让阿宅不再寂寞的聊天机器人

    阿宅爱上了阿美 在一个有星星的夜晚 飞机从头顶飞过 流星也划破那夜空 虽然说人生并没有什么意义 但是爱情确实让生活更加美丽 阿美嫁给了二富 在一个有香槟的晴天 豪车从眼前驶过 车笛也震动那烈阳 虽然说 ...

  10. 教你几招搞定 LSTMs 的独门绝技(附代码)

    本文为雷锋字幕组编译的技术博客,原标题 Taming LSTMs: Variable-sized mini-batches and why PyTorch is good for your healt ...

最新文章

  1. angular轮播图
  2. 关于动态规划,你想知道的都在这里了!
  3. java websocket修改为同步_初级Java程序员需要掌握哪些主流技术才能拿25K?
  4. Linux常用的50个命令
  5. 为什么 wait/notify/notifyAll 在 Object 类定义而不是 Thread 类?
  6. 从零开始搭建系统1.1——CentOs安装
  7. python numpy中ndarray.reshape函数参数-1是什么意思?(模糊控制、自动推理)
  8. JAVA struts2
  9. springboot默认数据源如何设置连接数_Spring Boot系列之配置数据库连接池
  10. 条件控制(if ) ( case)
  11. CentOS 安装go client调用Kubernetes API
  12. 数据结构—链表-单链表基本操作实现
  13. 【华为敏捷/DevOps实践】3. 如何开好站立会议
  14. 实现横向排列的几种方案
  15. python面板数据模型操作步骤_面板数据模型估计一般要做哪些步骤
  16. Python 手写体识别
  17. 证件照换底色+改变大小
  18. 1196踩方格—递推方法!
  19. mysql 清理relay日志_MySQL中binlog和relay log清理方式
  20. 获取Jenkins项目名称

热门文章

  1. 域控服务器导出证书,证书服务器(CA)的备份和还原
  2. GPU 编程与CG 语言之阳春白雪下里巴人——CG学习读书笔记之数学函数(之二)。
  3. 一味顺从的人没有好果子吃
  4. oj美元和人民币java_【牛客网OJ题】:人民币转换
  5. 麦吉尔大学计算机工程的世界排名,不只是知名大学:麦吉尔大学你需要知道这些!...
  6. 前端EChart图表转换为图片保存到服务器路径
  7. 休谟问题和金岳霖的回答
  8. 谷歌io大会2019_Google IO 2012的前5个精选
  9. 公安大数据系统具有哪些功能
  10. 量化金融入门笔记(一)