文章目录

  • 前言
  • 一、注意力机制
    • 1.1注意力机制通俗理解
    • 1.2注意力机制计算公式
    • 1.3注意力机制计算过程
    • 1.4注意力机制代码
  • 二、自注意力机制
    • 2.1 注意力机制和自注意力机制的区别
    • 2.2 编码-译码中的attention
    • 2.3自注意力机制计算流程
  • 三、多头注意力机制
    • 3.1多头注意力机制计算过程
    • 3.2 多头自注意力机制计算过程
    • 3.3位置编码
  • 四、Vision Teansformer(ViT)
    • 4.1 Embedding层
    • 4.2 Encoder层
    • 4.3 MLP Head层
    • 4.4 ViT代码实现
  • 五、Swin Tranformer
  • 六、其他Transformer的改进
    • 6.1 Hybird ViT

前言

本文主要介绍:注意力机制、自注意力机制、多头注意力机制、ViT、Swin Tranformer、其他Transformer的改进,并配合代码实现。

参考链接:
(饭范仁义-AI编程)https://www.bilibili.com/video/BV1nL4y1j7hA?spm_id_from=333.999.0.0&vd_source=b2549fdee562c700f2b1f3f49065201b
(霹雳巴啦Wz)https://blog.csdn.net/qq_37541097/article/details/117691873


一、注意力机制

1.1注意力机制通俗理解

注意力机制本质上与人类对外界事物的观察机制相似。通常来说,人们在观察外界事物的时候,首先会比较关注比较倾向于观察事物某些重要的局部信息,然后再把不同区域的信息组合起来,从而形成一个对被观察事物的整体印象,实现关注重要有用信息,抑制其他无用信息
Attention机制最先应用在自然语言处理方面,主要是为了改进文本之间的编码方式,通过编码-解码之后能学习到更好的序列信息。

可以总体上分为两类:
聚焦式(focus)注意力:自上而下的有意识的注意力,主动注意——是指有预定目的、依赖任务的、主动有意识地聚焦于某一对象的注意力;
显著性(saliency-based)注意力:自下而上的有意识的注意力,被动注意——基于显著性的注意力是由外界刺激驱动的注意,不需要主动干预,也和任务无关;可以将max-pooling和门控(gating)机制来近似地看作是自下而上的基于显著性的注意力机制。
在人工神经网络中,注意力机制一般就特指聚焦式注意力。

1.2注意力机制计算公式

现在你可能还看不懂这个公式具体在讲什么,接下来我将详细简明的阐述。

第一阶段,需要三个指定的输入Q(query),K(key),V(value),可以引入不同函数和计算机制,根据Q和K,计算两者的相似性和相关性,d为K的维度dim。

第二阶段,引入类似的softmax的计算方式对第一阶段得分进行数值转换,一方面可以进行归一化,计算所有元素权重之和为1,另一方面可以通过softmax突出元素的权重。

第三阶段,通过计算结果a和V对应的权重系数,然后加权求和得到Attention数值。

(当输入的Q=K=V时,称作自注意力计算规则)。

举个例子:

Q(查询)和K(键)转置进行点乘(对于位置相乘求和),得到了各项查询的相似度,再除d,得到的是一个实数值,使用softmax将其变为权重(小于1的值),相似度权重x价值,就是求得的注意力。

1.3注意力机制计算过程

1.Input:输入Q、K、V三个向量;
2.a(i,j):每个qi分别和不同的kj乘,得a(i,j) = qi · kj;(应该是K的转置),a(i,j)为一个实数值。
3.除dim:为了梯度的稳定,Transformer使用了归一化,对a(i,j) 除以根号d,(d为k的维度);
4.softmax:对同一个i的a(i,j) ,施以softmax激活函数;
5.乘V:对于每个i,a(i,j)乘vj后求和,得到加权的每个输入向量ai的注意力评分bi;

q:代表query,后续会去和每一个k进行匹配(相乘)
k:代表key,后续会被每个q匹配(相乘)
v:代表从a 中提取得到的信息
后续q 和k 匹配的过程可以理解成计算两者的相关性,相关性越大对应v 的权重也就越大。

通过上述讲解,我们了解了单个qi是如何求注意力评分bi的,接下来仅需合并成矩阵,进行并行运算,一次求得多个输入的注意力评分矩阵B。

