前文再续,书接上一回。接下来我们看看wav2vec2怎么构建transformer-encoder。

在我们的wav2vec_model方法中,提取特征模型建立后就开始建立transformer的encoder模型

我们跳进去compoents._get_encoder方法中看看。

encoder

直接看代码:

def _get_encoder(in_features: int,embed_dim: int,dropout_input: float,pos_conv_kernel: int,pos_conv_groups: int,num_layers: int,num_heads: int,attention_dropout: float,ff_interm_features: int,ff_interm_dropout: float,dropout: float,layer_norm_first: bool,layer_drop: float,
) -> Encoder:# 特征映射feature_projection = FeatureProjection(in_features, embed_dim, dropout_input)# 特征位置embeddingpos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups)# 建立空模型列表encoder_layers = nn.ModuleList()for _ in range(num_layers):# 多头自监督attention = SelfAttention(embed_dim=embed_dim,num_heads=num_heads,dropout=attention_dropout,)# 前馈网络feed_forward = FeedForward(io_features=embed_dim,intermediate_features=ff_interm_features,intermediate_dropout=ff_interm_dropout,output_dropout=dropout,)# encoder结构encoder_layers.append(EncoderLayer(attention=attention,dropout=dropout,layer_norm_first=layer_norm_first,feed_forward=feed_forward,))# 位置embedding、dropout、encoder层、transformer = Transformer(pos_conv_embed=pos_conv,dropout=dropout,layers=encoder_layers,layer_norm_first=not layer_norm_first,layer_drop=layer_drop,)return Encoder(feature_projection, transformer)

主要工作:

  1. 特征映射feature_projection
  2. 特征位置embedding,主要为了记录位置
  3. transformer的encoder结构的搭建
  4. 返回encoder模型

现在我们一个一个看里面的代码结构。

特征映射

feature_projection = FeatureProjection(in_features, embed_dim, dropout_input)

in_features的值是特征提取时最后的output_channel的大小。

in_features=512

out_feature=embed_dim=768

dropout_input=0.1

特征映射用到了FeatureProjection对象,我们点进去看看结构是怎么样的。

里面代码结构如下:

# 特征映射
class FeatureProjection(Module):# 建立线性投影模型# 输入、输出、dropoutdef __init__(self,in_features: int,out_features: int,dropout: float,):super().__init__()# 层归一化self.layer_norm = nn.LayerNorm(in_features)# 线性转换self.projection = nn.Linear(in_features,out_features,)# dropoutself.dropout = nn.Dropout(dropout)def forward(self, x):"""Args:x (Tensor):Feature Tensor. shape: ``[batch, frame, in_feature]``Returns:Tensor: Projected features. ``[batch, frame, out_feature]``."""# 先归一、再投影、再dropoutx = self.layer_norm(x)x = self.projection(x)x = self.dropout(x)return x

好像没什么好说的…,直接下一个。

特征位置embedding

 pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups)

embed_dim=768

pos_conv_kernel=128

pos_conv_groups=16

具体代码如下:

class ConvolutionalPositionalEmbedding(Module):"""Positional embedding which is placed at the beginning of Transformer.Args:embed_dim (int): Feature dimension of the input Tensor.kernel_size (int): The number of frames to be use.groups (int): The number of groups in feature dimensions."""# embed维度、核函数大小、卷积位置嵌入的组数。def __init__(self,embed_dim: int,kernel_size: int,groups: int,):super().__init__()self.embed_dim = embed_dim# conv1dself.conv = nn.Conv1d(in_channels=embed_dim,out_channels=embed_dim,kernel_size=kernel_size,padding=kernel_size // 2,groups=groups,)# 权重标准化self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)self.num_remove: int = 1 if kernel_size % 2 == 0 else 0def __prepare_scriptable__(self):for hook in self.conv._forward_pre_hooks.values():# The hook we want to remove is an instance of WeightNorm class, so# normally we would do `if isinstance(...)` but this class is not accessible# because of shadowing, so we check the module name directly.#if hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm":_LG.warning("Removing weight_norm from %s", self.__class__.__name__)torch.nn.utils.remove_weight_norm(self.conv)return selfdef forward(self, x):"""Args:x (Tensor): shape ``[batch, frame, feature]``.Returns:Tensor: The resulting feature. Shape ``[batch, frame, feature]``."""x = x.transpose(-2, -1)x = self.conv(x)if self.num_remove > 0:x = x[..., : -self.num_remove]x = torch.nn.functional.gelu(x)x = x.transpose(-2, -1)return x

