文章目录

  • 一、背景
  • 二、动机
  • 三、方法
    • 3.1 Cross-scale Embedding Layer(CEL)
    • 3.2 Cross-former Block
      • 3.2.1 Long Short Distance Attention(LSDA)
      • 3.2.2 Dynamic position bias (DPB)
    • 3.3 CrossFormer 的变体
  • 四、效果
  • 五、代码

论文链接:https://arxiv.org/pdf/2108.00154.pdf
代码链接:https://github.com/cheerss/CrossFormer

一、背景

Transformer 在计算机视觉中已经有了一些成功的应用,这些方法大都将输入图像切分成 patches,然后将这些 patch 编码成序列的特征,在 transformer 内部使用 attention 来建立编码后特征之间的关系。但使用原始的 self-attention 计算量相比 NLP 来说非常大,于是也有一些人对 self-attention 的方法做了一些改进。

二、动机

对于视觉任务,一个图像中不同目标的尺度是有很大不同的,所以如果需要对两个大小相差较大的目标建立关系的话,需要使用跨尺度的 attention,但是很少有方法能很好的建立不同尺度特征之间的 attention,原因有两个:

  • 其一是每层的输入编码特征都是相同尺度的,没有跨尺度的特征
  • 其二是一些方法为了效率牺牲了小尺度编码的特征。

基于此,作者就构建了一个 crossformer 结构。

三、方法

总体结构如图 1 所示,是金字塔结构,总共有四个 stages,每个 stage 的组成:

  • 一个 cross-scale embedding layer(CEL)

    CEL 接收上一个 stage 的输出(首个接收原始图像),输出 cross-scale embedding。除了 stage 1 以外, CEL 会把输入 embedding 的元素数量降低为输入的 1/4,维度增加为输入的 2 倍。

  • 多个 CrossFormer blocks


1、Cross-scale Embedding Layer(CEL)

作者使用金字塔结构的 Transformer,即会将模型分为多个 stage,在每个stage 之前都会使用 CEL,接收上个 stage 的输出作为输入,使用不同尺度的核进行 patch 选择。然后把这些 embedding 进行线性影射后 concat 起来,可以看成单个尺度的patch。

2、Long Short Distance Attention(LSDA)

作者也提出了一个传统 self-attention 的替代品 LSDA,LSDA不会损坏小尺度或大尺度特征,所以能够进行跨尺度的信息交互。LSDA 中作者将 self-attention 分成了两个部分

  • short-distance attention (SDA):建立和目标 embedding 近距离 embedding 的注意力特征
  • long-distance attention (LDA):建立和目标 embedding 远距离 embedding 的注意力特征

3、作者引入可学习模块 dynamic position bias (DPB) 来进行位置表达,输入为两个 embedding 之间的距离,输出为它们之间的 position bias。之前的 relative position bias(RPB)虽然高效,但只适合输入图像大小一致的情况,不适合于目标检测任务。

3.1 Cross-scale Embedding Layer(CEL)

CEL 在每个 stage 都会被用来生成 stage 的输入 embedding,如图2所示,在 stage-1 之前,使用原图作为首个 CEL 的输入,使用四个不同大小的核来进行 patch 抽取,即在相同的位置(中心点)使用四个大小不同的核进行 patch 抽取,然后经过投影后concat起来,得到 embedding 特征。

但还有一个问题就是每个不同大小的 patch 投影后的特征维度如何选取,已知大 kernel 容易带来大的计算复杂度,所以作者给大 kernel 使用了低维输出,小 kernel 使用了高维输出。

接收上个 stage 的输出作为输入,使用不同尺度的核进行 patch 选择。然后把这些 embedding 进行线性影射后 concat 起来,可以看成单个尺度的patch。

3.2 Cross-former Block

每个 Cross-former Block 都包括一个 SDA(或一个 LDA) + 一个 MLP,也就是 SDA 和 LDA 不会同时出现在 Cross-former Block 里边。

3.2.1 Long Short Distance Attention(LSDA)

作者将 self-attention 分成了两个部分

  • short-distance attention (SDA):建立和目标 embedding 近距离 embedding 的注意力特征
  • long-distance attention (LDA):建立和目标 embedding 远距离 embedding 的注意力特征

1、对于 SDA,每个 G×GG \times GG×G 的相邻 embedding 都被聚合起来了,图3a展示了 G=3G=3G=3 的情况

2、对于 LDA,输入为 S×SS\times SS×S,其 embedding 都会使用固定的间隔 III 被采样。如图3b所示,I=3I=3I=3,所有红色区域的 embedding 属于一个 group,黄色的是另一个 group。group 的宽高都为 G=SIG=\frac{S}{I}G=IS,此处 G=3G=3G=3

3、在聚合(group)之后,SDA 和 LDA 都会在每个 group 内使用传统的 self-attention,计算复杂度会从 O(S4)O(S^4)O(S4) 降低为 O(S2G2)O(S^2G^2)O(S2G2)

在图3b中,作者绘制了两个 embedding 的 patch,两个 embedding 的小尺度 patch 是不相邻的,在没有大尺度 patch 的帮助下很难判断该两者的关系。所以,如果两个 embedding 只包含小的 patches 时,很难建立它们俩的关系。相反的,相邻的大尺度 patch 能够提供足够的上下文信息来连接这两个 embedding。所以,跨尺度的 attention 能够较好的解决主要由大尺度 patch 主导的问题。

3.2.2 Dynamic position bias (DPB)

Relative position bias(RPB)通常被用来表示 embedding 的相对位置,是加在 attention 的一个偏置。

虽然高效,但只适合输入图像大小一致的情况,不适合于目标检测任务。

所以作者提出了一个 DPB,结构如图3c所示,其输入维度为 2(Δxi,j,Δyi,j\Delta x_{i,j}, \Delta y_{i,j}Δxi,j,Δyi,j,即第 i 个和第 j 个 embedding 的坐标距离)。由三个全连接层、一个 layer norm、一个 ReLU 构成,中间层维度为 D/4D/4D/4DDD 为 embedding 的维度。

3.3 CrossFormer 的变体

表 1 展示了 CrossFormer 的变体,包括 T/S/B/L,分别对应 tiny,small,base,large。

四、效果


五、代码

下载代码后,可以使用下面的方式来进行简单调用,看看 crossformer 是怎么实现的。