1.Q和K转置进行点乘,除根号d,进softmax,得相关性矩阵
2.相关性矩阵乘V得注意力评分矩阵B

Attention机制的实质其实就是一个寻址(addressing)的过程,如上图所示:给定一个和任务相关的查询Query向量 q,通过计算与Key的注意力分布并附加在Value上,从而计算Attention Value,这个过程实际上是Attention机制缓解神经网络模型复杂度的体现:不需要将所有的N个输入信息都输入到神经网络进行计算,只需要从中选择一些和查询Query相关的信息输入给神经网络。

1.4注意力机制代码

# pytorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F# 缩放点积注意力
class ScaledDotProductAttention(nn.Module):''' Scaled Dot-Product Attention '''def __init__(self, temperature, attn_dropout=0.1):super().__init__()# temperature是k的维度dkself.temperature = temperatureself.dropout = nn.Dropout(attn_dropout)#外部输入q、k、vdef forward(self, q, k, v, mask=None):# a = (q/dk) 与 k的转置 矩阵相乘attn = torch.matmul(q / self.temperature, k.transpose(2, 3))# 是否进行maskif mask is not None:attn = attn.masked_fill(mask == 0, -1e9)# softmax+dropout得到相似性矩阵attn = self.dropout(F.softmax(attn, dim=-1))# 相似性矩阵与v矩阵相乘,得注意力评价矩阵output = torch.matmul(attn, v)# 返回:注意力评价矩阵 和 相似性矩阵return output, attn

二、自注意力机制

2.1 注意力机制和自注意力机制的区别

自注意力机制:Query=Key=Value=输入

传统的Attention:
Q来自于外部,K、V
Q在Decoder目标处,K、V在Encoder源头处

self-Attention:
Q、K、V是对自身(self)输入的变换
Q、K、V在同一处(Decoder目标或Encoder源头处)

2.2 编码-译码中的attention

汉译英编码-译码模型:

无attention的编码-译码模型

有attention的编码-译码模型

2.3自注意力机制计算流程

1.Input:输入单词或图片xi;
2.Embedding:将单词、图片转化为转化成嵌入向量ai;
3.Querys、Keys、Values:a分别对Wq、Wk、Wv(这三个参数是可训练的,是共享的)矩阵乘法,得到Q、K、V三个向量;
4.a(i,j):每个qi分别和不同的kj乘,得a(i,j) = qi · kj;(应该是K的转置),a(i,j)为一个实数值。
5.除dim:为了梯度的稳定,Transformer使用了归一化,对a(i,j) 除以根号d,(d为k的维度);
6.softmax:对同一个i的a(i,j) ,施以softmax激活函数;
7.乘V:对于每个i,a(i,j)乘vj后求和,得到加权的每个输入向量ai的注意力评分bi;




矩阵计算:

1.X进行Embeding后得到输入矩阵A
2.A分别与Wq、Wk、Wv相乘得到Q、K、V矩阵
3.Q和K转置进行点乘,除根号d,进softmax,得相关性矩阵
4.相关性矩阵乘V得注意力评分矩阵B

self-attention就是对输入向量的权重进行调整。

三、多头注意力机制

刚刚已经聊完了Self-Attention模块,接下来再来看看Multi-Head Attention模块,实际使用中基本使用的还是Multi-Head Attention模块。其实只要懂了Self-Attention模块Multi-Head Attention模块就非常简单了,多头注意力就是对单头注意力的简单堆叠

3.1多头注意力机制计算过程

(无embeding操作)

就是和attention类似,将输入X分别通过多组不同的Wqi、Wki、Wvi得到多组不同的Qi、Ki、Vi,然后得到了不同的结果,进行拼接,通过线性层乘Wo得到与输入矩阵维度相等的结果。


3.2 多头自注意力机制计算过程


1.QKV分头:
对得到的qi、ki、vi按n个head(n=2)进行均分为q(i,j)、k(i,j)、v(i,j),(其中j=1~n)


2.对于每个 j 的q、k、v 是一个头,共分为n个头如上图的q(i,1)、k(i,1)、v(i,1)是一个head(i=1和2)

3.对每个head,执行self-attention的同样的操作,对每组q(i,j)、k(i,j)、v(i,j)求得 自注意力评分b(i,j).