处对象初始化方法中可以看出,位置embedding也是用conv1d进行获得的。

然后权重标准化的目的是防止过拟合(不过我也不太懂)

SelfAttention

接下来看看selfAttention的源代码。

先看看传入的参数:

attention = SelfAttention(embed_dim=embed_dim,num_heads=num_heads,dropout=attention_dropout,)

embed_dim=768,代表输出的特征维度为768

num_heads=12,代表12头注意力机制

attention_dropout=0.1

然后点进SelfAttention对象中看看结构(把解释都写下面了):

class SelfAttention(Module):def __init__(self,embed_dim: int,num_heads: int,dropout: float = 0.0,):super().__init__()# 求每个注意力头的维度是多少head_dim = embed_dim // num_heads# 如果整除有余数,那代表模型结构不行if head_dim * num_heads != embed_dim:raise ValueError(f"`embed_dim ({embed_dim})` is not divisible by `num_heads ({num_heads})`")self.embed_dim = embed_dimself.num_heads = num_headsself.dropout = torch.nn.Dropout(dropout)self.head_dim = head_dim# sqrt(dk)self.scaling = self.head_dim ** -0.5#得到k、v、q的模型self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)def forward(self,x: Tensor,attention_mask: Optional[Tensor] = None,) -> Tensor:# x的维度不等于3或者第三个维度不等于768的话,就寄if x.ndim != 3 or x.shape[2] != self.embed_dim:raise ValueError(f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). " f"Found{x.shape}.")# 获取x的三个维度的大小batch_size, length, embed_dim = x.size()# 检测遮蔽参数是否正确if attention_mask is not None:# 设置遮蔽参数shape_ = (batch_size, 1, length, length)if attention_mask.size() != shape_:raise ValueError(f"The expected attention mask shape is{shape_}. " f"Found{attention_mask.size()}.")# q、k、v的形状shape = (batch_size, length, self.num_heads, self.head_dim)# 把q变成shape的形状后再进行转换q = self.q_proj(x).view(*shape).transpose(2, 1)  # B, nH, L, Hd# k的转置k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1)  # B, nH, Hd, Lv = self.v_proj(x).view(*shape).transpose(2, 1)  # B, nH, L, Hd#softmax中的计算公式weights = self.scaling * (q @ k)  # B, nH, L, L# 如果attention_mask是有的,就在上式加上if attention_mask is not None:weights += attention_mask# 做个softmaxweights = torch.nn.functional.softmax(weights, dim=-1)# 做个dropoutweights = self.dropout(weights)# 和v做点积output = weights @ v  # B, nH, L, Hd# 输出output = output.transpose(2, 1).reshape(batch_size, length, embed_dim)output = self.out_proj(output)return output

FeedForward

接下来我们看feed_forward的模型获取。

feed_forward = FeedForward(io_features=embed_dim,intermediate_features=ff_interm_features,intermediate_dropout=ff_interm_dropout,output_dropout=dropout,)

FeedForward对象传入的参数有如下几个参数:(名字写得花里胡哨的,还是不是传特征个数)

io_features=embed_dim=768

intermediate_features=ff_interm_features=3072

intermediate_dropout=ff_interm_dropout=0.0

output_dropout=dropout=0.1

看看FeedForward对象的代码如何:

