ECAPA_TDNN代码和论文细节分析

  • 一、数据部分(dataloader.py)
  • 二、网络结构(model.py)
    • 2.1 整体网络结构
    • 2.2 SpecAugment算法
    • 2.3 注意力统计池化
    • 2.4 SE Res2Blocks
      • 2.4.1 SE block
      • 2.4.2 res2net
    • 2.5 MFA多层特征聚合
  • 三、损失函数AAMsoftmax(loss.py)

来源:INTERSPEECH 2020
机构:比利时根特大学
论文地址:
源码地址:
论文阅读博客:ECAPA_TDNN 上

一、数据部分(dataloader.py)

  1. 数据集: Voxceleb2 5994个说话人
  2. 数据增强: 每个话语生成6个额外的样本
    (1) 结合MUSAN(嘈杂的人声,噪声)数据集提供的RIR数据集(混响)生成三个。
    (2) 利用Sox (tempo up, tempo down)和ffmpeg (alternating opus or aac compression) 生成三个。
    (3) SpecAugment算法:随机掩码。(在第二部分具体说明)

MUSAN数据集下载:wget https://www.openslr.org/resources/17/musan.tar.gz
RIR数据集下载:wget https://openslr.org/resources/28/rirs_noises.zip

  1. 相关代码
    数据增强过程如下:先对音频长度进行调整,再通过选择语句随机选择增强方式。
audio, sr = soundfile.read(self.data_list[index])
#将所有音频调整为一个长度
length = self.num_frames * 160 + 240
if audio.shape[0] <= length:shortage = length - audio.shape[0]audio = numpy.pad(audio, (0, shortage), 'wrap')
start_frame = numpy.int64(random.random()*(audio.shape[0]-length))
audio = audio[start_frame:start_frame + length]
audio = numpy.stack([audio],axis=0)
# 数据增强
augtype = random.randint(0,5)
if augtype == 0:   # Originalaudio = audio
elif augtype == 1: # Reverberation混响audio = self.add_rev(audio)
elif augtype == 2: # Babbleaudio = self.add_noise(audio, 'speech')
elif augtype == 3: # Musicaudio = self.add_noise(audio, 'music')
elif augtype == 4: # Noiseaudio = self.add_noise(audio, 'noise')
elif augtype == 5: # Television noiseaudio = self.add_noise(audio, 'speech')audio = self.add_noise(audio, 'music')
return torch.FloatTensor(audio[0]), self.data_label[index]

如下为混响增强,随机从数据集中选取混响音频,再增加混响音频的维度与人声音频保持一致,最后对人声音频和混响音频做一个卷积。

#添加混响
def add_rev(self, audio):rir_file    = random.choice(self.rir_files)rir, sr     = soundfile.read(rir_file)rir         = numpy.expand_dims(rir.astype(numpy.float),0) rir         = rir / numpy.sqrt(numpy.sum(rir**2))return signal.convolve(audio, rir, mode='full')[:,:self.num_frames * 160 + 240]

如下为噪声增强,先获得人声音频的db,随机选出n个噪声音频,将噪声音频长度调整至与人声音频一致,再获得噪声音频的db,随机获取noise信噪比,然后计算出噪声系数,并与噪声音频相乘。将所有噪声音频进行concatencate,再与人声音频叠加。

def add_noise(self, audio, noisecat):#numpy.mean(audio ** 2) 为信号功率clean_db    = 10 * numpy.log10(numpy.mean(audio ** 2)+1e-4) numnoise    = self.numnoise[noisecat]noiselist   = random.sample(self.noiselist[noisecat], random.randint(numnoise[0],numnoise[1]))noises = []for noise in noiselist:#假设噪声音频长度已调整至人声音频一致noise_db = 10 * numpy.log10(numpy.mean(noiseaudio ** 2)+1e-4) noisesnr   = random.uniform(self.noisesnr[noisecat][0],self.noisesnr[noisecat][1])#noiseaudio乘以噪声系数noises.append(numpy.sqrt(10 ** ((clean_db - noise_db - noisesnr) / 10)) * noiseaudio)noise = numpy.sum(numpy.concatenate(noises,axis=0),axis=0,keepdims=True)return noise + audio

二、网络结构(model.py)

2.1 整体网络结构

网络结构如下

  1. 数据增强
  2. TDNN block
  3. 多层特征聚合
  4. 注意力统计池化
  5. FC+BN
  6. 输出
