在pytorch中,经常会需要通过batch进行批量处理数据,由于每个batch中各个样本之间存在差异,经常会需要进行先padding后mask的操作。

尤其是在自然语言处理任务中,每个batch中的每个句子是不等长的。一般都是先通过填充0的方式将每个batch中每一句padding成和最长的句子等长的形式;再模型中或者计算loss的时候,再将padding成0的部分mask掉,从而避免padding带来的影响。

所以,每个batch在读取数据时,需要保存下该batch中每个句子的长度(以及最长的句子长度),也就是 lengths 和 max_len;通过保存的句子长度就可以得到每个batch对应的mask,然后就可以对数据输出进行mask操作。

如何通过batch中每个句子的长度lengths 和最长的句子长度 max_len 获得mask呢?可以通过如下的自定义函数 get_mask_from_lengths(lengths, max_len) 获得。具体如下:

import torchdef get_mask_from_lengths(lengths, max_len=None):'''param:lengths --- [Batch_size]return:mask --- [Batch_size, max_len]'''batch_size = lengths.shape[0]  if max_len is None:max_len = torch.max(lengths).item()  ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1)# ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).cuda()  ## 实际需要注意devicemask = ids >= lengths.unsqueeze(1).expand(-1, max_len)    ## True 或 Falsereturn mask## 实例
## batch_size = 4 , 每句话的长度分别是 [2, 4, 3, 2]
lengths = torch.tensor([2,4,3,2])
mask = get_mask_from_lengths(lengths)
print(mask)tensor([[False, False,  True,  True],[False, False, False, False],[False, False, False,  True],[False, False,  True,  True]])

实际应用场景:

比如,语音合成时,每个batch中每个句子的音素长度每个batch中每个句子的梅尔谱的长度等等。

pytorch 中 利用自定义函数 get_mask_from_lengths(lengths, max_len)获取每个batch的mask相关推荐

  1. oracle体育成绩字段,在Excel中利用自定义函数处理体育达标成绩

    一.建立标准查分表 首先是根据<国家体育锻炼标准评分表>以16岁男子(高中一年级)为例,在Excel中建立标准评分表,把工作表命名为"评分表",建立该表的目的是为了编制 ...

  2. Py之matplotlib:在matplotlib库中利用legend函数创建自定义图例(代码实现)

    Py之matplotlib:在matplotlib库中利用legend函数创建自定义图例(代码实现) 目录 matplotlib库中利用legend函数创建自定义图例 原始图像 在原始图像上创建自定义 ...

  3. 关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题

    关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题 Hook 是 PyTorch 中一个十分有用的特性.利用它,我们可以不必改变网络输入输出的结构, ...

  4. Pytorch中的collate_fn函数用法

    Pytorch中的collate_fn函数用法 官方的解释:   Puts each data field into a tensor with outer dimension batch size ...

  5. javascript利用自定义函数向页面输出自定义的表格,在调用函数时通过传递的参数指定表格的行数

    利用自定义函数向页面中输出自定义的表格 <!DOCTYPE html> <html lang="en"> <head><script ty ...

  6. ffmpeg php 快速播放,怎么在PHP中利用FFmpeg函数对视频播放的时长进行获取

    怎么在PHP中利用FFmpeg函数对视频播放的时长进行获取 发布时间:2020-12-18 16:02:20 来源:亿速云 阅读:96 作者:Leah 这篇文章给大家介绍怎么在PHP中利用FFmpeg ...

  7. Entity Framework 6 Recipes 2nd Edition(10-5)译 - 在存储模型中使用自定义函数

    10-5. 在存储模型中使用自定义函数 问题 想在模型中使用自定义函数,而不是存储过程. 解决方案 假设我们数据库里有成员(members)和他们已经发送的信息(messages) 关系数据表,如Fi ...

  8. 在 Apache Spark 中利用 HyperLogLog 函数实现高级分析

    在 Apache Spark 中利用 HyperLogLog 函数实现高级分析 预聚合是高性能分析中的常用技术,例如,每小时100亿条的网站访问数据可以通过对常用的查询纬度进行聚合,被降低到1000万 ...

  9. php 模板 自定义函数调用,thinkphp模板中使用自定义函数

    注意:自定义函数要放在项目应用目录/common/common.php中. 这里是关键. 模板变量的函数调用格式:{$varname|function1|function2=arg1,arg2,### ...

最新文章

  1. java高深技术总结_一名25K以上的高薪Java程序员总结出的技术以及学习技能
  2. Spring Boot Admin 2.5.5 发布,支持在线重启服务
  3. Tensorboard安装和访问(pytorch+MobaXterm)
  4. spring源码分析之spring-core-env
  5. IOS学习之UINavigationController详解与使用(一)添加UIBarButtonItem
  6. .NET分布式大规模计算利器-Orleans(一)
  7. python预处理标准化_tensorflow预处理:数据标准化的几种方法
  8. 2021高考厦门一中成绩查询,2021年厦门中考成绩和分数线什么时候公布(附查询入口)...
  9. G - Hard problem CodeForces - 706C DP
  10. getch和getchar的区别
  11. 安装了Python2.X和Python3.X后Python2.X IDLE打不开解决办法总结
  12. NWT内斗:为了还不值钱的股份
  13. 文件版本转换( AutoCAD、3dMax、SketchUp高版本转低版本 )
  14. Opencv python之车辆识别项目(附代码)
  15. 数据库同步利器 otter 双A同步配置
  16. Mysql的远程连接设置
  17. 获取元素在屏幕的相对位置
  18. oracle序列号的使用
  19. js 人民币小写金额转换为大写
  20. 二向箔-百日打卡writeup26-30

热门文章

  1. 谷歌自动驾驶正式入华,能否掀起“鲶鱼效应”?
  2. ruby中的符号_Ruby中的凡人和不朽符号
  3. 番红-固绿染色(植物)
  4. matlab产生均匀白噪声,各种分布白噪声的产生matlab.pdf
  5. Camera2 APP Flash 打闪流程及原理分析
  6. wex5 发布apk以及更新
  7. 《中国通史》纪录片100集笔记(持更)
  8. 苹果蓝牙连接不上是什么原因_无线网连接不上 原因很多,总有一个办法解决你的问题...
  9. 手机闪存速度测试工具,AndroBench
  10. 挂载ISO镜像文件作为本地yum源