class FeedForward(Module):"""Layer that follows attention layer in encoder layer."""def __init__(self,io_features: int,intermediate_features: int,intermediate_dropout: float,output_dropout: float,):super().__init__()# 768 -> 3072self.intermediate_dense = nn.Linear(io_features, intermediate_features)self.intermediate_dropout = nn.Dropout(intermediate_dropout)# 3072 -> 768self.output_dense = nn.Linear(intermediate_features, io_features)self.output_dropout = nn.Dropout(output_dropout)def forward(self, x):"""Args:x (Tensor): shape: `(batch, sequence_length, io_features)`Returns:x (Tensor): shape: `(batch, sequence_length, io_features)`"""x = self.intermediate_dense(x)x = torch.nn.functional.gelu(x)x = self.intermediate_dropout(x)x = self.output_dense(x)x = self.output_dropout(x)return x

哇,简单多了,就两个线性网络层,中间夹了个gelu。

EncoderLayer

现在看第三部分,encoder_layers的构建。

encoder_layers.append(EncoderLayer(attention=attention,dropout=dropout,layer_norm_first=layer_norm_first,feed_forward=feed_forward,))

encoder_layers是一个nn.ModuleList()对象。所以往里面加入的,就是我们的encoder模型。

我们看看EncoderLayer对象是什么。

首先看看传入的参数有如下参数:

attention=attention,就是传入多头模型

dropout=dropout=0.1

layer_norm_first=layer_norm_first=false

feed_forward=feed_forward,传入全连接层

看看源代码:

class EncoderLayer(Module):"""A layer unit in encoder. Combines multihead self attention and feed forward."""def __init__(self,attention: Module,dropout: float,layer_norm_first: bool,feed_forward: Module,):super().__init__()self.attention = attentionself.dropout = nn.Dropout(dropout)self.layer_norm = nn.LayerNorm(attention.embed_dim)self.layer_norm_first = layer_norm_firstself.feed_forward = feed_forwardself.final_layer_norm = nn.LayerNorm(attention.embed_dim)def forward(self,x: Tensor,attention_mask: Optional[Tensor] = None,):"""Args:x (Tensor): shape: `(batch, sequence_length, embed_dim)`attention_mask (Tensor or None, optional):shape: `(batch, 1, sequence_length, sequence_length)`"""residual = x# 是否先正则if self.layer_norm_first:x = self.layer_norm(x)# 多头注意力x = self.attention(x, attention_mask)x = self.dropout(x)# resnet机制x = residual + x# 是否先正则if self.layer_norm_first:x = x + self.feed_forward(self.final_layer_norm(x))else:x = self.layer_norm(x)x = self.final_layer_norm(x + self.feed_forward(x))return x

Transformer

把encoder结构搞完了,接下来看看transformer对象的结构如何。

先看参数:

transformer = Transformer(pos_conv_embed=pos_conv,dropout=dropout,layers=encoder_layers,layer_norm_first=not layer_norm_first,layer_drop=layer_drop,)

pos_conv_embed=pos_conv,传入位置embedding结构

dropout=dropout=0.1

layers=encoder_layers,传入encoder层数

layer_norm_first=not layer_norm_first=true

layer_drop=layer_drop=0.05

再看看transformer代码结构如何:

class Transformer(Module):def __init__(self,pos_conv_embed: Module,dropout: float,layers: Module,layer_norm_first: bool,layer_drop: float,):super().__init__()self.pos_conv_embed = pos_conv_embedself.layer_norm = nn.LayerNorm(pos_conv_embed.embed_dim)self.layer_norm_first = layer_norm_firstself.layer_drop = layer_dropself.dropout = nn.Dropout(dropout)self.layers = layersdef _preprocess(self, x: Tensor):# 位置设置x = x + self.pos_conv_embed(x)# 是否归一if self.layer_norm_first:x = self.layer_norm(x)x = self.dropout(x)return xdef forward(self,x: Tensor,attention_mask: Optional[Tensor] = None,):# 位置设置x = self._preprocess(x)# 循环encoder层数,跑encdoerfor layer in self.layers:if not (self.training and torch.rand(1).item() <= self.layer_drop):x = layer(x, attention_mask)# 归一if not self.layer_norm_first:x = self.layer_norm(x)return xdef get_intermediate_outputs(self,x: Tensor,attention_mask: Optional[Tensor] = None,num_layers: Optional[int] = None,) -> List[Tensor]:if num_layers is not None:if not 0 < num_layers <= len(self.layers):raise ValueError(f"`num_layers` must be between [1,{len(self.layers)}]")ret: List[Tensor] = []x = self._preprocess(x)for layer in self.layers:x = layer(x, attention_mask)ret.append(x)if num_layers is not None and len(ret) >= num_layers:return retreturn ret