def forward(self, x, aug):#数据增强with torch.no_grad():x = self.torchfbank(x)+1e-6x = x.log()   x = x - torch.mean(x, dim=-1, keepdim=True)if aug == True:x = self.specaug(x)#相当于一个TDNN blockx = self.conv1(x)x = self.relu(x)x = self.bn1(x)#多层特征聚合x1 = self.layer1(x)x2 = self.layer2(x+x1)x3 = self.layer3(x+x1+x2)x = self.layer4(torch.cat((x1,x2,x3),dim=1))x = self.relu(x)#注意力统计池化t = x.size()[-1] global_x = torch.cat((x,torch.mean(x,dim=2,keepdim=True).repeat(1,1,t), torch.sqrt(torch.var(x,dim=2,keepdim=True).clamp(min=1e-4)).repeat(1,1,t)), dim=1)w = self.attention(global_x)mu = torch.sum(x * w, dim=2)sg = torch.sqrt( ( torch.sum((x**2) * w, dim=2) - mu**2 ).clamp(min=1e-4) )x = torch.cat((mu,sg),1)x = self.bn5(x)x = self.fc6(x)x = self.bn6(x)return x

2.2 SpecAugment算法

SpecAugment算法是一种添加掩码的数据增强算法,步骤如下:

  1. 预加重:PreEmphasis(torch.nn.Module)
  2. 提取梅尔
    torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=512, win_length=400, hop_length=160, f_min = 20, f_max = 7600, window_fn=torch.hamming_window, n_mels=80)
  3. 将梅尔进行零均值归一化,可以直接将Mask位置设为0
  4. 时间维度掩码
  5. 频率维度掩码

总代码:

 with torch.no_grad():#预加重和提取梅尔x = self.torchfbank(x)+1e-6#对数梅尔x = x.log()   x = x - torch.mean(x, dim=-1, keepdim=True)if aug == True:#添加掩码x = self.specaug(x)

掩码部分主要代码:
1.获取梅尔的维度,分别赋值为batch, fea, time
batch为每批次输入梅尔的数量;
fea为每一个梅尔的特征维度,这里应该为80;
time为每一个梅尔的时间维度
2.掩码的长度:生成[batch, 1, 1]维数组
3.掩码的位置:生成[batch, 1 ,1]维数组,根据长度和梅尔的维度调整
4.生成一个D维张量,并将其增加维度至[1,1,D]
5.根据掩码长度和掩码位置得到掩码:[batch, 1 , D] ->[batch, D] ->[batch, 1 , D] or [batch, D, 1]
6.将梅尔掩码的地方赋值为0

def mask_along_axis(self, x, dim):original_size = x.shapebatch, fea, time = x.shapeif dim == 1:D = feawidth_range = self.freq_mask_widthelse:D = timewidth_range = self.time_mask_widthmask_len = torch.randint(width_range[0], width_range[1], (batch, 1), device=x.device).unsqueeze(2)mask_pos = torch.randint(0, max(1, D - mask_len.max()), (batch, 1), device=x.device).unsqueeze(2)arange = torch.arange(D, device=x.device).view(1, 1, -1)mask = (mask_pos <= arange) * (arange < (mask_pos + mask_len))mask = mask.any(dim=1)if dim == 1:mask = mask.unsqueeze(2)else:mask = mask.unsqueeze(1)#用0填充张量x中对应mask位置处为True的元素x = x.masked_fill_(mask, 0.0)return x.view(*original_size)

2.3 注意力统计池化

主要是通过两个公式计算加权平均和加权标准差:
μc=∑tTαt,cht,c\mu_{c} = \sum^{T}_{t}\alpha_{t,c}h_{t,c} μc​=t∑T​αt,c​ht,c​
σc=∑tTαt,cht,c2−μc2\sigma_{c} = \sqrt{\sum^{T}_{t}\alpha_{t,c}h^{2}_{t,c}-\mu^2_{c}} σc​=t∑T​αt,c​ht,c2​−μc2​​
池化层的最终输出由加权平均μ和加权标准差σ的向量串联得到。

