torchAudio中wav2vec2的源码(三)——transformer-encoder的构建
前文再续,书接上一回。接下来我们看看wav2vec2怎么构建transformer-encoder。
在我们的wav2vec_model方法中,提取特征模型建立后就开始建立transformer的encoder模型
我们跳进去compoents._get_encoder方法中看看。
encoder
直接看代码:
def _get_encoder(in_features: int,embed_dim: int,dropout_input: float,pos_conv_kernel: int,pos_conv_groups: int,num_layers: int,num_heads: int,attention_dropout: float,ff_interm_features: int,ff_interm_dropout: float,dropout: float,layer_norm_first: bool,layer_drop: float,
) -> Encoder:# 特征映射feature_projection = FeatureProjection(in_features, embed_dim, dropout_input)# 特征位置embeddingpos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups)# 建立空模型列表encoder_layers = nn.ModuleList()for _ in range(num_layers):# 多头自监督attention = SelfAttention(embed_dim=embed_dim,num_heads=num_heads,dropout=attention_dropout,)# 前馈网络feed_forward = FeedForward(io_features=embed_dim,intermediate_features=ff_interm_features,intermediate_dropout=ff_interm_dropout,output_dropout=dropout,)# encoder结构encoder_layers.append(EncoderLayer(attention=attention,dropout=dropout,layer_norm_first=layer_norm_first,feed_forward=feed_forward,))# 位置embedding、dropout、encoder层、transformer = Transformer(pos_conv_embed=pos_conv,dropout=dropout,layers=encoder_layers,layer_norm_first=not layer_norm_first,layer_drop=layer_drop,)return Encoder(feature_projection, transformer)
主要工作:
- 特征映射
feature_projection
- 特征位置embedding,主要为了记录位置
- transformer的encoder结构的搭建
- 返回encoder模型
现在我们一个一个看里面的代码结构。
特征映射
feature_projection = FeatureProjection(in_features, embed_dim, dropout_input)
in_features
的值是特征提取时最后的output_channel的大小。
in_features=512
out_feature=embed_dim=768
dropout_input=0.1
特征映射用到了FeatureProjection对象,我们点进去看看结构是怎么样的。
里面代码结构如下:
# 特征映射
class FeatureProjection(Module):# 建立线性投影模型# 输入、输出、dropoutdef __init__(self,in_features: int,out_features: int,dropout: float,):super().__init__()# 层归一化self.layer_norm = nn.LayerNorm(in_features)# 线性转换self.projection = nn.Linear(in_features,out_features,)# dropoutself.dropout = nn.Dropout(dropout)def forward(self, x):"""Args:x (Tensor):Feature Tensor. shape: ``[batch, frame, in_feature]``Returns:Tensor: Projected features. ``[batch, frame, out_feature]``."""# 先归一、再投影、再dropoutx = self.layer_norm(x)x = self.projection(x)x = self.dropout(x)return x
好像没什么好说的…,直接下一个。
特征位置embedding
pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups)
embed_dim=768
pos_conv_kernel=128
pos_conv_groups=16
具体代码如下:
class ConvolutionalPositionalEmbedding(Module):"""Positional embedding which is placed at the beginning of Transformer.Args:embed_dim (int): Feature dimension of the input Tensor.kernel_size (int): The number of frames to be use.groups (int): The number of groups in feature dimensions."""# embed维度、核函数大小、卷积位置嵌入的组数。def __init__(self,embed_dim: int,kernel_size: int,groups: int,):super().__init__()self.embed_dim = embed_dim# conv1dself.conv = nn.Conv1d(in_channels=embed_dim,out_channels=embed_dim,kernel_size=kernel_size,padding=kernel_size // 2,groups=groups,)# 权重标准化self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)self.num_remove: int = 1 if kernel_size % 2 == 0 else 0def __prepare_scriptable__(self):for hook in self.conv._forward_pre_hooks.values():# The hook we want to remove is an instance of WeightNorm class, so# normally we would do `if isinstance(...)` but this class is not accessible# because of shadowing, so we check the module name directly.#if hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm":_LG.warning("Removing weight_norm from %s", self.__class__.__name__)torch.nn.utils.remove_weight_norm(self.conv)return selfdef forward(self, x):"""Args:x (Tensor): shape ``[batch, frame, feature]``.Returns:Tensor: The resulting feature. Shape ``[batch, frame, feature]``."""x = x.transpose(-2, -1)x = self.conv(x)if self.num_remove > 0:x = x[..., : -self.num_remove]x = torch.nn.functional.gelu(x)x = x.transpose(-2, -1)return x
处对象初始化方法中可以看出,位置embedding也是用conv1d进行获得的。
然后权重标准化的目的是防止过拟合(不过我也不太懂)
SelfAttention
接下来看看selfAttention的源代码。
先看看传入的参数:
attention = SelfAttention(embed_dim=embed_dim,num_heads=num_heads,dropout=attention_dropout,)
embed_dim=768,代表输出的特征维度为768
num_heads=12,代表12头注意力机制
attention_dropout=0.1
然后点进SelfAttention对象中看看结构(把解释都写下面了):
class SelfAttention(Module):def __init__(self,embed_dim: int,num_heads: int,dropout: float = 0.0,):super().__init__()# 求每个注意力头的维度是多少head_dim = embed_dim // num_heads# 如果整除有余数,那代表模型结构不行if head_dim * num_heads != embed_dim:raise ValueError(f"`embed_dim ({embed_dim})` is not divisible by `num_heads ({num_heads})`")self.embed_dim = embed_dimself.num_heads = num_headsself.dropout = torch.nn.Dropout(dropout)self.head_dim = head_dim# sqrt(dk)self.scaling = self.head_dim ** -0.5#得到k、v、q的模型self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)def forward(self,x: Tensor,attention_mask: Optional[Tensor] = None,) -> Tensor:# x的维度不等于3或者第三个维度不等于768的话,就寄if x.ndim != 3 or x.shape[2] != self.embed_dim:raise ValueError(f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). " f"Found{x.shape}.")# 获取x的三个维度的大小batch_size, length, embed_dim = x.size()# 检测遮蔽参数是否正确if attention_mask is not None:# 设置遮蔽参数shape_ = (batch_size, 1, length, length)if attention_mask.size() != shape_:raise ValueError(f"The expected attention mask shape is{shape_}. " f"Found{attention_mask.size()}.")# q、k、v的形状shape = (batch_size, length, self.num_heads, self.head_dim)# 把q变成shape的形状后再进行转换q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd# k的转置k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, Lv = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd#softmax中的计算公式weights = self.scaling * (q @ k) # B, nH, L, L# 如果attention_mask是有的,就在上式加上if attention_mask is not None:weights += attention_mask# 做个softmaxweights = torch.nn.functional.softmax(weights, dim=-1)# 做个dropoutweights = self.dropout(weights)# 和v做点积output = weights @ v # B, nH, L, Hd# 输出output = output.transpose(2, 1).reshape(batch_size, length, embed_dim)output = self.out_proj(output)return output
FeedForward
接下来我们看feed_forward的模型获取。
feed_forward = FeedForward(io_features=embed_dim,intermediate_features=ff_interm_features,intermediate_dropout=ff_interm_dropout,output_dropout=dropout,)
FeedForward对象传入的参数有如下几个参数:(名字写得花里胡哨的,还是不是传特征个数)
io_features=embed_dim=768
intermediate_features=ff_interm_features=3072
intermediate_dropout=ff_interm_dropout=0.0
output_dropout=dropout=0.1
看看FeedForward对象的代码如何:
class FeedForward(Module):"""Layer that follows attention layer in encoder layer."""def __init__(self,io_features: int,intermediate_features: int,intermediate_dropout: float,output_dropout: float,):super().__init__()# 768 -> 3072self.intermediate_dense = nn.Linear(io_features, intermediate_features)self.intermediate_dropout = nn.Dropout(intermediate_dropout)# 3072 -> 768self.output_dense = nn.Linear(intermediate_features, io_features)self.output_dropout = nn.Dropout(output_dropout)def forward(self, x):"""Args:x (Tensor): shape: `(batch, sequence_length, io_features)`Returns:x (Tensor): shape: `(batch, sequence_length, io_features)`"""x = self.intermediate_dense(x)x = torch.nn.functional.gelu(x)x = self.intermediate_dropout(x)x = self.output_dense(x)x = self.output_dropout(x)return x
哇,简单多了,就两个线性网络层,中间夹了个gelu。
EncoderLayer
现在看第三部分,encoder_layers
的构建。
encoder_layers.append(EncoderLayer(attention=attention,dropout=dropout,layer_norm_first=layer_norm_first,feed_forward=feed_forward,))
encoder_layers
是一个nn.ModuleList()对象。所以往里面加入的,就是我们的encoder模型。
我们看看EncoderLayer对象是什么。
首先看看传入的参数有如下参数:
attention=attention,就是传入多头模型
dropout=dropout=0.1
layer_norm_first=layer_norm_first=false
feed_forward=feed_forward,传入全连接层
看看源代码:
class EncoderLayer(Module):"""A layer unit in encoder. Combines multihead self attention and feed forward."""def __init__(self,attention: Module,dropout: float,layer_norm_first: bool,feed_forward: Module,):super().__init__()self.attention = attentionself.dropout = nn.Dropout(dropout)self.layer_norm = nn.LayerNorm(attention.embed_dim)self.layer_norm_first = layer_norm_firstself.feed_forward = feed_forwardself.final_layer_norm = nn.LayerNorm(attention.embed_dim)def forward(self,x: Tensor,attention_mask: Optional[Tensor] = None,):"""Args:x (Tensor): shape: `(batch, sequence_length, embed_dim)`attention_mask (Tensor or None, optional):shape: `(batch, 1, sequence_length, sequence_length)`"""residual = x# 是否先正则if self.layer_norm_first:x = self.layer_norm(x)# 多头注意力x = self.attention(x, attention_mask)x = self.dropout(x)# resnet机制x = residual + x# 是否先正则if self.layer_norm_first:x = x + self.feed_forward(self.final_layer_norm(x))else:x = self.layer_norm(x)x = self.final_layer_norm(x + self.feed_forward(x))return x
Transformer
把encoder结构搞完了,接下来看看transformer对象的结构如何。
先看参数:
transformer = Transformer(pos_conv_embed=pos_conv,dropout=dropout,layers=encoder_layers,layer_norm_first=not layer_norm_first,layer_drop=layer_drop,)
pos_conv_embed=pos_conv,传入位置embedding结构
dropout=dropout=0.1
layers=encoder_layers,传入encoder层数
layer_norm_first=not layer_norm_first=true
layer_drop=layer_drop=0.05
再看看transformer代码结构如何:
class Transformer(Module):def __init__(self,pos_conv_embed: Module,dropout: float,layers: Module,layer_norm_first: bool,layer_drop: float,):super().__init__()self.pos_conv_embed = pos_conv_embedself.layer_norm = nn.LayerNorm(pos_conv_embed.embed_dim)self.layer_norm_first = layer_norm_firstself.layer_drop = layer_dropself.dropout = nn.Dropout(dropout)self.layers = layersdef _preprocess(self, x: Tensor):# 位置设置x = x + self.pos_conv_embed(x)# 是否归一if self.layer_norm_first:x = self.layer_norm(x)x = self.dropout(x)return xdef forward(self,x: Tensor,attention_mask: Optional[Tensor] = None,):# 位置设置x = self._preprocess(x)# 循环encoder层数,跑encdoerfor layer in self.layers:if not (self.training and torch.rand(1).item() <= self.layer_drop):x = layer(x, attention_mask)# 归一if not self.layer_norm_first:x = self.layer_norm(x)return xdef get_intermediate_outputs(self,x: Tensor,attention_mask: Optional[Tensor] = None,num_layers: Optional[int] = None,) -> List[Tensor]:if num_layers is not None:if not 0 < num_layers <= len(self.layers):raise ValueError(f"`num_layers` must be between [1,{len(self.layers)}]")ret: List[Tensor] = []x = self._preprocess(x)for layer in self.layers:x = layer(x, attention_mask)ret.append(x)if num_layers is not None and len(ret) >= num_layers:return retreturn ret
主要工作:简单来说就是把embedding结构和encoder多层结构给整合了
Encoder
最后看一个Encoder类,看看传入的参数。
Encoder(feature_projection, transformer)
feature_projection和transformer都是上面建立过的module。
Encoder代码:
class Encoder(Module):def __init__(self,feature_projection: Module,transformer: Module,):super().__init__()self.feature_projection = feature_projectionself.transformer = transformer# 为特征加maskdef _preprocess(self,features: Tensor,lengths: Optional[Tensor] = None,) -> Tuple[Tensor, Optional[Tensor]]:# 特征映射x = self.feature_projection(features)# maskmask: Optional[Tensor] = Noneif lengths is not None:# 数据个数、最大长度batch_size, max_len, _ = x.shape# create mask for padded elements and zero-out them# 为填充元素创建遮罩并将其归零mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None]x[mask] = 0.0# extend the mask to attention shape and set weight# 将mask延伸至注意力的大小并设置weightmask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype)mask = mask.expand(batch_size, 1, max_len, max_len)return x, maskdef forward(self,features: Tensor,lengths: Optional[Tensor] = None,) -> Tensor:x, mask = self._preprocess(features, lengths)x = self.transformer(x, attention_mask=mask)return xdef extract_features(self,features: Tensor,lengths: Optional[Tensor] = None,num_layers: Optional[int] = None,) -> List[Tensor]:x, masks = self._preprocess(features, lengths)return self.transformer.get_intermediate_outputs(x, attention_mask=masks, num_layers=num_layers)
关于mask方面的设定,我表示没怎么看懂。其他都还好。
不过总的来说,也算是了解了整个模型的构造过程了。
总结
因此,整个wav2vec2_model结构如下:
蓝色字代表步骤解释,并非代码。其中,全部的对象都是为了构造一整个模型而存在。
近期发现
这个模型是用来直接做asr微调任务的,也就是说不需要进行预训练操作。整个模型预训练过的权重也在下面这个步骤下载了:
这导致了,我们看不到预训练的损失函数设置和gamble-softmax和量化过程。
只能直接用这个模型进行微调任务。
这我就不乐意了,虽然看代码的时间学到了不少东西,但强行学不完的感觉也有点憋屈。fairseq的操作也是一样,直接加载pt文件得到模型权重。
我们的wav2vec2.0的微调任务模型就这些。后面还有的就是调用模型过程和ctc解码操作了。有时间会继续看下去的。
torchAudio中wav2vec2的源码(三)——transformer-encoder的构建相关推荐
- PyTorch:torchAudio中wav2vec2
一.概述 torchAudio在0.10.0版本中已经兼容了hubert的代码(现在已经0.11.0了).然而,我连wav2vec2.0的代码都没跑过,官方文档也给出了一个wav2vec2.0代码的运 ...
- 读文章笔记(三):从源码学习Transformer
读文章笔记(三):从源码学习Transformer encoder分为两部分: decoder 公众号机器学习算法工程师 文章链接: https://mp.weixin.qq.com/s/0NajB_ ...
- bert模型简介、transformers中bert模型源码阅读、分类任务实战和难点总结
bert模型简介.transformers中bert模型源码阅读.分类任务实战和难点总结:https://blog.csdn.net/HUSTHY/article/details/105882989 ...
- 探秘Tomcat(一)——Myeclipse中导入Tomcat源码
前言:有的时候自己不知道自己是井底之蛙,这并没有什么可怕的,因为你只要蜷缩在方寸之间的井里,无数次的生活轨迹无非最终归结还是一个圆形:但是可怕的是有一天你不得不从井里跳出来生活,需要重新审视井以外的生 ...
- 深入java并发包源码(三)AQS独占方法源码分析
深入java并发包源码(一)简介 深入java并发包源码(二)AQS的介绍与使用 深入java并发包源码(三)AQS独占方法源码分析 AQS 的实现原理 学完用 AQS 自定义一个锁以后,我们可以来看 ...
- 阅读react-redux源码(三) - mapStateToPropsFactories、mapDispatchToPropsFactories和mergePropsFactories
阅读react-redux源码 - 零 阅读react-redux源码 - 一 阅读react-redux源码(二) - createConnect.match函数的实现 阅读react-redux源 ...
- 【java】浅析JDK中ServiceLoader的源码
1.概述 转载:浅析JDK中ServiceLoader的源码 上一篇文章:深入探讨 Java 类加载器 2.ServiceLoader的使用 这里先列举一个经典的例子,MySQL的Java驱动就是通过 ...
- 剑指spring源码(三)
文章目录 剑指spring源码(三) 前言 注册BeanPostProcessors 剑指spring源码(三) 前言 在我的spring源码系列文章已经写了注册bd和执行BeanFactory的后置 ...
- Uboot中start.S源码的指令级的详尽解析
Uboot中start.S源码的指令级的详尽解析 版本:v1.9 Crifan Li 摘要 本文对Uboot中的Start.S的源码的几乎每一行,都进行了详细的解析 本文提供多种格式供: 在线阅读 H ...
最新文章
- Nginx 的这些妙用,你都 get 到了吗?
- laravel-admin 关闭debug模式导致异常信息到页面的排查
- 如何开发利用计算机字体,如何利用PS来制作水纹特效字体
- LaTeX 使用 bib 管理参考文献时,引用网络资源 URL 导致排版难看的问题
- aspose.cells html excel导出,C#使用Aspose.Cells导出Excel简单实现
- 戴尔发布面向制造、生命科学和研究的高性能计算系统
- liferay 导入源码问题
- 第三次学JAVA再学不好就吃翔(part33)--final关键字
- Linux Qt使用POSIX多线程条件变量、互斥锁(量)
- HDU 2653 (记忆化BFS搜索+优先队列)
- [导入]ASP.NET2.0中Tabs的简单实现
- LAMP_ 访问控制
- java剪刀石头布游戏心得体会,基于JAVA的剪刀石头布游戏设计——Java课程设计报告_...
- 查询 加载时间过长添加提示信息
- 大厂P5、P6、P7级程序员的简历长什么样?
- 【PyTorch修炼】一、安装GPU的pytorch详细教程(避坑)
- 色彩校正中的Gamma(人眼亮度感应的非线性)
- 自助订餐管理系统(小程序+后台源码+数据库)
- 用python玩转数据测试答案_MOOC_用Python玩转数据_测试答案
- IDEA 返回上一步,回到下一步 冲突 快捷键设置