CV领域Transformer这一篇就够了(原理详解+pytorch代码复现)
文章目录
- 前言
- 一、注意力机制
- 1.1注意力机制通俗理解
- 1.2注意力机制计算公式
- 1.3注意力机制计算过程
- 1.4注意力机制代码
- 二、自注意力机制
- 2.1 注意力机制和自注意力机制的区别
- 2.2 编码-译码中的attention
- 2.3自注意力机制计算流程
- 三、多头注意力机制
- 3.1多头注意力机制计算过程
- 3.2 多头自注意力机制计算过程
- 3.3位置编码
- 四、Vision Teansformer(ViT)
- 4.1 Embedding层
- 4.2 Encoder层
- 4.3 MLP Head层
- 4.4 ViT代码实现
- 五、Swin Tranformer
- 六、其他Transformer的改进
- 6.1 Hybird ViT
前言
本文主要介绍:注意力机制、自注意力机制、多头注意力机制、ViT、Swin Tranformer、其他Transformer的改进,并配合代码实现。
参考链接:
(饭范仁义-AI编程)https://www.bilibili.com/video/BV1nL4y1j7hA?spm_id_from=333.999.0.0&vd_source=b2549fdee562c700f2b1f3f49065201b
(霹雳巴啦Wz)https://blog.csdn.net/qq_37541097/article/details/117691873
一、注意力机制
1.1注意力机制通俗理解
注意力机制本质上与人类对外界事物的观察机制相似。通常来说,人们在观察外界事物的时候,首先会比较关注比较倾向于观察事物某些重要的局部信息,然后再把不同区域的信息组合起来,从而形成一个对被观察事物的整体印象,实现关注重要有用信息,抑制其他无用信息。
Attention机制最先应用在自然语言处理方面,主要是为了改进文本之间的编码方式,通过编码-解码之后能学习到更好的序列信息。
可以总体上分为两类:
聚焦式(focus)注意力:自上而下的有意识的注意力,主动注意——是指有预定目的、依赖任务的、主动有意识地聚焦于某一对象的注意力;
显著性(saliency-based)注意力:自下而上的有意识的注意力,被动注意——基于显著性的注意力是由外界刺激驱动的注意,不需要主动干预,也和任务无关;可以将max-pooling和门控(gating)机制来近似地看作是自下而上的基于显著性的注意力机制。
在人工神经网络中,注意力机制一般就特指聚焦式注意力。
1.2注意力机制计算公式
现在你可能还看不懂这个公式具体在讲什么,接下来我将详细简明的阐述。
第一阶段,需要三个指定的输入Q(query),K(key),V(value),可以引入不同函数和计算机制,根据Q和K,计算两者的相似性和相关性,d为K的维度dim。
第二阶段,引入类似的softmax的计算方式对第一阶段得分进行数值转换,一方面可以进行归一化,计算所有元素权重之和为1,另一方面可以通过softmax突出元素的权重。
第三阶段,通过计算结果a和V对应的权重系数,然后加权求和得到Attention数值。
(当输入的Q=K=V时,称作自注意力计算规则)。
举个例子:
Q(查询)和K(键)转置进行点乘(对于位置相乘求和),得到了各项查询的相似度,再除d,得到的是一个实数值,使用softmax将其变为权重(小于1的值),相似度权重x价值,就是求得的注意力。
1.3注意力机制计算过程
1.Input:输入Q、K、V三个向量;
2.a(i,j):每个qi分别和不同的kj乘,得a(i,j) = qi · kj;(应该是K的转置),a(i,j)为一个实数值。
3.除dim:为了梯度的稳定,Transformer使用了归一化,对a(i,j) 除以根号d,(d为k的维度);
4.softmax:对同一个i的a(i,j) ,施以softmax激活函数;
5.乘V:对于每个i,a(i,j)乘vj后求和,得到加权的每个输入向量ai的注意力评分bi;
q:代表query,后续会去和每一个k进行匹配(相乘)
k:代表key,后续会被每个q匹配(相乘)
v:代表从a 中提取得到的信息
后续q 和k 匹配的过程可以理解成计算两者的相关性,相关性越大对应v 的权重也就越大。
通过上述讲解,我们了解了单个qi是如何求注意力评分bi的,接下来仅需合并成矩阵,进行并行运算,一次求得多个输入的注意力评分矩阵B。
1.Q和K转置进行点乘,除根号d,进softmax,得相关性矩阵
2.相关性矩阵乘V得注意力评分矩阵B
Attention机制的实质其实就是一个寻址(addressing)的过程,如上图所示:给定一个和任务相关的查询Query向量 q,通过计算与Key的注意力分布并附加在Value上,从而计算Attention Value,这个过程实际上是Attention机制缓解神经网络模型复杂度的体现:不需要将所有的N个输入信息都输入到神经网络进行计算,只需要从中选择一些和查询Query相关的信息输入给神经网络。
1.4注意力机制代码
# pytorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F# 缩放点积注意力
class ScaledDotProductAttention(nn.Module):''' Scaled Dot-Product Attention '''def __init__(self, temperature, attn_dropout=0.1):super().__init__()# temperature是k的维度dkself.temperature = temperatureself.dropout = nn.Dropout(attn_dropout)#外部输入q、k、vdef forward(self, q, k, v, mask=None):# a = (q/dk) 与 k的转置 矩阵相乘attn = torch.matmul(q / self.temperature, k.transpose(2, 3))# 是否进行maskif mask is not None:attn = attn.masked_fill(mask == 0, -1e9)# softmax+dropout得到相似性矩阵attn = self.dropout(F.softmax(attn, dim=-1))# 相似性矩阵与v矩阵相乘,得注意力评价矩阵output = torch.matmul(attn, v)# 返回:注意力评价矩阵 和 相似性矩阵return output, attn
二、自注意力机制
2.1 注意力机制和自注意力机制的区别
自注意力机制:Query=Key=Value=输入
传统的Attention:
Q来自于外部,K、V
Q在Decoder目标处,K、V在Encoder源头处self-Attention:
Q、K、V是对自身(self)输入的变换
Q、K、V在同一处(Decoder目标或Encoder源头处)
2.2 编码-译码中的attention
汉译英编码-译码模型:
无attention的编码-译码模型
有attention的编码-译码模型
2.3自注意力机制计算流程
1.Input:输入单词或图片xi;
2.Embedding:将单词、图片转化为转化成嵌入向量ai;
3.Querys、Keys、Values:a分别对Wq、Wk、Wv(这三个参数是可训练的,是共享的)矩阵乘法,得到Q、K、V三个向量;
4.a(i,j):每个qi分别和不同的kj乘,得a(i,j) = qi · kj;(应该是K的转置),a(i,j)为一个实数值。
5.除dim:为了梯度的稳定,Transformer使用了归一化,对a(i,j) 除以根号d,(d为k的维度);
6.softmax:对同一个i的a(i,j) ,施以softmax激活函数;
7.乘V:对于每个i,a(i,j)乘vj后求和,得到加权的每个输入向量ai的注意力评分bi;
矩阵计算:
1.X进行Embeding后得到输入矩阵A
2.A分别与Wq、Wk、Wv相乘得到Q、K、V矩阵
3.Q和K转置进行点乘,除根号d,进softmax,得相关性矩阵
4.相关性矩阵乘V得注意力评分矩阵B
self-attention就是对输入向量的权重进行调整。
三、多头注意力机制
刚刚已经聊完了Self-Attention模块,接下来再来看看Multi-Head Attention模块,实际使用中基本使用的还是Multi-Head Attention模块。其实只要懂了Self-Attention模块Multi-Head Attention模块就非常简单了,多头注意力就是对单头注意力的简单堆叠。
3.1多头注意力机制计算过程
(无embeding操作)
就是和attention类似,将输入X分别通过多组不同的Wqi、Wki、Wvi得到多组不同的Qi、Ki、Vi,然后得到了不同的结果,进行拼接,通过线性层乘Wo得到与输入矩阵维度相等的结果。
3.2 多头自注意力机制计算过程
1.QKV分头:
对得到的qi、ki、vi按n个head(n=2)进行均分为q(i,j)、k(i,j)、v(i,j),(其中j=1~n)
2.对于每个 j 的q、k、v 是一个头,共分为n个头,如上图的q(i,1)、k(i,1)、v(i,1)是一个head(i=1和2)
3.对每个head,执行self-attention的同样的操作,对每组q(i,j)、k(i,j)、v(i,j)求得 自注意力评分b(i,j).
4. b(i,j)按照二维矩阵 拼接成B,B乘以Wo。( Wo的作用:是保证multi-head-self-attention输出的向量和输入的长度一致。)
Multi-head-self-attention最终效果:
3.3位置编码
位置编码要和ai相加,则shape的ai一样。
四、Vision Teansformer(ViT)
ViT由3个模块组成:
Linear Projection of Flattened Patches(Embedding层):Patch embedding+Position embedding+Class token输入Encoder层
Transformer Encoder(Encoder层):将上图右边的结构重复堆叠L次
MLP Head(最终用于分类的层结构):只提取Class token的输出,进行得到分类的结果
4.1 Embedding层
对于标准的Transformer模块,要求输入的是token(向量)序列,即二维矩阵[num_token, token_dim],如下图,token0-token9对应的都是向量,以ViT-B/16为例,每个token向量长度为768。
对于图像数据而言,其数据格式为 [H, W, C] 是三维矩阵明显不是Transformer想要的。所以需要先通过一个Embedding层来对三维数据变换为二维数据。如下图所示,首先将一张图片按给定大小分成一堆Patches(图片块)。
以ViT-B/16为例,将大小224x224的输入图片按照16x16大小的Patch进行划分,划分后会得到196个Patches。接着通过线性映射将每个Patch映射到一维向量中,每个Patche数据shape为[16, 16, 3]通过映射得到一个长度为768的token向量(后面都直接称为token)。[16, 16, 3] -> [768]
在代码实现中,直接通过一个卷积层来实现。 以ViT-B/16为例,直接使用一个卷积核大小为16x16,stride为16,卷积核个数为768的卷积来实现。通过卷积[224, 224, 3] -> [14, 14, 768]
,然后把H以及W两个维度展平[W,H,C]->[W*H,C],如[14, 14, 768] -> [196, 768]
,此时正好变成了一个二维矩阵,正是Transformer想要的。
在输入Transformer Encoder之前注意需要加上图片类别 [class]token 放在positoin=0处以及叠加Position Embedding。 以ViT-B/16为例,就是一个长度为768的向量,与之前从图片中生成的tokens拼接在一起,Cat([1, 768], [196, 768]) -> [197, 768]
。然后关于Position Embedding就是之前Transformer中讲到的Positional Encoding,这里的Position Embedding采用的是一个可训练的参数是直接叠加在tokens上的(add),所以shape要一样。以ViT-B/16为例,刚刚拼接[class]token后shape是[197, 768],那么这里的Position Embedding的shape也是[197, 768]。对于Position Embedding作者也有做一系列对比试验,在源码中默认使用的是1D Pos. Emb。
图片中每个patch求得的token 都有一个位置编码,这些位置编码彼此间的余弦相似度如上图。黄色相似度高,蓝色相似度低。亮点就是对应该token的位置编码在原图中的位置。这就是最终学习到的位置编码。
4.2 Encoder层
Transformer Encoder其实就是堆叠Encoder Block重复 L次,Encoder Block,主要由以下几部分组成:
·Layer Norm,这种Normalization方法主要是针对NLP领域提出的,这里是对每个token进行Norm处理。
·Multi-Head Attention,这个结构之前在讲Transformer中很详细的讲过,不再赘述。
·Dropout/DropPath,在原论文的代码中是直接使用的Dropout层,在但实现代码中使用的是DropPath(stochastic depth),可能后者会更好一点。(不了解Droppath的可以看这篇介绍Droppath通俗易懂)
·MLP Block,如上图右侧所示,就是全连接+GELU激活函数+Dropout组成也非常简单,需要注意的是第一个全连接层会把输入节点个数翻4倍[197, 768] -> [197, 3072]
,第二个全连接层会还原回原节点个数[197, 3072] -> [197, 768]
·残差结构, 将输入与dropout层输出相加。
4.3 MLP Head层
其中pre-logits就是一个全连接层+tanh激活函数。
下图是ViT-B/16的一个总体结构:
4.4 ViT代码实现
"""
original code from rwightman:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from functools import partial
from collections import OrderedDictimport torch
import torch.nn as nndef drop_path(x, drop_prob: float = 0., training: bool = False):if drop_prob == 0. or not training:return xkeep_prob = 1 - drop_probshape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNetsrandom_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)random_tensor.floor_() # binarizeoutput = x.div(keep_prob) * random_tensorreturn outputclass DropPath(nn.Module):"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""def __init__(self, drop_prob=None):super(DropPath, self).__init__()self.drop_prob = drop_probdef forward(self, x):return drop_path(x, self.drop_prob, self.training)# PatchEmbedding层(通过卷积实现)
class PatchEmbed(nn.Module):"""2D Image to Patch Embedding"""def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):super().__init__()img_size = (img_size, img_size) # img_size图片大小patch_size = (patch_size, patch_size) # patch_size图像块大小(也是卷积核大小)self.img_size = img_sizeself.patch_size = patch_sizeself.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) # //表取整除self.num_patches = self.grid_size[0] * self.grid_size[1]# 定义卷积层proj,in_c输入通道数(rgb3通道),embed_dim卷积核个数(卷积层输出通道数)self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)# 如果norm_layer不为空,则进行正则化,self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()def forward(self, x):# 输入图像X# assert检查输入图像大小,B(batch_size), C(channel), H(height), W(weight)B, C, H, W = x.shapeassert H == self.img_size[0] and W == self.img_size[1], \f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."# proj(卷积)# flatten(压平H,W): [B, C, H, W] -> [B, C, HW]# transpose(交换后两维): [B, C, HW] -> [B, HW, C]x = self.proj(x).flatten(2).transpose(1, 2)x = self.norm(x)return x# Encoder Block中的MultiHead-Self-Attention层
class Attention(nn.Module):def __init__(self,dim, # 输入token的dimnum_heads=8, # head数qkv_bias=False, # 生成qkv不用baisqk_scale=None, # None时使用:根号dk分之一attn_drop_ratio=0., # dropout率proj_drop_ratio=0.): # dropout率super(Attention, self).__init__()self.num_heads = num_headshead_dim = dim // num_heads # 分头:计算每个head均分得到的q,k,v个数self.scale = qk_scale or head_dim ** -0.5 # qk_scale是根号下head_dim分之一,就是q*k转置后乘的那个:根号dk分之一self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # 通过qkv全连接层:(q,k,v)=X·(Wq,Wk,Wv),一次并行求得qkv# 全连接层:in_features输入特征个数=dim,out_features输出特征个数(全连接层节点个数)=dim*3self.attn_drop = nn.Dropout(attn_drop_ratio)self.proj = nn.Linear(dim, dim) # 通过proj全连接层:B=B·Wo,进行bij->bi拼接后的映射self.proj_drop = nn.Dropout(proj_drop_ratio)def forward(self, x):# [batch_size, num_patches + 1, total_embed_dim]# (num_patches + 1的1是class token,num_patches + 1个向量大小都是total_embed_dim)B, N, C = x.shape# qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]# reshape分qkv分头: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]# permute调序: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)# 切片q、k、v,都是[batch_size, num_heads, num_patches + 1, embed_dim_per_head]q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)# transpose:原q、k、v-> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]# @ 矩阵乘法: 多维矩阵乘法只乘最后两维 -> [batch_size, num_heads, num_patches + 1, num_patches + 1]# q乘k转置,乘根号dkattn = (q @ k.transpose(-2, -1)) * self.scale# dim=-1表示attn在每一行进行softmax处理attn = attn.softmax(dim=-1)attn = self.attn_drop(attn)# @ 矩阵乘法: -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]# transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]# reshape: -> [batch_size, num_patches + 1, total_embed_dim]# softmax(qk相似度) 乘 v,reshape进行bij->bi的拼接映射x = (attn @ v).transpose(1, 2).reshape(B, N, C)# 通过proj全连接层:B=B·Wo映射x = self.proj(x)x = self.proj_drop(x)return x# Encoder Block中的MLP(两个全连接层)
class Mlp(nn.Module):"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""# in_features输入特征个数,hidden_features第一个全连接层节点个数,out_features第二个全连接层节点个数,act_layer激活函数def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):super().__init__()out_features = out_features or in_features # out_features=None,in_featureshidden_features = hidden_features or in_features # 同上self.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer() # GELU激活函数self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return x# Encoder Block
class Block(nn.Module):def __init__(self,dim, # token 的dimnum_heads, # head数mlp_ratio=4., # mlp中第一个全连接层的节点个数是输入的4倍qkv_bias=False, # 是否使用biasqk_scale=None, # 根号dk分之一drop_ratio=0., # attention中的drop_out率attn_drop_ratio=0., # attention中的drop_out率drop_path_ratio=0., # Encoder Block中的drop_path率act_layer=nn.GELU, # 激活函数norm_layer=nn.LayerNorm): # normalization使用LayerNormsuper(Block, self).__init__()# 实例化LayerNorm层self.norm1 = norm_layer(dim)# 实例化Attention层self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here# 实例化DropPath层self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()# 实例化LayerNorm层self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)# 实例化Mlp层self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)def forward(self, x):# 这里的+=都引入了恒等映射的残差思想x = x + self.drop_path(self.attn(self.norm1(x)))x = x + self.drop_path(self.mlp(self.norm2(x)))return x# ViT
class VisionTransformer(nn.Module):def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,act_layer=None):"""Args:img_size (int, tuple): input image sizepatch_size (int, tuple): patch sizein_c (int): number of input channelsnum_classes (int): number of classes for classification headembed_dim (int): embedding dimension,patch embedding层卷积核个数depth (int): 是Encoder重复次数,depth of transformernum_heads (int): number of attention headsmlp_ratio (int): ratio of mlp hidden dim to embedding dimqkv_bias (bool): enable bias for qkv if Trueqk_scale (float): override default qk scale of head_dim ** -0.5 if setrepresentation_size (Optional[int]):是否构建MLP层的pre-logits,enable and set representation layer (pre-logits) to this value if setdistilled (bool): 为了兼容搭建DeiT的参数,model includes a distillation token and head as in DeiT modelsdrop_ratio (float): dropout rateattn_drop_ratio (float): attention dropout ratedrop_path_ratio (float): stochastic depth rateembed_layer (nn.Module): patch embedding layernorm_layer: (nn.Module): normalization layer"""super(VisionTransformer, self).__init__()self.num_classes = num_classes # 分类数self.num_features = self.embed_dim = embed_dim # num_features for consistency with other modelsself.num_tokens = 2 if distilled else 1 # num_tokens默认为1norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) # normalization默认为LayerNorm# partial函数的功能就是:把一个函数的某些参数给默认固定住,返回一个新的函数act_layer = act_layer or nn.GELU # activate function默认为GELU# patch_embed层self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)num_patches = self.patch_embed.num_patches# class token初始化第一个1是batch_sizeself.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))# 不用管DeiT模型的dist_tokenself.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None# Position embedding位置编码初始化self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))# dropout层self.pos_drop = nn.Dropout(p=drop_ratio)# 生成一个drop_path率的序列dpr,共depth个,大小从0到drop_path_ratio递增dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule# 构建depth个连续的Encoder blockself.blocks = nn.Sequential(*[Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],norm_layer=norm_layer, act_layer=act_layer)for i in range(depth)])# Encoder block后的norm_layerself.norm = norm_layer(embed_dim)# Representation layer是否构建MLP层的pre-logitsif representation_size and not distilled:self.has_logits = Trueself.num_features = representation_size# pre-logits就是一个全连接层+tanh激活函数# embed_dim输入节点个数,representation_size输出节点个数self.pre_logits = nn.Sequential(OrderedDict([("fc", nn.Linear(embed_dim, representation_size)),("act", nn.Tanh())]))else:self.has_logits = Falseself.pre_logits = nn.Identity()# Classifier head(s),最后一层全连接层分类,num_features输入节点个数,num_classes输出节点个数self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()# 后面不用看,是DeiT模型的self.head_dist = Noneif distilled:self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()# Weight init,初始化pos_embed# trunc_normal_利用正态分布生成一个点,点在[a, b]区间之内nn.init.trunc_normal_(self.pos_embed, std=0.02)# 后面不用看,是DeiT模型的if self.dist_token is not None:nn.init.trunc_normal_(self.dist_token, std=0.02)# Weight init,初始化cls_tokennn.init.trunc_normal_(self.cls_token, std=0.02)# 调用vit初始函数self.apply(_init_vit_weights)def forward_features(self, x):# patch embedding# [B, C, H, W] -> [B, num_patches, embed_dim]x = self.patch_embed(x) # [B, 196, 768]# class token# [1, 1, 768] -> [B, 1, 768]cls_token = self.cls_token.expand(x.shape[0], -1, -1)# concat拼接cls_token和patch_token# ViT中dist_token为None,执行第一个if self.dist_token is None: # ViT中dist_token为Nonex = torch.cat((cls_token, x), dim=1) # [B, 197, 768]else:x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)# 加上位置编码position embeddingx = self.pos_drop(x + self.pos_embed)# 现在的token=[class + patch]+ position# encoder block层x = self.blocks(x)# normalization层x = self.norm(x)# MPL的pre_logits# ViT中dist_token为None,执行第一个if self.dist_token is None:return self.pre_logits(x[:, 0]) # 只返回class token输出的列else:return x[:, 0], x[:, 1]def forward(self, x):# 返回class token输出的列x = self.forward_features(x)# head_dist为None,执行elseif self.head_dist is not None:x, x_dist = self.head(x[0]), self.head_dist(x[1])if self.training and not torch.jit.is_scripting():# during inference, return the average of both classifier predictionsreturn x, x_distelse:return (x + x_dist) / 2else:x = self.head(x) # 最后的全连接层输出分类结果return xdef _init_vit_weights(m):"""ViT weight initialization:param m: module"""if isinstance(m, nn.Linear):nn.init.trunc_normal_(m.weight, std=.01)if m.bias is not None:nn.init.zeros_(m.bias)elif isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode="fan_out")if m.bias is not None:nn.init.zeros_(m.bias)elif isinstance(m, nn.LayerNorm):nn.init.zeros_(m.bias)nn.init.ones_(m.weight)# 至此我们已经完成了ViT所有模块的编写
# ——————————————————————————————————————————————————————————————————————————————————————————————————————————————————————def vit_base_patch16_224(num_classes: int = 1000):"""ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f"""model = VisionTransformer(img_size=224,patch_size=16,embed_dim=768,depth=12,num_heads=12,representation_size=None,num_classes=num_classes)return modeldef vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):"""ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth"""model = VisionTransformer(img_size=224,patch_size=16,embed_dim=768,depth=12,num_heads=12,representation_size=768 if has_logits else None,num_classes=num_classes)return modeldef vit_base_patch32_224(num_classes: int = 1000):"""ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl"""model = VisionTransformer(img_size=224,patch_size=32,embed_dim=768,depth=12,num_heads=12,representation_size=None,num_classes=num_classes)return modeldef vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):"""ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth"""model = VisionTransformer(img_size=224,patch_size=32,embed_dim=768,depth=12,num_heads=12,representation_size=768 if has_logits else None,num_classes=num_classes)return modeldef vit_large_patch16_224(num_classes: int = 1000):"""ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8"""model = VisionTransformer(img_size=224,patch_size=16,embed_dim=1024,depth=24,num_heads=16,representation_size=None,num_classes=num_classes)return modeldef vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):"""ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth"""model = VisionTransformer(img_size=224,patch_size=16,embed_dim=1024,depth=24,num_heads=16,representation_size=1024 if has_logits else None,num_classes=num_classes)return modeldef vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):"""ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth"""model = VisionTransformer(img_size=224,patch_size=32,embed_dim=1024,depth=24,num_heads=16,representation_size=1024 if has_logits else None,num_classes=num_classes)return modeldef vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True):"""ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.NOTE: converted weights not currently available, too large for github release hosting."""model = VisionTransformer(img_size=224,patch_size=14,embed_dim=1280,depth=32,num_heads=16,representation_size=1280 if has_logits else None,num_classes=num_classes)return model
五、Swin Tranformer
六、其他Transformer的改进
6.1 Hybird ViT
先用Resnet50特征提取,再用ViT进一步处理分类。
其中Resnet50部分做出了一些修改;
epoch较多时,混合模型模型反而效果不如纯正的ViT。
CV领域Transformer这一篇就够了(原理详解+pytorch代码复现)相关推荐
- 玩转Mysql系列 - 第22篇:mysql索引原理详解
Mysql系列的目标是:通过这个系列从入门到全面掌握一个高级开发所需要的全部技能. 欢迎大家加我微信itsoku一起交流java.算法.数据库相关技术. 这是Mysql系列第22篇. 背景 使用mys ...
- 自问自答学ArrayList,看这篇就够了,详解问答
前言 在之前的几篇文章里面,我主要都是推荐了一些工具类,为的就是让大家可以提高开发效率,但是我们在提高开发效率,也应该提高代码的执行效率,注重代码的质量.如何提高,其中的一个好办法就是阅读源码,知其然 ...
- arp 项删除失败: 请求的操作需要提升。_ccna必懂篇-arp协议工作原理详解。
本次呢,要说的是arp协议,那么什么是arp协议呢?有什么作用呢? 什么是arp ARP(Address Resolution Protocol)地址解析协议,地址解析协议由互联网工程任务组(IETF ...
- 玩转Luat 进阶篇②——远程升级功能原理详解
目录 一.简介 二.从云端获取升级包(新固件) 2.1 合宙官方服务器openluat 2.2 阿里云物联网平台 三.本地固件更新运行 3.1 合宙4G模块 3.1.1 合宙4G模块的Flash分区 ...
- Transformer 初识:模型结构+attention原理详解
Transformer 初识:模型结构+原理详解 参考资源 前言 1.整体结构 1.1 输入: 1.2 Encoder 和 Decoder的结构 1.3 Layer normalization Bat ...
- [Python从零到壹] 四十七.图像增强及运算篇之腐蚀和膨胀详解
欢迎大家来到"Python从零到壹",在这里我将分享约200篇Python系列文章,带大家一起去学习和玩耍,看看Python这个有趣的世界.所有文章都将结合案例.代码和作者的经验讲 ...
- 离线强化学习(Offline RL)系列3: (算法篇)策略约束 - BEAR算法原理详解与实现
论文信息:Stabilizing Off-Policy Q-Learning via Bootstrapping Error Reduction 本文由UC Berkeley的Sergey Levin ...
- PowerShell攻防进阶篇:nishang工具用法详解
PowerShell攻防进阶篇:nishang工具用法详解 导语:nishang,PowerShell下并肩Empire,Powersploit的神器. 开始之前,先放出个下载地址! 下载地址:htt ...
- 离线强化学习(Offline RL)系列3: (算法篇) IQL(Implicit Q-learning)算法详解与实现
[更新记录] 论文信息:Ilya Kostrikov, Ashvin Nair, Sergey Levine: "Offline Reinforcement Learning with Im ...
最新文章
- golang 同一个包中函数互相调用报错 undefined 以及在 VSCode 中配置右键执行整个包文件
- 【强势来袭】Node.js(nodejs)实现“一口多用”(含用户创建、登录、鉴权token) 一个文件解决所有常态化需求
- xilinx LVDS使用注意事项
- JQuery UI之Autocomplete(3)属性与事件
- electron 打包 php,electron 将现有vue项目改成支持electron桌面端应用
- 推荐eclipse插件Properties Editor
- Facebook 最新可佩戴 AR 设备、AR 设备未来五年市场扩张、语音社交新创Swell等|Decode the Week...
- Java学习笔记之 IO包 字符流
- 计算机科学与技术专业《计算机网络原理》课程实验指导书,计算机科学导论,课程实验指导书解读.pdf...
- 低学历的非要考研,多半输得更惨
- jQuery初识 - jQuery是什么
- linux mysql 修改root密码_MySQL忘了root密码,如何修改?
- underscore.js依赖库函数分析一(遍历)
- Oracle字符集设置
- java一键生成海报_小程序生成海报(java后端)
- VS2015,错误RC1015: 无法打开包含文件afxres.h
- python中判断无向图是否有环_数据结构与算法:17 图
- PWM控制的基本原理
- Linux的网络编程面试题汇总
- 彻底弄懂base64的编码与解码原理