请参考Dual-Path-RNN-Pytorch的网络架构图。 这里我们单独把Segmentation部分拿来分析。 (文件:model_rnn.py

到达Segmentation时,输入的张量维度为[B,N,L], 其中B为Batch Size, N为特征维度, L为特征长度。


主体函数

函数主体如下,输入首先做padding

padding

padding函数如下:

该函数主要做两件事

  1. 将最后一维(特征长度)对齐,即如果长度要与P(hop size)的整数倍对齐
  2. 将对齐后张量前后各补长度为P的0

其中gap = K - (P + L % K) % K看着有点云里雾里,它的作用就是计算L的长度与P的奇数倍的差距(gap);如果有gap,就在最后补上,这样做的好处在下一小节再说。

假设输入长度(改为1维图像便于理解)为深蓝色,那么蓝色部分即为补全的gap;对齐后整个长度为**(2n+1)K**,然后两边各有长度为P的padding

Segmentation

padding后的输入张量最后一维是 (2n+3)P
然后对该张量分别取input[:, :, :- P]和input[:, :, P:],这样每个分割出来的最后一维长度都是 (2n+2)P(n+1)K,
这样将K再单独抽取一维出来,令l=n+1,
input1.shape=[B,N,l,K]input2.shape=[B,N,l,K]input1.shape=[B,N,l,K] \newline input2.shape=[B,N,l,K] input1.shape=[B,N,l,K]input2.shape=[B,N,l,K]
这样抽取是什么意义?

根据paper中的介绍,这里是分离出chunk,每一个K长度为一个chunk;

再将两部分chunk做合并,即==[B,N,2l,K] ==,这又是什么意思?

原因是这两部分的chunk是差了P取的,所以合并的时候每个chunk变成了2l,但是前后保持了P的hopping。

转置

最后将最后两维转置,目的应该是方便后续的处理.即[B,N,K,2l]
将S=2l,
shape = ==[B,N,K,S] ==

论文配图

我觉得论文里的配图还是挺清晰的,放在这里也同时希望方便大家理解

源码部分

下面把源码部分抽取出来,单独测试,

def padding(input, K):'''padding the audio timesK: chunks of lengthP: hop sizeinput: [B, N, L]'''B, N, L = input.shapeP = K // 2gap = K - (P + L % K) % Kprint(f'gap={K} - ({P}+{L}%{K})%{K}={gap}')if gap > 0:pad = torch.Tensor(torch.zeros(B, N, gap)).type(input.type())input = torch.cat([input, pad], dim=2)_pad = torch.Tensor(torch.zeros(B, N, P)).type(input.type())input = torch.cat([_pad, input, _pad], dim=2)return input, gapdef Segmentation(input, K):'''the segmentation stage splitsK: chunks of lengthP: hop sizeinput: [B, N, L]output: [B, N, K, S]'''B, N, L = input.shapeP = K // 2input, gap = padding(input, K)print('after padding: input.shape ',input.shape)# [B, N, K, S]input1 = input[:, :, :-P].contiguous().view(B, N, -1, K)print('input[:, :, :-P] shape ', input[:, :, :-P].shape)print('input1.shape ',input1.shape)input2 = input[:, :, P:].contiguous().view(B, N, -1, K)print('input[:, :, P:] shape ', input[:, :, P:].shape)print('input2.shape ',input2.shape)input = torch.cat([input1, input2], dim=3).view(B, N, -1, K).transpose(2, 3)print()return input.contiguous(), gapinput = torch.linspace(1, 800, 1200).view(1, 10, 120)K = 200output, gap = Segmentation(input, K)print(output.shape, gap)

【Dual-Path-RNN-Pytorch源码分析】Segmentation相关推荐

  1. PyTorch 源码分析:Optimizer类

    PyTorch对Optimizer类的实现大部分都在Python上,只有计算用到了C++的部分,所以还是可以继续分析的. 总览 Optimizer类是所有具体优化器类的一个基类.下面一幅图表示一下. ...

  2. Pytorch源码分析

    目录 命名空间/类/方法/函数/变量 torch.autograd.Function中的ctx参数 DDP(DistributedDataParallel)的构造函数 torch.floor(inpu ...

  3. Pytorch Mobile 之Android Demo源码分析

    现如今,在边缘设备上运行机器学习/深度学习变得越来越流行,它需要更低的时延. 而从Pytorch 1.3开始,我们就可以使用Pytorch将模型部署到Android或者ios设备中. Pytorch官 ...

  4. PyTorch 源码解读之 torch.serialization torch.hub

    作者 | 123456 来源 | OpenMMLab 编辑 | 极市平台 导读 本文解读基于PyTorch 1.7版本,对torch.serialization.torch.save和torch.hu ...

  5. Transformer-XL解读(论文 + PyTorch源码)

    前言 目前在NLP领域中,处理语言建模问题有两种最先进的架构:RNN和Transformer.RNN按照序列顺序逐个学习输入的单词或字符之间的关系,而Transformer则接收一整段序列,然后使用s ...

  6. 【Android SDM660源码分析】- 03 - UEFI XBL GraphicsOutput BMP图片显示流程

    [Android SDM660源码分析]- 03 - UEFI XBL GraphicsOutput BMP图片显示流程 1. GraphicsOutput.h 2. 显示驱动初化 DisplayDx ...

  7. NeRF 源码分析解读(一)

    NeRF 源码解读(一) 前言 NeRF 是三维视觉中新视图合成任务的启示性工作,最近领域内出现了许多基于 NeRF 的变种工作.本文以pytorch 版 NeRF 作为基础对 NeRF 的代码进行分 ...

  8. TCP拥塞控制算法BBR源码分析

      BBR是谷歌与2016年提出的TCP拥塞控制算法,在Linux4.9的patch中正式加入.该算法一出,瞬间引起了极大的轰动.在CSDN上也有众多大佬对此进行分析讨论,褒贬不一.   本文首先对源 ...

  9. 【Golang源码分析】Go Web常用程序包gorilla/mux的使用与源码简析

    目录[阅读时间:约10分钟] 一.概述 二.对比: gorilla/mux与net/http DefaultServeMux 三.简单使用 四.源码简析 1.NewRouter函数 2.HandleF ...

  10. ELMo解读(论文 + PyTorch源码)

    ELMo的概念也是很早就出了,应该是18年初的事情了.但我仍然是后知后觉,居然还是等BERT出来很久之后,才知道有这么个东西.这两天才仔细看了下论文和源码,在这里做一些记录,如果有不详实的地方,欢迎指 ...

最新文章

  1. Android中的MVP模式初步使用
  2. asp.net 浏览服务器文件
  3. C++ STL stirng的复制比较
  4. Linux当前终端走代理ip
  5. 【数据结构】----C语言实现栈操作
  6. oracle authentication_services,SQLNET.AUTHENTICATION_SERVICES= (NTS) 解释
  7. html无节日为空,这个生死相拥的节日_311.Html
  8. android 标题名字,说说 Android 的 Material Design 设计(五)——可折叠式标题栏
  9. MarkDown下载和安装图文教程
  10. VPS搭建zotero自动同步的webdav服务
  11. python requests timeout参数
  12. 枪林弹雨 该服务器维护中,枪林弹雨登陆BUG解决方法
  13. python科研作图系列-01热力图相关性分析
  14. 苹果IOS14版本自建服务器无法下载解决方法
  15. Mobx-action
  16. 正则表达式限制文本框只能输入中文或者英文或者数字
  17. 基于Linux的kfifo移植到STM32(支持os的互斥访问)
  18. IDEA繁体问题解决
  19. 合并报表编制采用的理论_合并财务报表的编制原理(转载)
  20. Cadence IC 模拟版图初学手记

热门文章

  1. 【论文解读】MV3D-Net、AVOD-Ne用于自动驾驶的多视图3D目标检测网络
  2. 安装关联vs2008的opencv
  3. day9 线程与进程、队列
  4. RDKit | 化合物描述符向量化及部分结构检索
  5. DrugVQA | 用视觉问答技术预测药物蛋白质相互作用
  6. 附录2:Numpy实例记录
  7. 在Dos下运行exe程序的时候出现找不到Cygwin1.dll文件的情况总结
  8. 【首轮官宣】中国肠道大会姊妹盛会,GUT 2022正式启航!
  9. 蚂蚁森林合种计划(2020.12.19更新,7天有效)
  10. 中科院微生物所高程组招聘助研3名(正式编制)