4. b(i,j)按照二维矩阵 拼接成B,B乘以Wo。( Wo的作用:是保证multi-head-self-attention输出的向量和输入的长度一致。)

Multi-head-self-attention最终效果:

3.3位置编码

位置编码要和ai相加,则shape的ai一样。


四、Vision Teansformer(ViT)


ViT由3个模块组成:
Linear Projection of Flattened Patches(Embedding层):Patch embedding+Position embedding+Class token输入Encoder层
Transformer Encoder(Encoder层):将上图右边的结构重复堆叠L次
MLP Head(最终用于分类的层结构):只提取Class token的输出,进行得到分类的结果

4.1 Embedding层

对于标准的Transformer模块,要求输入的是token(向量)序列,即二维矩阵[num_token, token_dim],如下图,token0-token9对应的都是向量,以ViT-B/16为例,每个token向量长度为768。


对于图像数据而言,其数据格式为 [H, W, C] 是三维矩阵明显不是Transformer想要的。所以需要先通过一个Embedding层来对三维数据变换为二维数据。如下图所示,首先将一张图片按给定大小分成一堆Patches(图片块)

以ViT-B/16为例,将大小224x224的输入图片按照16x16大小的Patch进行划分,划分后会得到196个Patches。接着通过线性映射将每个Patch映射到一维向量中,每个Patche数据shape为[16, 16, 3]通过映射得到一个长度为768的token向量(后面都直接称为token)。[16, 16, 3] -> [768]

在代码实现中,直接通过一个卷积层来实现。 以ViT-B/16为例,直接使用一个卷积核大小为16x16,stride为16,卷积核个数为768的卷积来实现。通过卷积[224, 224, 3] -> [14, 14, 768],然后把H以及W两个维度展平[W,H,C]->[W*H,C],如[14, 14, 768] -> [196, 768],此时正好变成了一个二维矩阵,正是Transformer想要的。

在输入Transformer Encoder之前注意需要加上图片类别 [class]token 放在positoin=0处以及叠加Position Embedding。 以ViT-B/16为例,就是一个长度为768的向量,与之前从图片中生成的tokens拼接在一起,Cat([1, 768], [196, 768]) -> [197, 768]。然后关于Position Embedding就是之前Transformer中讲到的Positional Encoding,这里的Position Embedding采用的是一个可训练的参数直接叠加在tokens上的(add),所以shape要一样。以ViT-B/16为例,刚刚拼接[class]token后shape是[197, 768],那么这里的Position Embedding的shape也是[197, 768]。对于Position Embedding作者也有做一系列对比试验,在源码中默认使用的是1D Pos. Emb。

图片中每个patch求得的token 都有一个位置编码,这些位置编码彼此间的余弦相似度如上图。黄色相似度高,蓝色相似度低。亮点就是对应该token的位置编码在原图中的位置。这就是最终学习到的位置编码。

4.2 Encoder层


Transformer Encoder其实就是堆叠Encoder Block重复 L次,Encoder Block,主要由以下几部分组成:

·Layer Norm,这种Normalization方法主要是针对NLP领域提出的,这里是对每个token进行Norm处理。




·Multi-Head Attention,这个结构之前在讲Transformer中很详细的讲过,不再赘述。
·Dropout/DropPath,在原论文的代码中是直接使用的Dropout层,在但实现代码中使用的是DropPath(stochastic depth),可能后者会更好一点。(不了解Droppath的可以看这篇介绍Droppath通俗易懂)
·MLP Block,如上图右侧所示,就是全连接+GELU激活函数+Dropout组成也非常简单,需要注意的是第一个全连接层会把输入节点个数翻4倍[197, 768] -> [197, 3072]第二个全连接层会还原回原节点个数[197, 3072] -> [197, 768]
·残差结构, 将输入与dropout层输出相加。

4.3 MLP Head层


其中pre-logits就是一个全连接层+tanh激活函数

下图是ViT-B/16的一个总体结构

4.4 ViT代码实现