import torch
from models.crossformer import CrossFormermodel = CrossFormer(img_size=224,patch_size=[4, 8, 16, 32],in_chans=3,num_classes=1000,embed_dim=64,depths=[ 1, 1, 8, 6 ],num_heads=[ 2, 4, 8, 16 ],group_size=[ 7, 7, 7, 7 ],mlp_ratio= 4,qkv_bias=True,qk_scale=None,drop_rate=0.0,drop_path_rate=0.1,ape=False,patch_norm=True,use_checkpoint=False,merge_size=[[2, 4], [2,4], [2, 4]],)model.eval()
input = torch.randn(1, 3, 224, 224)
output = model(input)

1、patch embedding 的实现:

输入为原图,输出为经过不同大小的卷积核卷积后的结果,然后拼接起来,输入给 crossformer block。

class PatchEmbed(nn.Module):def __init__(self, img_size=224, patch_size=[4], in_chans=3, embed_dim=96, norm_layer=None):super().__init__()img_size = to_2tuple(img_size)# patch_size = to_2tuple(patch_size)patches_resolution = [img_size[0] // patch_size[0], img_size[0] // patch_size[0]]self.img_size = img_sizeself.patch_size = patch_sizeself.patches_resolution = patches_resolutionself.num_patches = patches_resolution[0] * patches_resolution[1]self.in_chans = in_chansself.embed_dim = embed_dimself.projs = nn.ModuleList()for i, ps in enumerate(patch_size):if i == len(patch_size) - 1:dim = embed_dim // 2 ** ielse:dim = embed_dim // 2 ** (i + 1)stride = patch_size[0]padding = (ps - patch_size[0]) // 2self.projs.append(nn.Conv2d(in_chans, dim, kernel_size=ps, stride=stride, padding=padding))if norm_layer is not None:self.norm = norm_layer(embed_dim)else:self.norm = Nonedef forward(self, x):B, C, H, W = x.shape# FIXME look at relaxing size constraintsassert H == self.img_size[0] and W == self.img_size[1], \f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."xs = []for i in range(len(self.projs)):tx = self.projs[i](x).flatten(2).transpose(1, 2) #[1,32,56,56],[1,16,56,56],[1,8,56,56],[1,8,56,56]xs.append(tx)  # B Ph*Pw C #xs[0]=[1, 3136, 32], xs[1]=[1, 3136, 16], xs[2]=[1, 3136, 8], xs[3]=[1, 3136, 8]x = torch.cat(xs, dim=2) # [1, 3136, 64]if self.norm is not None:x = self.norm(x)return x
PatchEmbed((projs): ModuleList((0): Conv2d(3, 32, kernel_size=(4, 4), stride=(4, 4))(1): Conv2d(3, 16, kernel_size=(8, 8), stride=(4, 4), padding=(2, 2))(2): Conv2d(3, 8, kernel_size=(16, 16), stride=(4, 4), padding=(6, 6))(3): Conv2d(3, 8, kernel_size=(32, 32), stride=(4, 4), padding=(14, 14)))(norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
)

2、CrossFormer Stage

lsda_flag=0 # 表示在该 block 里边使用 SDA
lsda_flag=1 # 表示在该 block 里边使用 LDA

stage0:输入 56x56,64d,1 个 SDA
stage1:输入 28x28,128d,1 个 SDA
stage 2:输入 14x14,156d,4 个 SDA 和 4 个 LDA 交替
stage 3:输入 7x7,512d,6 个 SDA

SDA 的实现:

G = self.group_size          # 7
if self.lsda_flag == 0: # 0 for SDAx = x.reshape(B, H // G, G, W // G, G, C).permute(0, 1, 3, 2, 4, 5)  # [1, 8, 7, 8, 7, 64] -> [1, 8, 8, 7, 7, 64]
else: # 1 for LDAx = x.reshape(B, G, H // G, G, W // G, C).permute(0, 2, 4, 1, 3, 5)  # [1, 7, 2, 7, 2, 256] -> [1, 2, 2, 7, 7, 256]
x = x.reshape(B * H * W // G**2, G**2, C) # [64, 49, 64] for SDA in layer1, [4, 49, 256] for LDA in layer2# multi-head self-attention
x = self.attn(x, mask=self.attn_mask)     # nW*B, G*G, C $ [64, 49, 64]

以 stage 0 为例,说明 SDA:

输入为 56x56,每行每列都分为 7 个 group,一共 49 个 group,每个 group 元素为 64 个,然后在每个 group 间做 attention。

以 stage 2 为例,说明 LDA:

输入为 14x14,每行每列都分为 7 个 group,跨一行一列取一个元素作为一个 group 内的元素,每个 group 元素为 4 个,然后在每个 group 间做 attention。

class Stage(nn.Module):""" CrossFormer blocks for one stage.Args:dim (int): Number of input channels.input_resolution (tuple[int]): Input resolution.depth (int): Number of blocks.num_heads (int): Number of attention heads.group_size (int): variable G in the paper, one group has GxG embeddingsmlp_ratio (float): Ratio of mlp hidden dim to embedding dim.qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.drop (float, optional): Dropout rate. Default: 0.0attn_drop (float, optional): Attention dropout rate. Default: 0.0drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNormdownsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: Noneuse_checkpoint (bool): Whether to use checkpointing to save memory. Default: False."""def __init__(self, dim, input_resolution, depth, num_heads, group_size,mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,patch_size_end=[4], num_patch_size=None):super().__init__()self.dim = dimself.input_resolution = input_resolutionself.depth = depthself.use_checkpoint = use_checkpoint# build blocksself.blocks = nn.ModuleList()for i in range(depth):lsda_flag = 0 if (i % 2 == 0) else 1self.blocks.append(CrossFormerBlock(dim=dim, input_resolution=input_resolution,num_heads=num_heads, group_size=group_size,lsda_flag=lsda_flag,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias, qk_scale=qk_scale,drop=drop, attn_drop=attn_drop,drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,norm_layer=norm_layer,num_patch_size=num_patch_size))# patch merging layerif downsample is not None:self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer, patch_size=patch_size_end, num_input_patch_size=num_patch_size)else:self.downsample = Nonedef forward(self, x):for blk in self.blocks:if self.use_checkpoint:x = checkpoint.checkpoint(blk, x)else:x = blk(x)if self.downsample is not None:x = self.downsample(x)return xdef extra_repr(self) -> str:return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"def flops(self):flops = 0for blk in self.blocks:flops += blk.flops()if self.downsample is not None:flops += self.downsample.flops()return flops

crossformer 结构:

ModuleList((0): Stage(dim=64, input_resolution=(56, 56), depth=1(blocks): ModuleList((0): CrossFormerBlock(dim=64, input_resolution=(56, 56), num_heads=2, group_size=7, lsda_flag=0, mlp_ratio=4(norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)(attn): Attention(dim=64, group_size=(7, 7), num_heads=2(pos): DynamicPosBias((pos_proj): Linear(in_features=2, out_features=4, bias=True)(pos1): Sequential((0): LayerNorm((4,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=4, out_features=4, bias=True))(pos2): Sequential((0): LayerNorm((4,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=4, out_features=4, bias=True))(pos3): Sequential((0): LayerNorm((4,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=4, out_features=2, bias=True)))(qkv): Linear(in_features=64, out_features=192, bias=True)(attn_drop): Dropout(p=0.0, inplace=False)(proj): Linear(in_features=64, out_features=64, bias=True)(proj_drop): Dropout(p=0.0, inplace=False)(softmax): Softmax(dim=-1))(drop_path): Identity()(norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=64, out_features=256, bias=True)(act): GELU()(fc2): Linear(in_features=256, out_features=64, bias=True)(drop): Dropout(p=0.0, inplace=False))))(downsample): PatchMerging(input_resolution=(56, 56), dim=64(reductions): ModuleList((0): Conv2d(64, 64, kernel_size=(2, 2), stride=(2, 2))(1): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)))(norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)))(1): Stage(dim=128, input_resolution=(28, 28), depth=1(blocks): ModuleList((0): CrossFormerBlock(dim=128, input_resolution=(28, 28), num_heads=4, group_size=7, lsda_flag=0, mlp_ratio=4(norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)(attn): Attention(dim=128, group_size=(7, 7), num_heads=4(pos): DynamicPosBias((pos_proj): Linear(in_features=2, out_features=8, bias=True)(pos1): Sequential((0): LayerNorm((8,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=8, out_features=8, bias=True))(pos2): Sequential((0): LayerNorm((8,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=8, out_features=8, bias=True))(pos3): Sequential((0): LayerNorm((8,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=8, out_features=4, bias=True)))(qkv): Linear(in_features=128, out_features=384, bias=True)(attn_drop): Dropout(p=0.0, inplace=False)(proj): Linear(in_features=128, out_features=128, bias=True)(proj_drop): Dropout(p=0.0, inplace=False)(softmax): Softmax(dim=-1))(drop_path): DropPath()(norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=128, out_features=512, bias=True)(act): GELU()(fc2): Linear(in_features=512, out_features=128, bias=True)(drop): Dropout(p=0.0, inplace=False))))(downsample): PatchMerging(input_resolution=(28, 28), dim=128(reductions): ModuleList((0): Conv2d(128, 128, kernel_size=(2, 2), stride=(2, 2))(1): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)))(norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)))(2): Stage(dim=256, input_resolution=(14, 14), depth=8(blocks): ModuleList((0): CrossFormerBlock(dim=256, input_resolution=(14, 14), num_heads=8, group_size=7, lsda_flag=0, mlp_ratio=4(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(attn): Attention(dim=256, group_size=(7, 7), num_heads=8(pos): DynamicPosBias((pos_proj): Linear(in_features=2, out_features=16, bias=True)(pos1): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=16, bias=True))(pos2): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=16, bias=True))(pos3): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=8, bias=True)))(qkv): Linear(in_features=256, out_features=768, bias=True)(attn_drop): Dropout(p=0.0, inplace=False)(proj): Linear(in_features=256, out_features=256, bias=True)(proj_drop): Dropout(p=0.0, inplace=False)(softmax): Softmax(dim=-1))(drop_path): DropPath()(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=256, out_features=1024, bias=True)(act): GELU()(fc2): Linear(in_features=1024, out_features=256, bias=True)(drop): Dropout(p=0.0, inplace=False)))(1): CrossFormerBlock(dim=256, input_resolution=(14, 14), num_heads=8, group_size=7, lsda_flag=1, mlp_ratio=4(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(attn): Attention(dim=256, group_size=(7, 7), num_heads=8(pos): DynamicPosBias((pos_proj): Linear(in_features=2, out_features=16, bias=True)(pos1): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=16, bias=True))(pos2): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=16, bias=True))(pos3): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=8, bias=True)))(qkv): Linear(in_features=256, out_features=768, bias=True)(attn_drop): Dropout(p=0.0, inplace=False)(proj): Linear(in_features=256, out_features=256, bias=True)(proj_drop): Dropout(p=0.0, inplace=False)(softmax): Softmax(dim=-1))(drop_path): DropPath()(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=256, out_features=1024, bias=True)(act): GELU()(fc2): Linear(in_features=1024, out_features=256, bias=True)(drop): Dropout(p=0.0, inplace=False)))(2): CrossFormerBlock(dim=256, input_resolution=(14, 14), num_heads=8, group_size=7, lsda_flag=0, mlp_ratio=4(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(attn): Attention(dim=256, group_size=(7, 7), num_heads=8(pos): DynamicPosBias((pos_proj): Linear(in_features=2, out_features=16, bias=True)(pos1): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=16, bias=True))(pos2): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=16, bias=True))(pos3): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=8, bias=True)))(qkv): Linear(in_features=256, out_features=768, bias=True)(attn_drop): Dropout(p=0.0, inplace=False)(proj): Linear(in_features=256, out_features=256, bias=True)(proj_drop): Dropout(p=0.0, inplace=False)(softmax): Softmax(dim=-1))(drop_path): DropPath()(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=256, out_features=1024, bias=True)(act): GELU()(fc2): Linear(in_features=1024, out_features=256, bias=True)(drop): Dropout(p=0.0, inplace=False)))(3): CrossFormerBlock(dim=256, input_resolution=(14, 14), num_heads=8, group_size=7, lsda_flag=1, mlp_ratio=4(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(attn): Attention(dim=256, group_size=(7, 7), num_heads=8(pos): DynamicPosBias((pos_proj): Linear(in_features=2, out_features=16, bias=True)(pos1): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=16, bias=True))(pos2): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=16, bias=True))(pos3): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=8, bias=True)))(qkv): Linear(in_features=256, out_features=768, bias=True)(attn_drop): Dropout(p=0.0, inplace=False)(proj): Linear(in_features=256, out_features=256, bias=True)(proj_drop): Dropout(p=0.0, inplace=False)(softmax): Softmax(dim=-1))(drop_path): DropPath()(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=256, out_features=1024, bias=True)(act): GELU()(fc2): Linear(in_features=1024, out_features=256, bias=True)(drop): Dropout(p=0.0, inplace=False)))(4): CrossFormerBlock(dim=256, input_resolution=(14, 14), num_heads=8, group_size=7, lsda_flag=0, mlp_ratio=4(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(attn): Attention(dim=256, group_size=(7, 7), num_heads=8(pos): DynamicPosBias((pos_proj): Linear(in_features=2, out_features=16, bias=True)(pos1): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=16, bias=True))(pos2): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=16, bias=True))(pos3): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=8, bias=True)))(qkv): Linear(in_features=256, out_features=768, bias=True)(attn_drop): Dropout(p=0.0, inplace=False)(proj): Linear(in_features=256, out_features=256, bias=True)(proj_drop): Dropout(p=0.0, inplace=False)(softmax): Softmax(dim=-1))(drop_path): DropPath()(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=256, out_features=1024, bias=True)(act): GELU()(fc2): Linear(in_features=1024, out_features=256, bias=True)(drop): Dropout(p=0.0, inplace=False)))(5): CrossFormerBlock(dim=256, input_resolution=(14, 14), num_heads=8, group_size=7, lsda_flag=1, mlp_ratio=4(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(attn): Attention(dim=256, group_size=(7, 7), num_heads=8(pos): DynamicPosBias((pos_proj): Linear(in_features=2, out_features=16, bias=True)(pos1): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=16, bias=True))(pos2): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=16, bias=True))(pos3): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=8, bias=True)))(qkv): Linear(in_features=256, out_features=768, bias=True)(attn_drop): Dropout(p=0.0, inplace=False)(proj): Linear(in_features=256, out_features=256, bias=True)(proj_drop): Dropout(p=0.0, inplace=False)(softmax): Softmax(dim=-1))(drop_path): DropPath()(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=256, out_features=1024, bias=True)(act): GELU()(fc2): Linear(in_features=1024, out_features=256, bias=True)(drop): Dropout(p=0.0, inplace=False)))(6): CrossFormerBlock(dim=256, input_resolution=(14, 14), num_heads=8, group_size=7, lsda_flag=0, mlp_ratio=4(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(attn): Attention(dim=256, group_size=(7, 7), num_heads=8(pos): DynamicPosBias((pos_proj): Linear(in_features=2, out_features=16, bias=True)(pos1): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=16, bias=True))(pos2): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=16, bias=True))(pos3): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=8, bias=True)))(qkv): Linear(in_features=256, out_features=768, bias=True)(attn_drop): Dropout(p=0.0, inplace=False)(proj): Linear(in_features=256, out_features=256, bias=True)(proj_drop): Dropout(p=0.0, inplace=False)(softmax): Softmax(dim=-1))(drop_path): DropPath()(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=256, out_features=1024, bias=True)(act): GELU()(fc2): Linear(in_features=1024, out_features=256, bias=True)(drop): Dropout(p=0.0, inplace=False)))(7): CrossFormerBlock(dim=256, input_resolution=(14, 14), num_heads=8, group_size=7, lsda_flag=1, mlp_ratio=4(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(attn): Attention(dim=256, group_size=(7, 7), num_heads=8(pos): DynamicPosBias((pos_proj): Linear(in_features=2, out_features=16, bias=True)(pos1): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=16, bias=True))(pos2): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=16, bias=True))(pos3): Sequential((0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=16, out_features=8, bias=True)))(qkv): Linear(in_features=256, out_features=768, bias=True)(attn_drop): Dropout(p=0.0, inplace=False)(proj): Linear(in_features=256, out_features=256, bias=True)(proj_drop): Dropout(p=0.0, inplace=False)(softmax): Softmax(dim=-1))(drop_path): DropPath()(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=256, out_features=1024, bias=True)(act): GELU()(fc2): Linear(in_features=1024, out_features=256, bias=True)(drop): Dropout(p=0.0, inplace=False))))(downsample): PatchMerging(input_resolution=(14, 14), dim=256(reductions): ModuleList((0): Conv2d(256, 256, kernel_size=(2, 2), stride=(2, 2))(1): Conv2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)))(norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)))(3): Stage(dim=512, input_resolution=(7, 7), depth=6(blocks): ModuleList((0): CrossFormerBlock(dim=512, input_resolution=(7, 7), num_heads=16, group_size=7, lsda_flag=0, mlp_ratio=4(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)(attn): Attention(dim=512, group_size=(7, 7), num_heads=16(pos): DynamicPosBias((pos_proj): Linear(in_features=2, out_features=32, bias=True)(pos1): Sequential((0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=32, out_features=32, bias=True))(pos2): Sequential((0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=32, out_features=32, bias=True))(pos3): Sequential((0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=32, out_features=16, bias=True)))(qkv): Linear(in_features=512, out_features=1536, bias=True)(attn_drop): Dropout(p=0.0, inplace=False)(proj): Linear(in_features=512, out_features=512, bias=True)(proj_drop): Dropout(p=0.0, inplace=False)(softmax): Softmax(dim=-1))(drop_path): DropPath()(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=512, out_features=2048, bias=True)(act): GELU()(fc2): Linear(in_features=2048, out_features=512, bias=True)(drop): Dropout(p=0.0, inplace=False)))(1): CrossFormerBlock(dim=512, input_resolution=(7, 7), num_heads=16, group_size=7, lsda_flag=0, mlp_ratio=4(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)(attn): Attention(dim=512, group_size=(7, 7), num_heads=16(pos): DynamicPosBias((pos_proj): Linear(in_features=2, out_features=32, bias=True)(pos1): Sequential((0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=32, out_features=32, bias=True))(pos2): Sequential((0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=32, out_features=32, bias=True))(pos3): Sequential((0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=32, out_features=16, bias=True)))(qkv): Linear(in_features=512, out_features=1536, bias=True)(attn_drop): Dropout(p=0.0, inplace=False)(proj): Linear(in_features=512, out_features=512, bias=True)(proj_drop): Dropout(p=0.0, inplace=False)(softmax): Softmax(dim=-1))(drop_path): DropPath()(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=512, out_features=2048, bias=True)(act): GELU()(fc2): Linear(in_features=2048, out_features=512, bias=True)(drop): Dropout(p=0.0, inplace=False)))(2): CrossFormerBlock(dim=512, input_resolution=(7, 7), num_heads=16, group_size=7, lsda_flag=0, mlp_ratio=4(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)(attn): Attention(dim=512, group_size=(7, 7), num_heads=16(pos): DynamicPosBias((pos_proj): Linear(in_features=2, out_features=32, bias=True)(pos1): Sequential((0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=32, out_features=32, bias=True))(pos2): Sequential((0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=32, out_features=32, bias=True))(pos3): Sequential((0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=32, out_features=16, bias=True)))(qkv): Linear(in_features=512, out_features=1536, bias=True)(attn_drop): Dropout(p=0.0, inplace=False)(proj): Linear(in_features=512, out_features=512, bias=True)(proj_drop): Dropout(p=0.0, inplace=False)(softmax): Softmax(dim=-1))(drop_path): DropPath()(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=512, out_features=2048, bias=True)(act): GELU()(fc2): Linear(in_features=2048, out_features=512, bias=True)(drop): Dropout(p=0.0, inplace=False)))(3): CrossFormerBlock(dim=512, input_resolution=(7, 7), num_heads=16, group_size=7, lsda_flag=0, mlp_ratio=4(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)(attn): Attention(dim=512, group_size=(7, 7), num_heads=16(pos): DynamicPosBias((pos_proj): Linear(in_features=2, out_features=32, bias=True)(pos1): Sequential((0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=32, out_features=32, bias=True))(pos2): Sequential((0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=32, out_features=32, bias=True))(pos3): Sequential((0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=32, out_features=16, bias=True)))(qkv): Linear(in_features=512, out_features=1536, bias=True)(attn_drop): Dropout(p=0.0, inplace=False)(proj): Linear(in_features=512, out_features=512, bias=True)(proj_drop): Dropout(p=0.0, inplace=False)(softmax): Softmax(dim=-1))(drop_path): DropPath()(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=512, out_features=2048, bias=True)(act): GELU()(fc2): Linear(in_features=2048, out_features=512, bias=True)(drop): Dropout(p=0.0, inplace=False)))(4): CrossFormerBlock(dim=512, input_resolution=(7, 7), num_heads=16, group_size=7, lsda_flag=0, mlp_ratio=4(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)(attn): Attention(dim=512, group_size=(7, 7), num_heads=16(pos): DynamicPosBias((pos_proj): Linear(in_features=2, out_features=32, bias=True)(pos1): Sequential((0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=32, out_features=32, bias=True))(pos2): Sequential((0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=32, out_features=32, bias=True))(pos3): Sequential((0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=32, out_features=16, bias=True)))(qkv): Linear(in_features=512, out_features=1536, bias=True)(attn_drop): Dropout(p=0.0, inplace=False)(proj): Linear(in_features=512, out_features=512, bias=True)(proj_drop): Dropout(p=0.0, inplace=False)(softmax): Softmax(dim=-1))(drop_path): DropPath()(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=512, out_features=2048, bias=True)(act): GELU()(fc2): Linear(in_features=2048, out_features=512, bias=True)(drop): Dropout(p=0.0, inplace=False)))(5): CrossFormerBlock(dim=512, input_resolution=(7, 7), num_heads=16, group_size=7, lsda_flag=0, mlp_ratio=4(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)(attn): Attention(dim=512, group_size=(7, 7), num_heads=16(pos): DynamicPosBias((pos_proj): Linear(in_features=2, out_features=32, bias=True)(pos1): Sequential((0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=32, out_features=32, bias=True))(pos2): Sequential((0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=32, out_features=32, bias=True))(pos3): Sequential((0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)(1): ReLU(inplace=True)(2): Linear(in_features=32, out_features=16, bias=True)))(qkv): Linear(in_features=512, out_features=1536, bias=True)(attn_drop): Dropout(p=0.0, inplace=False)(proj): Linear(in_features=512, out_features=512, bias=True)(proj_drop): Dropout(p=0.0, inplace=False)(softmax): Softmax(dim=-1))(drop_path): DropPath()(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=512, out_features=2048, bias=True)(act): GELU()(fc2): Linear(in_features=2048, out_features=512, bias=True)(drop): Dropout(p=0.0, inplace=False)))))
)