# 得到时间帧
t = x.size()[-1]
# 获取时间帧维度的均值和标准差,然后串联原始数据
mean = torch.mean(x,dim=2,keepdim=True).repeat(1,1,t)
standrad = torch.sqrt(torch.var(x,dim=2,keepdim=True).clamp(min=1e-4)).repeat(1,1,t))
global_x = torch.cat((x, mean, standrad), dim=1)
#通过注意力网络得到注意力矩阵w
w = self.attention(global_x)
self.attention = nn.Sequential(nn.Conv1d(4608, 256, kernel_size=1),nn.ReLU(),nn.BatchNorm1d(256),nn.Tanh(), # I add this layernn.Conv1d(256, 1536, kernel_size=1),nn.Softmax(dim=2),)mu = torch.sum(x * w, dim=2)
sg = torch.sqrt( ( torch.sum((x**2) * w, dim=2) - mu**2 ).clamp(min=1e-4) )
x = torch.cat((mu,sg),1)

2.4 SE Res2Blocks

2.4.1 SE block

一维SE blocks,重新缩放帧级特征,得到通道的重要性。
过程为:
1.特征通过全局平均池化进行压缩
2.用两个全连接层,主要是为了应用relu和sigmoid(将输出映射至0和1)。第一个全连接层降低维度,第二个全连接层恢复维度。
3.输出为输入乘以权重矩阵。

 def __init__(self, channels, bottleneck=128):super(SEModule, self).__init__()self.se = nn.Sequential(#全局平均池化压缩为1个数nn.AdaptiveAvgPool1d(1),nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0),nn.ReLU(),nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0),nn.Sigmoid(),)def forward(self, input):#获得权重矩阵x = self.se(input)return input * x

2.4.2 res2net

res2net主要是利用细粒度的多尺度信息,产生多个感受野的组合。下面左图是res2net多尺度的具体做法,右图是本论文res2net模块的网络结构。
由左图可得,res2net将传统resnet中的3*3卷积进行了多尺度的解耦,在1 * 1卷积之后对通道进行分组,尺度越大计算开销越大。
由右图可知,包含了扩展卷积和前后密集层,第一个密集层用于降低维度,第二个密集层用于恢复维度,最后由SE模块缩放每一个通道。

本文所用的res2net采用了8尺度,在代码中x1是作为最后一个直接送到。

def forward(self, x):residual = x#############################out = self.conv1(x)out = self.relu(out)out = self.bn1(out)###########################################################这里是res2net的核心spx = torch.split(out, self.width, 1)#分块卷积计算for i in range(self.nums):if i==0:sp = spx[i]else:sp = sp + spx[i]sp = self.convs[i](sp)sp = self.relu(sp)sp = self.bns[i](sp)if i==0:out = spelse:out = torch.cat((out, sp), 1)#cat x1的块out = torch.cat((out, spx[self.nums]),1)##############################################################out = self.conv3(out)out = self.relu(out)out = self.bn3(out)###############################out = self.se(out)out += residualreturn out

2.5 MFA多层特征聚合

MFA是多层特征聚合,将SE Res2Blocks输出特征映射连接起来

整体代码

x1 = self.layer1(x)
x2 = self.layer2(x+x1)
x3 = self.layer3(x+x1+x2)
x = self.layer4(torch.cat((x1,x2,x3),dim=1))
x = self.relu(x)

三、损失函数AAMsoftmax(loss.py)

详细介绍见:https://blog.csdn.net/qq_39478403/article/details/116788113
加性角度边界损失最早用于人脸识别任务。
原理:最大化类间间距,最小化类内间距。
softmax loss在决策边界产生明显的模糊性,但是AAMsoftmax通过添加加性角度边距可以扩大类间的间隙。

  1. 归一化输入特征和FC层权重。令所得归一化特征xix_{i}xi​与第j类别的FC层权重点乘得到FC层的第j个输出cosθjcosθ_{j}cosθj​,将特征xix_{i}xi​预测为第j类的预测值
cosine = F.linear(F.normalize(x), F.normalize(self.weight))
  1. 根据正余弦公式计算sinθjsin\theta_{j}sinθj​
sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
  1. 根据当前夹角的正余弦,计算添加了加性角度边距m的cos(θ+m)cos(\theta+m)cos(θ+m)
phi = cosine * self.cos_m - sine * self.sin_m
  1. 松弛约束
phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)
  1. 生成标签矩阵
one_hot = torch.zeros_like(cosine) #全0矩阵
one_hot.scatter_(1, label.view(-1, 1), 1) #在label索引上用1替换0
  1. 当输入特征x对应真实类别,采用新 Target Logit cos(θ_yi + m)
    其余并不对应输入特征x的真实类别的类,保持原有的logit
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
  1. 使用scale缩放新的logit
output = output * self.s
  1. 计算损失
loss = self.ce(output, label)
self.ce = nn.CrossEntropyLoss()

【ECAPA_TDNN 下 】代码和论文细节分析相关推荐

  1. 2022华数杯B题论文思路分析+完整代码(水下机器人组装计划)(一二问答案接出来和标准答案一样)(问题三四逼近正确答案)(完整论文,代码可直接跑)

    写在前面:学校最近搞数学建模竞赛培训,以2022华数杯B题作为训练题目,在查资料过程中发现网上没有哪一篇论文解出了正确答案,而我们组利用Lingo软件准确的解出了正确答案,但是在第三问时,由于决策的变 ...

  2. rcnn代码实现_Faster-RCNN论文细节原理解读+代码实现gluoncv(MXNet)

    Faster-RCNN开创了基于锚框(anchors)的目标检测框架,并且提出了RPN(Region proposal network),来生成RoI,用来取代之前的selective search方 ...

  3. Faster-RCNN论文细节原理解读+代码实现gluoncv(MXNet)

      Faster-RCNN开创了基于锚框(anchors)的目标检测框架,并且提出了RPN(Region proposal network),来生成RoI,用来取代之前的selective searc ...

  4. 【Java 并发编程】线程池机制 ( 线程池执行任务细节分析 | 线程池执行 execute 源码分析 | 先创建核心线程 | 再放入阻塞队列 | 最后创建非核心线程 )

    文章目录 一.线程池执行任务细节分析 二.线程池执行 execute 源码分析 一.线程池执行任务细节分析 线程池执行细节分析 : 核心线程数 101010 , 最大小成熟 202020 , 非核心线 ...

  5. 【Spring 工厂】工厂设计模式、第一个Spring程序细节分析、整合日志框架

    Spring 引言 什么是 Spring? 工厂设计模式 简单工厂的设计 通用工厂的设计 通用工厂的使用方式 第一个 Spring 程序 环境搭建 Spring 的核心API 程序开发 细节分析 Sp ...

  6. fastjson远程代码执行漏洞问题分析

    背景 fastjson远程代码执行安全漏洞(以下简称RCE漏洞),最早是官方在2017年3月份发出的声明, security_update_20170315 没错,强如阿里这样的公司也会有漏洞.代码是 ...

  7. 基于单片机的压力流量报警器(附代码+仿真+论文)

    基于单片机的压力流量报警器(附代码+仿真+论文) **==完整论文+代码+仿真可关注我在主页私我==** 摘要 关键字 第一章绪论 1.1课题背景及其意义 1.2 国内外的研究状况 1.3本文的主要研 ...

  8. 多层感知器用实际例子和Python代码进行解释情绪分析

    多层感知器用实际例子和Python代码进行解释情绪分析 多层感知器是一种学习线性和非线性数据之间关系的神经网络. 这是专门介绍深度学习系列的第一篇文章,深度学习是一组机器学习方法,其根源可以追溯到20 ...

  9. (01)ORB-SLAM2源码无死角解析-(24) 单目SFM地图初始化→CreateInitialMapMonocular()-细节分析:尺度不确定性

    讲解关于slam一系列文章汇总链接:史上最全slam从零开始,针对于本栏目讲解的(01)ORB-SLAM2源码无死角解析链接如下(本文内容来自计算机视觉life ORB-SLAM2 课程课件): (0 ...

最新文章

  1. 回文数:给你一个整数 x ,如果 x 是一个回文整数,返回 true ;否则,返回 false 。回文数是指正序(从左向右)和倒序(从右向左)读都是一样的整数。
  2. 存储过程授权给子用户
  3. 造车行业百年未有变局之下,一个「老玩家」开始了自己的赶超
  4. rsync备份之windows+linux
  5. 微服务化小团队集群的组织和管理
  6. 如何设计一门语言(四)——什么是坑(操作模板)
  7. 邓侃:深度强化学习“深”在哪里?
  8. [scala-spark]10. RDD转换操作
  9. MT7628/MT7688 修改串口2作为调试串口 所踩的坑
  10. 使用WeUI+JS 的label包含input触发两次的问题
  11. 常用JQuery插件整理
  12. mysql 恢复空密码_mysql 找回密码
  13. 一道PHP面试题,求两个文件的相对路径
  14. OpenGL——颜色混合 glBlendFunc函数
  15. 【机器学习】今天详细谈下Soft Margin SVM和 SVM正则化
  16. php jquery ajax实现用户名,php+jquery+ajax实现用户名验证
  17. ssl证书不可信 群晖_上海云盾 CDN 网站 SSL 证书过期更新不生效问题排查和解决...
  18. MQTT客户端代码(C语言)
  19. K3 WISE修改单据表头字段默认值
  20. Magic-api介绍及使用

热门文章

  1. webapi中使用token验证(JWT验证)
  2. 电商基本模块-促销服务
  3. 智能汽车大爆发,车企创新为何首选华为云?
  4. 沪市A股,kdj指标,api接口,API接口
  5. asp毕业设计——基于asp+access的房产信息管理系统设计与实现(毕业论文+程序源码)——房产信息管理系统
  6. 使用OBS双电脑直播(主副电脑)推流至哔哩哔哩
  7. matplotlib画图之设置横轴坐标上下限的两种方法
  8. 杨辉三角+判断奇数并求和
  9. 百度(baidu)、bing、sogo、360关键字 - 图片批量下载
  10. android 动态表情实现,基于Android开发支持表情的实现详解