"""
original code from rwightman:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from functools import partial
from collections import OrderedDictimport torch
import torch.nn as nndef drop_path(x, drop_prob: float = 0., training: bool = False):if drop_prob == 0. or not training:return xkeep_prob = 1 - drop_probshape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNetsrandom_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)random_tensor.floor_()  # binarizeoutput = x.div(keep_prob) * random_tensorreturn outputclass DropPath(nn.Module):"""Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""def __init__(self, drop_prob=None):super(DropPath, self).__init__()self.drop_prob = drop_probdef forward(self, x):return drop_path(x, self.drop_prob, self.training)# PatchEmbedding层(通过卷积实现)
class PatchEmbed(nn.Module):"""2D Image to Patch Embedding"""def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):super().__init__()img_size = (img_size, img_size)  # img_size图片大小patch_size = (patch_size, patch_size)  # patch_size图像块大小(也是卷积核大小)self.img_size = img_sizeself.patch_size = patch_sizeself.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])  # //表取整除self.num_patches = self.grid_size[0] * self.grid_size[1]# 定义卷积层proj,in_c输入通道数(rgb3通道),embed_dim卷积核个数(卷积层输出通道数)self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)# 如果norm_layer不为空,则进行正则化,self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()def forward(self, x):# 输入图像X# assert检查输入图像大小,B(batch_size), C(channel), H(height), W(weight)B, C, H, W = x.shapeassert H == self.img_size[0] and W == self.img_size[1], \f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."# proj(卷积)# flatten(压平H,W): [B, C, H, W] -> [B, C, HW]# transpose(交换后两维): [B, C, HW] -> [B, HW, C]x = self.proj(x).flatten(2).transpose(1, 2)x = self.norm(x)return x# Encoder Block中的MultiHead-Self-Attention层
class Attention(nn.Module):def __init__(self,dim,   # 输入token的dimnum_heads=8,  # head数qkv_bias=False,  # 生成qkv不用baisqk_scale=None,  # None时使用:根号dk分之一attn_drop_ratio=0.,  # dropout率proj_drop_ratio=0.):  # dropout率super(Attention, self).__init__()self.num_heads = num_headshead_dim = dim // num_heads  # 分头:计算每个head均分得到的q,k,v个数self.scale = qk_scale or head_dim ** -0.5  # qk_scale是根号下head_dim分之一,就是q*k转置后乘的那个:根号dk分之一self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)  # 通过qkv全连接层:(q,k,v)=X·(Wq,Wk,Wv),一次并行求得qkv# 全连接层:in_features输入特征个数=dim,out_features输出特征个数(全连接层节点个数)=dim*3self.attn_drop = nn.Dropout(attn_drop_ratio)self.proj = nn.Linear(dim, dim)  # 通过proj全连接层:B=B·Wo,进行bij->bi拼接后的映射self.proj_drop = nn.Dropout(proj_drop_ratio)def forward(self, x):# [batch_size, num_patches + 1, total_embed_dim]# (num_patches + 1的1是class token,num_patches + 1个向量大小都是total_embed_dim)B, N, C = x.shape# qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]# reshape分qkv分头: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]# permute调序: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)# 切片q、k、v,都是[batch_size, num_heads, num_patches + 1, embed_dim_per_head]q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)# transpose:原q、k、v-> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]# @ 矩阵乘法: 多维矩阵乘法只乘最后两维 -> [batch_size, num_heads, num_patches + 1, num_patches + 1]# q乘k转置,乘根号dkattn = (q @ k.transpose(-2, -1)) * self.scale# dim=-1表示attn在每一行进行softmax处理attn = attn.softmax(dim=-1)attn = self.attn_drop(attn)# @ 矩阵乘法: -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]# transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]# reshape: -> [batch_size, num_patches + 1, total_embed_dim]# softmax(qk相似度) 乘 v,reshape进行bij->bi的拼接映射x = (attn @ v).transpose(1, 2).reshape(B, N, C)# 通过proj全连接层:B=B·Wo映射x = self.proj(x)x = self.proj_drop(x)return x# Encoder Block中的MLP(两个全连接层)
class Mlp(nn.Module):"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""# in_features输入特征个数,hidden_features第一个全连接层节点个数,out_features第二个全连接层节点个数,act_layer激活函数def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):super().__init__()out_features = out_features or in_features  # out_features=None,in_featureshidden_features = hidden_features or in_features  # 同上self.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()  # GELU激活函数self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return x# Encoder Block
class Block(nn.Module):def __init__(self,dim,  # token 的dimnum_heads,  # head数mlp_ratio=4.,  # mlp中第一个全连接层的节点个数是输入的4倍qkv_bias=False,  # 是否使用biasqk_scale=None,  # 根号dk分之一drop_ratio=0.,  # attention中的drop_out率attn_drop_ratio=0.,  # attention中的drop_out率drop_path_ratio=0.,  # Encoder Block中的drop_path率act_layer=nn.GELU,  # 激活函数norm_layer=nn.LayerNorm):  # normalization使用LayerNormsuper(Block, self).__init__()# 实例化LayerNorm层self.norm1 = norm_layer(dim)# 实例化Attention层self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here# 实例化DropPath层self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()# 实例化LayerNorm层self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)# 实例化Mlp层self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)def forward(self, x):# 这里的+=都引入了恒等映射的残差思想x = x + self.drop_path(self.attn(self.norm1(x)))x = x + self.drop_path(self.mlp(self.norm2(x)))return x# ViT
class VisionTransformer(nn.Module):def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,act_layer=None):"""Args:img_size (int, tuple): input image sizepatch_size (int, tuple): patch sizein_c (int): number of input channelsnum_classes (int): number of classes for classification headembed_dim (int): embedding dimension,patch embedding层卷积核个数depth (int): 是Encoder重复次数,depth of transformernum_heads (int): number of attention headsmlp_ratio (int): ratio of mlp hidden dim to embedding dimqkv_bias (bool): enable bias for qkv if Trueqk_scale (float): override default qk scale of head_dim ** -0.5 if setrepresentation_size (Optional[int]):是否构建MLP层的pre-logits,enable and set representation layer (pre-logits) to this value if setdistilled (bool): 为了兼容搭建DeiT的参数,model includes a distillation token and head as in DeiT modelsdrop_ratio (float): dropout rateattn_drop_ratio (float): attention dropout ratedrop_path_ratio (float): stochastic depth rateembed_layer (nn.Module): patch embedding layernorm_layer: (nn.Module): normalization layer"""super(VisionTransformer, self).__init__()self.num_classes = num_classes  # 分类数self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other modelsself.num_tokens = 2 if distilled else 1  # num_tokens默认为1norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)  # normalization默认为LayerNorm# partial函数的功能就是:把一个函数的某些参数给默认固定住,返回一个新的函数act_layer = act_layer or nn.GELU  # activate function默认为GELU# patch_embed层self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)num_patches = self.patch_embed.num_patches# class token初始化第一个1是batch_sizeself.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))# 不用管DeiT模型的dist_tokenself.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None# Position embedding位置编码初始化self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))# dropout层self.pos_drop = nn.Dropout(p=drop_ratio)# 生成一个drop_path率的序列dpr,共depth个,大小从0到drop_path_ratio递增dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]  # stochastic depth decay rule# 构建depth个连续的Encoder blockself.blocks = nn.Sequential(*[Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],norm_layer=norm_layer, act_layer=act_layer)for i in range(depth)])# Encoder block后的norm_layerself.norm = norm_layer(embed_dim)# Representation layer是否构建MLP层的pre-logitsif representation_size and not distilled:self.has_logits = Trueself.num_features = representation_size# pre-logits就是一个全连接层+tanh激活函数# embed_dim输入节点个数,representation_size输出节点个数self.pre_logits = nn.Sequential(OrderedDict([("fc", nn.Linear(embed_dim, representation_size)),("act", nn.Tanh())]))else:self.has_logits = Falseself.pre_logits = nn.Identity()# Classifier head(s),最后一层全连接层分类,num_features输入节点个数,num_classes输出节点个数self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()# 后面不用看,是DeiT模型的self.head_dist = Noneif distilled:self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()# Weight init,初始化pos_embed# trunc_normal_利用正态分布生成一个点,点在[a, b]区间之内nn.init.trunc_normal_(self.pos_embed, std=0.02)# 后面不用看,是DeiT模型的if self.dist_token is not None:nn.init.trunc_normal_(self.dist_token, std=0.02)# Weight init,初始化cls_tokennn.init.trunc_normal_(self.cls_token, std=0.02)# 调用vit初始函数self.apply(_init_vit_weights)def forward_features(self, x):# patch embedding# [B, C, H, W] -> [B, num_patches, embed_dim]x = self.patch_embed(x)  # [B, 196, 768]# class token# [1, 1, 768] -> [B, 1, 768]cls_token = self.cls_token.expand(x.shape[0], -1, -1)# concat拼接cls_token和patch_token# ViT中dist_token为None,执行第一个if self.dist_token is None:  # ViT中dist_token为Nonex = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]else:x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)# 加上位置编码position embeddingx = self.pos_drop(x + self.pos_embed)# 现在的token=[class + patch]+ position# encoder block层x = self.blocks(x)# normalization层x = self.norm(x)# MPL的pre_logits# ViT中dist_token为None,执行第一个if self.dist_token is None:return self.pre_logits(x[:, 0])  # 只返回class token输出的列else:return x[:, 0], x[:, 1]def forward(self, x):# 返回class token输出的列x = self.forward_features(x)# head_dist为None,执行elseif self.head_dist is not None:x, x_dist = self.head(x[0]), self.head_dist(x[1])if self.training and not torch.jit.is_scripting():# during inference, return the average of both classifier predictionsreturn x, x_distelse:return (x + x_dist) / 2else:x = self.head(x)  # 最后的全连接层输出分类结果return xdef _init_vit_weights(m):"""ViT weight initialization:param m: module"""if isinstance(m, nn.Linear):nn.init.trunc_normal_(m.weight, std=.01)if m.bias is not None:nn.init.zeros_(m.bias)elif isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode="fan_out")if m.bias is not None:nn.init.zeros_(m.bias)elif isinstance(m, nn.LayerNorm):nn.init.zeros_(m.bias)nn.init.ones_(m.weight)# 至此我们已经完成了ViT所有模块的编写
# ——————————————————————————————————————————————————————————————————————————————————————————————————————————————————————def vit_base_patch16_224(num_classes: int = 1000):"""ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA  密码: eu9f"""model = VisionTransformer(img_size=224,patch_size=16,embed_dim=768,depth=12,num_heads=12,representation_size=None,num_classes=num_classes)return modeldef vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):"""ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth"""model = VisionTransformer(img_size=224,patch_size=16,embed_dim=768,depth=12,num_heads=12,representation_size=768 if has_logits else None,num_classes=num_classes)return modeldef vit_base_patch32_224(num_classes: int = 1000):"""ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg  密码: s5hl"""model = VisionTransformer(img_size=224,patch_size=32,embed_dim=768,depth=12,num_heads=12,representation_size=None,num_classes=num_classes)return modeldef vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):"""ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth"""model = VisionTransformer(img_size=224,patch_size=32,embed_dim=768,depth=12,num_heads=12,representation_size=768 if has_logits else None,num_classes=num_classes)return modeldef vit_large_patch16_224(num_classes: int = 1000):"""ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ  密码: qqt8"""model = VisionTransformer(img_size=224,patch_size=16,embed_dim=1024,depth=24,num_heads=16,representation_size=None,num_classes=num_classes)return modeldef vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):"""ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth"""model = VisionTransformer(img_size=224,patch_size=16,embed_dim=1024,depth=24,num_heads=16,representation_size=1024 if has_logits else None,num_classes=num_classes)return modeldef vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):"""ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth"""model = VisionTransformer(img_size=224,patch_size=32,embed_dim=1024,depth=24,num_heads=16,representation_size=1024 if has_logits else None,num_classes=num_classes)return modeldef vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True):"""ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.NOTE: converted weights not currently available, too large for github release hosting."""model = VisionTransformer(img_size=224,patch_size=14,embed_dim=1280,depth=32,num_heads=16,representation_size=1280 if has_logits else None,num_classes=num_classes)return model

五、Swin Tranformer

六、其他Transformer的改进

6.1 Hybird ViT

先用Resnet50特征提取,再用ViT进一步处理分类。

其中Resnet50部分做出了一些修改;


epoch较多时,混合模型模型反而效果不如纯正的ViT。

CV领域Transformer这一篇就够了(原理详解+pytorch代码复现)相关推荐

  1. 玩转Mysql系列 - 第22篇:mysql索引原理详解

    Mysql系列的目标是:通过这个系列从入门到全面掌握一个高级开发所需要的全部技能. 欢迎大家加我微信itsoku一起交流java.算法.数据库相关技术. 这是Mysql系列第22篇. 背景 使用mys ...

  2. 自问自答学ArrayList,看这篇就够了,详解问答

    前言 在之前的几篇文章里面,我主要都是推荐了一些工具类,为的就是让大家可以提高开发效率,但是我们在提高开发效率,也应该提高代码的执行效率,注重代码的质量.如何提高,其中的一个好办法就是阅读源码,知其然 ...

  3. arp 项删除失败: 请求的操作需要提升。_ccna必懂篇-arp协议工作原理详解。

    本次呢,要说的是arp协议,那么什么是arp协议呢?有什么作用呢? 什么是arp ARP(Address Resolution Protocol)地址解析协议,地址解析协议由互联网工程任务组(IETF ...

  4. 玩转Luat 进阶篇②——远程升级功能原理详解

    目录 一.简介 二.从云端获取升级包(新固件) 2.1 合宙官方服务器openluat 2.2 阿里云物联网平台 三.本地固件更新运行 3.1 合宙4G模块 3.1.1 合宙4G模块的Flash分区 ...

  5. Transformer 初识:模型结构+attention原理详解

    Transformer 初识:模型结构+原理详解 参考资源 前言 1.整体结构 1.1 输入: 1.2 Encoder 和 Decoder的结构 1.3 Layer normalization Bat ...

  6. [Python从零到壹] 四十七.图像增强及运算篇之腐蚀和膨胀详解

    欢迎大家来到"Python从零到壹",在这里我将分享约200篇Python系列文章,带大家一起去学习和玩耍,看看Python这个有趣的世界.所有文章都将结合案例.代码和作者的经验讲 ...

  7. 离线强化学习(Offline RL)系列3: (算法篇)策略约束 - BEAR算法原理详解与实现

    论文信息:Stabilizing Off-Policy Q-Learning via Bootstrapping Error Reduction 本文由UC Berkeley的Sergey Levin ...

  8. PowerShell攻防进阶篇:nishang工具用法详解

    PowerShell攻防进阶篇:nishang工具用法详解 导语:nishang,PowerShell下并肩Empire,Powersploit的神器. 开始之前,先放出个下载地址! 下载地址:htt ...

  9. 离线强化学习(Offline RL)系列3: (算法篇) IQL(Implicit Q-learning)算法详解与实现

    [更新记录] 论文信息:Ilya Kostrikov, Ashvin Nair, Sergey Levine: "Offline Reinforcement Learning with Im ...

最新文章

  1. golang 同一个包中函数互相调用报错 undefined 以及在 VSCode 中配置右键执行整个包文件
  2. 【强势来袭】Node.js(nodejs)实现“一口多用”(含用户创建、登录、鉴权token) 一个文件解决所有常态化需求
  3. xilinx LVDS使用注意事项
  4. JQuery UI之Autocomplete(3)属性与事件
  5. electron 打包 php,electron 将现有vue项目改成支持electron桌面端应用
  6. 推荐eclipse插件Properties Editor
  7. Facebook 最新可佩戴 AR 设备、AR 设备未来五年市场扩张、语音社交新创Swell等|Decode the Week...
  8. Java学习笔记之 IO包 字符流
  9. 计算机科学与技术专业《计算机网络原理》课程实验指导书,计算机科学导论,课程实验指导书解读.pdf...
  10. 低学历的非要考研,多半输得更惨
  11. jQuery初识 - jQuery是什么
  12. linux mysql 修改root密码_MySQL忘了root密码,如何修改?
  13. underscore.js依赖库函数分析一(遍历)
  14. Oracle字符集设置
  15. java一键生成海报_小程序生成海报(java后端)
  16. VS2015,错误RC1015: 无法打开包含文件afxres.h
  17. python中判断无向图是否有环_数据结构与算法:17 图
  18. PWM控制的基本原理
  19. Linux的网络编程面试题汇总
  20. 彻底弄懂base64的编码与解码原理

热门文章

  1. MYSQL5.7在Linux系统详细安装步骤
  2. 分享几个快速加微信粉丝的方法
  3. 湖南大学计算机专业保研,湖南大学各专业保研率
  4. BurpSuite抓包手机模拟器APP
  5. 《攻壳机动队》+押井守
  6. springboot整合rocketMQ记录 实现发送普通消息,延时消息
  7. Praat脚本-005 | 标注文件批量增加层级
  8. 【数学物理方法】定解问题——数物方程的导出(列泛定方程)
  9. 阿里云域名持有者过户
  10. (复健计划)Python列表