3、总体代码

import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_class Mlp(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xclass DynamicPosBias(nn.Module):def __init__(self, dim, num_heads, residual):super().__init__()self.residual = residualself.num_heads = num_headsself.pos_dim = dim // 4self.pos_proj = nn.Linear(2, self.pos_dim)self.pos1 = nn.Sequential(nn.LayerNorm(self.pos_dim),nn.ReLU(inplace=True),nn.Linear(self.pos_dim, self.pos_dim),)self.pos2 = nn.Sequential(nn.LayerNorm(self.pos_dim),nn.ReLU(inplace=True),nn.Linear(self.pos_dim, self.pos_dim))self.pos3 = nn.Sequential(nn.LayerNorm(self.pos_dim),nn.ReLU(inplace=True),nn.Linear(self.pos_dim, self.num_heads))def forward(self, biases):if self.residual:pos = self.pos_proj(biases) # 2Wh-1 * 2Ww-1, headspos = pos + self.pos1(pos)pos = pos + self.pos2(pos)pos = self.pos3(pos)else:pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))return posdef flops(self, N):flops = N * 2 * self.pos_dimflops += N * self.pos_dim * self.pos_dimflops += N * self.pos_dim * self.pos_dimflops += N * self.pos_dim * self.num_headsreturn flopsclass Attention(nn.Module):r""" Multi-head self attention module with dynamic position bias.Args:dim (int): Number of input channels.group_size (tuple[int]): The height and width of the group.num_heads (int): Number of attention heads.qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if setattn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0proj_drop (float, optional): Dropout ratio of output. Default: 0.0"""def __init__(self, dim, group_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,position_bias=True):super().__init__()self.dim = dimself.group_size = group_size  # Wh, Wwself.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5self.position_bias = position_biasif position_bias:self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)# generate mother-setposition_bias_h = torch.arange(1 - self.group_size[0], self.group_size[0])position_bias_w = torch.arange(1 - self.group_size[1], self.group_size[1])biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))  # 2, 2Wh-1, 2W2-1biases = biases.flatten(1).transpose(0, 1).float()self.register_buffer("biases", biases)# get pair-wise relative position index for each token inside the groupcoords_h = torch.arange(self.group_size[0])coords_w = torch.arange(self.group_size[1])coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Wwcoords_flatten = torch.flatten(coords, 1)  # 2, Wh*Wwrelative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Wwrelative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2relative_coords[:, :, 0] += self.group_size[0] - 1  # shift to start from 0relative_coords[:, :, 1] += self.group_size[1] - 1relative_coords[:, :, 0] *= 2 * self.group_size[1] - 1relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Wwself.register_buffer("relative_position_index", relative_position_index)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)self.softmax = nn.Softmax(dim=-1)def forward(self, x, mask=None):"""Args:x: input features with shape of (num_groups*B, N, C)mask: (0/-inf) mask with shape of (num_groups, Wh*Ww, Wh*Ww) or None"""B_, N, C = x.shape # [64, 49, 64] for SDAqkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple) #  q.shape=k.shape=v.shape=[64, 2, 49, 32]q = q * self.scaleattn = (q @ k.transpose(-2, -1))if self.position_bias:pos = self.pos(self.biases) # 2Wh-1 * 2Ww-1, heads # [169, 2]# select position biasrelative_position_bias = pos[self.relative_position_index.view(-1)].view(self.group_size[0] * self.group_size[1], self.group_size[0] * self.group_size[1], -1)  # Wh*Ww,Wh*Ww,nHrelative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww # [2, 49, 49]attn = attn + relative_position_bias.unsqueeze(0)if mask is not None:nW = mask.shape[0]attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn) # [64, 2, 49, 49]attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # [64, 49, 64]x = self.proj(x)            # [64, 49, 64]x = self.proj_drop(x)return xdef extra_repr(self) -> str:return f'dim={self.dim}, group_size={self.group_size}, num_heads={self.num_heads}'def flops(self, N):# calculate flops for 1 group with token length of Nflops = 0# qkv = self.qkv(x)flops += N * self.dim * 3 * self.dim# attn = (q @ k.transpose(-2, -1))flops += self.num_heads * N * (self.dim // self.num_heads) * N#  x = (attn @ v)flops += self.num_heads * N * N * (self.dim // self.num_heads)# x = self.proj(x)flops += N * self.dim * self.dimif self.position_bias:flops += self.pos.flops(N)return flopsclass CrossFormerBlock(nn.Module):r""" CrossFormer Block.Args:dim (int): Number of input channels.input_resolution (tuple[int]): Input resulotion.num_heads (int): Number of attention heads.group_size (int): Group size.lsda_flag (int): use SDA or LDA, 0 for SDA and 1 for LDA.mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.drop (float, optional): Dropout rate. Default: 0.0attn_drop (float, optional): Attention dropout rate. Default: 0.0drop_path (float, optional): Stochastic depth rate. Default: 0.0act_layer (nn.Module, optional): Activation layer. Default: nn.GELUnorm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm"""def __init__(self, dim, input_resolution, num_heads, group_size=7, lsda_flag=0,mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,act_layer=nn.GELU, norm_layer=nn.LayerNorm, num_patch_size=1):super().__init__()self.dim = dimself.input_resolution = input_resolutionself.num_heads = num_headsself.group_size = group_sizeself.lsda_flag = lsda_flagself.mlp_ratio = mlp_ratioself.num_patch_size = num_patch_sizeif min(self.input_resolution) <= self.group_size:# if group size is larger than input resolution, we don't partition groupsself.lsda_flag = 0self.group_size = min(self.input_resolution)self.norm1 = norm_layer(dim)self.attn = Attention(dim, group_size=to_2tuple(self.group_size), num_heads=num_heads,qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,position_bias=True)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)attn_mask = Noneself.register_buffer("attn_mask", attn_mask)def forward(self, x):H, W = self.input_resolution # [56, 56]B, L, C = x.shape            # [1, 3136, 64]assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W)shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)       # [1, 56, 56, 64]# group embeddingsG = self.group_size          # 7if self.lsda_flag == 0: # 0 for SDAx = x.reshape(B, H // G, G, W // G, G, C).permute(0, 1, 3, 2, 4, 5)  # [1, 8, 7, 8, 7, 64] -> [1, 8, 8, 7, 7, 64]else: # 1 for LDAx = x.reshape(B, G, H // G, G, W // G, C).permute(0, 2, 4, 1, 3, 5)x = x.reshape(B * H * W // G**2, G**2, C) # [64, 49, 64] for SDA# multi-head self-attentionx = self.attn(x, mask=self.attn_mask)     # nW*B, G*G, C $ [64, 49, 64]# ungroup embeddingsx = x.reshape(B, H // G, W // G, G, G, C) # [1, 8, 8, 7, 7, 64]if self.lsda_flag == 0:x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, H, W, C)else:x = x.permute(0, 3, 1, 4, 2, 5).reshape(B, H, W, C)x = x.view(B, H * W, C) # [1, 3136, 64]# FFNx = shortcut + self.drop_path(x)x = x + self.drop_path(self.mlp(self.norm2(x)))    # [1, 3136, 64]return xdef extra_repr(self) -> str:return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \f"group_size={self.group_size}, lsda_flag={self.lsda_flag}, mlp_ratio={self.mlp_ratio}"def flops(self):flops = 0H, W = self.input_resolution# norm1flops += self.dim * H * W# LSDAnW = H * W / self.group_size / self.group_sizeflops += nW * self.attn.flops(self.group_size * self.group_size)# mlpflops += 2 * H * W * self.dim * self.dim * self.mlp_ratio# norm2flops += self.dim * H * Wreturn flopsclass PatchMerging(nn.Module):r""" Patch Merging Layer.Args:input_resolution (tuple[int]): Resolution of input feature.dim (int): Number of input channels.norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm"""def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, patch_size=[2], num_input_patch_size=1):super().__init__()self.input_resolution = input_resolutionself.dim = dimself.reductions = nn.ModuleList()self.patch_size = patch_sizeself.norm = norm_layer(dim)for i, ps in enumerate(patch_size):if i == len(patch_size) - 1:out_dim = 2 * dim // 2 ** ielse:out_dim = 2 * dim // 2 ** (i + 1)stride = 2padding = (ps - stride) // 2self.reductions.append(nn.Conv2d(dim, out_dim, kernel_size=ps, stride=stride, padding=padding))def forward(self, x):"""x: B, H*W, C"""H, W = self.input_resolutionB, L, C = x.shapeassert L == H * W, "input feature has wrong size"assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."x = self.norm(x)x = x.view(B, H, W, C).permute(0, 3, 1, 2)xs = []for i in range(len(self.reductions)):tmp_x = self.reductions[i](x).flatten(2).transpose(1, 2)xs.append(tmp_x)x = torch.cat(xs, dim=2)return xdef extra_repr(self) -> str:return f"input_resolution={self.input_resolution}, dim={self.dim}"def flops(self):H, W = self.input_resolutionflops = H * W * self.dimfor i, ps in enumerate(self.patch_size):if i == len(self.patch_size) - 1:out_dim = 2 * self.dim // 2 ** ielse:out_dim = 2 * self.dim // 2 ** (i + 1)flops += (H // 2) * (W // 2) * ps * ps * out_dim * self.dimreturn flopsclass Stage(nn.Module):""" CrossFormer blocks for one stage.Args:dim (int): Number of input channels.input_resolution (tuple[int]): Input resolution.depth (int): Number of blocks.num_heads (int): Number of attention heads.group_size (int): variable G in the paper, one group has GxG embeddingsmlp_ratio (float): Ratio of mlp hidden dim to embedding dim.qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.drop (float, optional): Dropout rate. Default: 0.0attn_drop (float, optional): Attention dropout rate. Default: 0.0drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNormdownsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: Noneuse_checkpoint (bool): Whether to use checkpointing to save memory. Default: False."""def __init__(self, dim, input_resolution, depth, num_heads, group_size,mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,patch_size_end=[4], num_patch_size=None):super().__init__()self.dim = dimself.input_resolution = input_resolutionself.depth = depthself.use_checkpoint = use_checkpoint# build blocksself.blocks = nn.ModuleList()for i in range(depth):lsda_flag = 0 if (i % 2 == 0) else 1self.blocks.append(CrossFormerBlock(dim=dim, input_resolution=input_resolution,num_heads=num_heads, group_size=group_size,lsda_flag=lsda_flag,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias, qk_scale=qk_scale,drop=drop, attn_drop=attn_drop,drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,norm_layer=norm_layer,num_patch_size=num_patch_size))# patch merging layerif downsample is not None:self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer, patch_size=patch_size_end, num_input_patch_size=num_patch_size)else:self.downsample = Nonedef forward(self, x):for blk in self.blocks:if self.use_checkpoint:x = checkpoint.checkpoint(blk, x)else:x = blk(x)if self.downsample is not None:x = self.downsample(x) # [1, 784, 128]return xdef extra_repr(self) -> str:return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"def flops(self):flops = 0for blk in self.blocks:flops += blk.flops()if self.downsample is not None:flops += self.downsample.flops()return flopsclass PatchEmbed(nn.Module):r""" Image to Patch EmbeddingArgs:img_size (int): Image size.  Default: 224.patch_size (int): Patch token size. Default: [4].in_chans (int): Number of input image channels. Default: 3.embed_dim (int): Number of linear projection output channels. Default: 96.norm_layer (nn.Module, optional): Normalization layer. Default: None"""def __init__(self, img_size=224, patch_size=[4], in_chans=3, embed_dim=96, norm_layer=None):super().__init__()img_size = to_2tuple(img_size)# patch_size = to_2tuple(patch_size)patches_resolution = [img_size[0] // patch_size[0], img_size[0] // patch_size[0]]self.img_size = img_sizeself.patch_size = patch_sizeself.patches_resolution = patches_resolutionself.num_patches = patches_resolution[0] * patches_resolution[1]self.in_chans = in_chansself.embed_dim = embed_dimself.projs = nn.ModuleList()for i, ps in enumerate(patch_size):if i == len(patch_size) - 1:dim = embed_dim // 2 ** ielse:dim = embed_dim // 2 ** (i + 1)stride = patch_size[0]padding = (ps - patch_size[0]) // 2self.projs.append(nn.Conv2d(in_chans, dim, kernel_size=ps, stride=stride, padding=padding))if norm_layer is not None:self.norm = norm_layer(embed_dim)else:self.norm = Nonedef forward(self, x):B, C, H, W = x.shape# FIXME look at relaxing size constraintsassert H == self.img_size[0] and W == self.img_size[1], \f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."xs = []for i in range(len(self.projs)):tx = self.projs[i](x).flatten(2).transpose(1, 2) #[1,32,56,56],[1,16,56,56],[1,8,56,56],[1,8,56,56]xs.append(tx)  # B Ph*Pw C #xs[0]=[1, 3136, 32], xs[1]=[1, 3136, 16], xs[2]=[1, 3136, 8], xs[3]=[1, 3136, 8]x = torch.cat(xs, dim=2) # [1, 3136, 64]if self.norm is not None:x = self.norm(x)return xdef flops(self):Ho, Wo = self.patches_resolutionflops = 0for i, ps in enumerate(self.patch_size):if i == len(self.patch_size) - 1:dim = self.embed_dim // 2 ** ielse:dim = self.embed_dim // 2 ** (i + 1)flops += Ho * Wo * dim * self.in_chans * (self.patch_size[i] * self.patch_size[i])if self.norm is not None:flops += Ho * Wo * self.embed_dimreturn flopsclass CrossFormer(nn.Module):r""" CrossFormerA PyTorch impl of : `CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention`  -Args:img_size (int | tuple(int)): Input image size. Default 224patch_size (int | tuple(int)): Patch size. Default: 4in_chans (int): Number of input image channels. Default: 3num_classes (int): Number of classes for classification head. Default: 1000embed_dim (int): Patch embedding dimension. Default: 96depths (tuple(int)): Depth of each stage.num_heads (tuple(int)): Number of attention heads in different layers.group_size (int): Group size. Default: 7mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: Nonedrop_rate (float): Dropout rate. Default: 0attn_drop_rate (float): Attention dropout rate. Default: 0drop_path_rate (float): Stochastic depth rate. Default: 0.1norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.ape (bool): If True, add absolute position embedding to the patch embedding. Default: Falsepatch_norm (bool): If True, add normalization after patch embedding. Default: Trueuse_checkpoint (bool): Whether to use checkpointing to save memory. Default: False"""def __init__(self, img_size=224, patch_size=[4], in_chans=3, num_classes=1000,embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],group_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,norm_layer=nn.LayerNorm, ape=False, patch_norm=True,use_checkpoint=False, merge_size=[[2], [2], [2]], **kwargs):super().__init__()self.num_classes = num_classes       # 1000self.num_layers = len(depths)        # 4self.embed_dim = embed_dim           # 64self.ape = ape                       # Falseself.patch_norm = patch_norm         # Trueself.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) # 512self.mlp_ratio = mlp_ratio           # 4# split image into non-overlapping patchesself.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,norm_layer=norm_layer if self.patch_norm else None)num_patches = self.patch_embed.num_patches                # 3136patches_resolution = self.patch_embed.patches_resolution  # [56, 56]self.patches_resolution = patches_resolution              # [56, 56]# absolute position embeddingif self.ape:self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))trunc_normal_(self.absolute_pos_embed, std=.02)self.pos_drop = nn.Dropout(p=drop_rate)# stochastic depthdpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule# build layersself.layers = nn.ModuleList()num_patch_sizes = [len(patch_size)] + [len(m) for m in merge_size]  # [4, 2, 2, 2]for i_layer in range(self.num_layers):patch_size_end = merge_size[i_layer] if i_layer < self.num_layers - 1 else Nonenum_patch_size = num_patch_sizes[i_layer]layer = Stage(dim=int(embed_dim * 2 ** i_layer),input_resolution=(patches_resolution[0] // (2 ** i_layer),patches_resolution[1] // (2 ** i_layer)),depth=depths[i_layer],num_heads=num_heads[i_layer],group_size=group_size[i_layer],mlp_ratio=self.mlp_ratio,qkv_bias=qkv_bias, qk_scale=qk_scale,drop=drop_rate, attn_drop=attn_drop_rate,drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],norm_layer=norm_layer,downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,use_checkpoint=use_checkpoint,patch_size_end=patch_size_end,num_patch_size=num_patch_size)self.layers.append(layer)self.norm = norm_layer(self.num_features)self.avgpool = nn.AdaptiveAvgPool1d(1)self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()self.apply(self._init_weights)# import pdb; pdb.set_trace()def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)@torch.jit.ignoredef no_weight_decay(self):return {'absolute_pos_embed'}@torch.jit.ignoredef no_weight_decay_keywords(self):return {'relative_position_bias_table'}def forward_features(self, x):import pdb; pdb.set_trace()# input_x.shape=[1,3,224,224]x = self.patch_embed(x) # x.shape=[1, 3136, 64]if self.ape:x = x + self.absolute_pos_embedx = self.pos_drop(x) # dropoutfor layer in self.layers:x = layer(x)   # [1, 784, 128], [1, 196, 256], [1, 49, 512], [1, 49, 512]x = self.norm(x)  # B L C [1, 49, 512]x = self.avgpool(x.transpose(1, 2))  # B C 1 # [1, 512, 1]x = torch.flatten(x, 1)return xdef forward(self, x):x = self.forward_features(x)x = self.head(x)return xdef flops(self):flops = 0flops += self.patch_embed.flops()for i, layer in enumerate(self.layers):flops += layer.flops()flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)flops += self.num_features * self.num_classesreturn flops

【Transformer】CrossFormer:A versatile vision transformer based on cross-scale attention相关推荐

  1. 【NLP】GPT:第一个引入Transformer的预训练模型

    目前两种最重要的预训练语言模型,一种是前面介绍过的BERT,另外一种就是GPT. GPT出了两个版本,GPT1.0和GPT2.0,GPT2.0相对于GPT1.0差别不大,可以认为是它的增强版.本篇介绍 ...

  2. 【自然语言处理】【多模态】UniT:基于统一Transformer的多模态多任务学习

    UniT:基于统一Transformer的多模态多任务学习 <UniT:Multimodal Multitask Learning with a Unified Transformer> ...

  3. 【自然语言处理】【多模态】CLIP:从自然语言监督中学习可迁移视觉模型

    从自然语言监督中学习可迁移视觉模型 <Learning Transferable Visual Models From Natural Language Supervision> 论文地址 ...

  4. 【计算机视觉】MAE:Masked Autoencoder

    有任何的书写错误.排版错误.概念错误等,希望大家包含指正. 在阅读本篇之前建议先学习: [自然语言处理]Transformer 讲解 [自然语言处理]BERT 讲解 [计算机视觉]ViT:Vision ...

  5. 【NLP】XLnet:GPT和BERT的合体,博采众长,所以更强

    前面介绍过BERT,作为一种非常成功的预训练模型,取得了非常不错的成绩,那么,他还有改进的空间吗? 本文介绍BERT的改进版,XLnet.看看它用了什么方法,改进了BERT的哪些弱点. 作者& ...

  6. 【自然语言处理】【多模态】OFA:通过简单的sequence-to-sequence学习框架统一架构、任务和模态

    OFA:通过简单的sequence-to-sequence学习框架统一架构.任务和模态 <Unifying Architectures, Task, and Modalities through ...

  7. 【自然语言处理】【多模态】BLIP:面向统一视觉语言理解和生成的自举语言图像预训练

    BLIP: 面向统一视觉语言理解和生成的自举语言图像预训练 <BLIP: Bootstrapping Language-Image Pre-training for Unified Vision ...

  8. 【论文】VQA:Learning Conditioned Graph Structures for Interpretable Visual Question Answering

    [论文]VQA:学习可解释的可视问题解答的条件图结构 目录 [论文]VQA:学习可解释的可视问题解答的条件图结构 摘 要 一.模型结构图 二.Computing model inputs 三.Grap ...

  9. 【转】NG:垂枝桦基因组图谱构建(2+3组装)及重测序分析

    [转]NG:垂枝桦基因组图谱构建(2+3组装)及重测序分析 转自希望组公众号.学习二代+三代组装策略的流程 垂枝桦(Betula pendula)是一种速生乔木,能在短短一年时间内开花,木质坚实,可做 ...

最新文章

  1. Redis 笔记(01)— 安装、启动配置、开启远程连接、设置密码、远程连接
  2. 2W 字详解 Redis 6.0 集群环境搭建实践
  3. Unity自动保存场景脚本
  4. FreeSWITCH异常原因总结
  5. 策略模式的扩展——策略枚举
  6. stdafx.h 的作用
  7. linux多网口dhcp服务绑定,Linux系统实现多网段DHCP服务器配置
  8. Python 类的属性和实例属性 Python 的动态语言
  9. Socket编程:之TCP案例
  10. OpenCV中膨胀和腐蚀结构元素的创建
  11. 赤虹JSON模块 v1.0 麻雀虽小, 五脏俱全
  12. rfid 标签内存_RFID有源与无源的区别与联系
  13. 给定一个N位数,得到一个N-k位的数中最小的数
  14. SAP License:”事后借记”与第三方外币支付处理
  15. php设置mysql 编码_PHP和Mysql中转UTF8编码问题汇总
  16. 防火墙转发流量的原理
  17. 一个“技术文化人”的片段感悟
  18. 你的六岁在玩儿泥巴,他们六岁已经在讲算法了
  19. ink css,CSS text-decoration-skip-ink属性用法及代码示例
  20. uniapp的项目,scss和js实现跑马灯

热门文章

  1. LAMP基于php模块实现个人博客搭建
  2. Greenplum 权限管理与客户端认证
  3. 编码之道:取个好名字很重要(转)
  4. Windows xp 如何查看SID?
  5. 深入update语句(延伸学习)
  6. 使用PHP时出现乱码,php出现乱码该怎么解决?
  7. java双引号的转义字符_JAVA中转义字符
  8. first() mysql_EF6配合MySQL或MSSQL(CodeFirst模式)配置指引
  9. java如何获得当前文件路径
  10. python3讨论交流地_讨论 - 廖雪峰的官方网站