主要工作:简单来说就是把embedding结构和encoder多层结构给整合了

Encoder

最后看一个Encoder类,看看传入的参数。

Encoder(feature_projection, transformer)

feature_projection和transformer都是上面建立过的module。

Encoder代码:

class Encoder(Module):def __init__(self,feature_projection: Module,transformer: Module,):super().__init__()self.feature_projection = feature_projectionself.transformer = transformer# 为特征加maskdef _preprocess(self,features: Tensor,lengths: Optional[Tensor] = None,) -> Tuple[Tensor, Optional[Tensor]]:# 特征映射x = self.feature_projection(features)# maskmask: Optional[Tensor] = Noneif lengths is not None:# 数据个数、最大长度batch_size, max_len, _ = x.shape# create mask for padded elements and zero-out them# 为填充元素创建遮罩并将其归零mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None]x[mask] = 0.0# extend the mask to attention shape and set weight# 将mask延伸至注意力的大小并设置weightmask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype)mask = mask.expand(batch_size, 1, max_len, max_len)return x, maskdef forward(self,features: Tensor,lengths: Optional[Tensor] = None,) -> Tensor:x, mask = self._preprocess(features, lengths)x = self.transformer(x, attention_mask=mask)return xdef extract_features(self,features: Tensor,lengths: Optional[Tensor] = None,num_layers: Optional[int] = None,) -> List[Tensor]:x, masks = self._preprocess(features, lengths)return self.transformer.get_intermediate_outputs(x, attention_mask=masks, num_layers=num_layers)

关于mask方面的设定,我表示没怎么看懂。其他都还好。

不过总的来说,也算是了解了整个模型的构造过程了。

总结

因此,整个wav2vec2_model结构如下:

蓝色字代表步骤解释,并非代码。其中,全部的对象都是为了构造一整个模型而存在。

近期发现

这个模型是用来直接做asr微调任务的,也就是说不需要进行预训练操作。整个模型预训练过的权重也在下面这个步骤下载了:

这导致了,我们看不到预训练的损失函数设置和gamble-softmax和量化过程。

只能直接用这个模型进行微调任务。

这我就不乐意了,虽然看代码的时间学到了不少东西,但强行学不完的感觉也有点憋屈。fairseq的操作也是一样,直接加载pt文件得到模型权重。

我们的wav2vec2.0的微调任务模型就这些。后面还有的就是调用模型过程和ctc解码操作了。有时间会继续看下去的。

torchAudio中wav2vec2的源码(三)——transformer-encoder的构建相关推荐

  1. PyTorch:torchAudio中wav2vec2

    一.概述 torchAudio在0.10.0版本中已经兼容了hubert的代码(现在已经0.11.0了).然而,我连wav2vec2.0的代码都没跑过,官方文档也给出了一个wav2vec2.0代码的运 ...

  2. 读文章笔记(三):从源码学习Transformer

    读文章笔记(三):从源码学习Transformer encoder分为两部分: decoder 公众号机器学习算法工程师 文章链接: https://mp.weixin.qq.com/s/0NajB_ ...

  3. bert模型简介、transformers中bert模型源码阅读、分类任务实战和难点总结

    bert模型简介.transformers中bert模型源码阅读.分类任务实战和难点总结:https://blog.csdn.net/HUSTHY/article/details/105882989 ...

  4. 探秘Tomcat(一)——Myeclipse中导入Tomcat源码

    前言:有的时候自己不知道自己是井底之蛙,这并没有什么可怕的,因为你只要蜷缩在方寸之间的井里,无数次的生活轨迹无非最终归结还是一个圆形:但是可怕的是有一天你不得不从井里跳出来生活,需要重新审视井以外的生 ...

  5. 深入java并发包源码(三)AQS独占方法源码分析

    深入java并发包源码(一)简介 深入java并发包源码(二)AQS的介绍与使用 深入java并发包源码(三)AQS独占方法源码分析 AQS 的实现原理 学完用 AQS 自定义一个锁以后,我们可以来看 ...

  6. 阅读react-redux源码(三) - mapStateToPropsFactories、mapDispatchToPropsFactories和mergePropsFactories

    阅读react-redux源码 - 零 阅读react-redux源码 - 一 阅读react-redux源码(二) - createConnect.match函数的实现 阅读react-redux源 ...

  7. 【java】浅析JDK中ServiceLoader的源码

    1.概述 转载:浅析JDK中ServiceLoader的源码 上一篇文章:深入探讨 Java 类加载器 2.ServiceLoader的使用 这里先列举一个经典的例子,MySQL的Java驱动就是通过 ...

  8. 剑指spring源码(三)

    文章目录 剑指spring源码(三) 前言 注册BeanPostProcessors 剑指spring源码(三) 前言 在我的spring源码系列文章已经写了注册bd和执行BeanFactory的后置 ...

  9. Uboot中start.S源码的指令级的详尽解析

    Uboot中start.S源码的指令级的详尽解析 版本:v1.9 Crifan Li 摘要 本文对Uboot中的Start.S的源码的几乎每一行,都进行了详细的解析 本文提供多种格式供: 在线阅读 H ...

最新文章

  1. Nginx 的这些妙用,你都 get 到了吗?
  2. laravel-admin 关闭debug模式导致异常信息到页面的排查
  3. 如何开发利用计算机字体,如何利用PS来制作水纹特效字体
  4. LaTeX 使用 bib 管理参考文献时,引用网络资源 URL 导致排版难看的问题
  5. aspose.cells html excel导出,C#使用Aspose.Cells导出Excel简单实现
  6. 戴尔发布面向制造、生命科学和研究的高性能计算系统
  7. liferay 导入源码问题
  8. 第三次学JAVA再学不好就吃翔(part33)--final关键字
  9. Linux Qt使用POSIX多线程条件变量、互斥锁(量)
  10. HDU 2653 (记忆化BFS搜索+优先队列)
  11. [导入]ASP.NET2.0中Tabs的简单实现
  12. LAMP_ 访问控制
  13. java剪刀石头布游戏心得体会,基于JAVA的剪刀石头布游戏设计——Java课程设计报告_...
  14. 查询 加载时间过长添加提示信息
  15. 大厂P5、P6、P7级程序员的简历长什么样?
  16. 【PyTorch修炼】一、安装GPU的pytorch详细教程(避坑)
  17. 色彩校正中的Gamma(人眼亮度感应的非线性)
  18. 自助订餐管理系统(小程序+后台源码+数据库)
  19. 用python玩转数据测试答案_MOOC_用Python玩转数据_测试答案
  20. IDEA 返回上一步,回到下一步 冲突 快捷键设置

热门文章

  1. 用TB5128FTG替换THB6128驱动方案 要点
  2. No.77 组合:回溯法
  3. Spring boot Mybatis 整合
  4. 互联网不断地更新迭代!我的人生就是要不断的更新迭代!
  5. PKCS及PKCS 15个标准, Public-Key Cryptography Standards
  6. 【记】Vue - 拖拽元素组件实现
  7. 探索真实事物的虚拟再现——微软亚洲研究院SIGGRAPH Asia 2014精彩入选论文赏析
  8. geokit无法适用计算机怎么弄,geokit用户手册.docx
  9. PMP考试经验总结分享
  10. 搭建中台架构的几个误区