简介

主页:https://github. com/microsoft/Swin-Transformer.
Swin Transformer 是 2021 ICCV最佳论文,屠榜了各大CV任务,性能优于DeiT、ViT和EfficientNet等主干网络,已经替代经典的CNN架构,成为了计算机视觉领域通用的backbone。

Swin Transformer 基于了ViT模型的思想,创新性的引入了滑动窗口机制,让模型能够学习到跨窗口的信息,同时也。同时通过下采样层,使得模型能够处理超分辨率的图片,节省计算量以及能够关注全局和局部的信息

ViT 开启了transformer在视觉领域的征途,但是transformer从自然语言领域应用到计算机视觉领域还有两大挑战:

  • 视觉实体的方差较大,例如同一个物体,拍摄角度不同,转化为二进制后的图片就会具有很大的差异。同时在不同场景下视觉 Transformer 性能未必很好。
  • 图像分辨率高,像素点多,如果采用ViT模型,自注意力的计算量会与像素的平方成正比。

针对上述两个问题,论文中提出了一种基于滑动窗口机制,具有层级设计(下采样层) 的 Swin Transformer。

其中滑窗操作包括不重叠的 local window,和重叠的 cross-window。将注意力计算限制在一个窗口(window size固定)中,一方面能引入 CNN 卷积操作的局部性,另一方面能大幅度节省计算量,它只和窗口数量成线性关系。

通过下采样的层级设计,能够逐渐增大感受野,从而使得注意力机制也能够注意到全局的特征

从上图可知,Swin Transformer 思想是实现 ViT 到类似卷积模式的转变,这样的结构模式能适用于各类视觉任务,真正成为视觉领域通用的backbone。

实现原理


模型整体采取了层次化的设计

  • 在输入开始的时候做了一个Patch Partition,即ViT中的Patch Embedding操作,通过Patch_size为4的卷积层将图片切成一个个Patch,并嵌入到Embedding,将embedding_size 转变为48(可以将 CV 中图片的通道数理解为NLP中token的词嵌入长度)
  • 第一行Stage中通过Linear Embedding 调整通道数为 C
  • 在后3个Stage均由Patch Merging 和 多个 Swin Transformer Block组成
  • Patch Merging 模块主要在每个Stage一开始降低图片分辨率,进行下采样操作
  • Swin Transformer Block如上图右边所示,主要是LayerNorm,Window Attention,Shifted Window Attention和MLP组成

Patch Merging 总是在两个Swin Transformer Block之间执行下采样,最后一个Stage不需要下采样操作,之间通过后续的全连接层与 target label 计算损失。

class SwinTransformer(nn.Module):r""" Swin TransformerA PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -https://arxiv.org/pdf/2103.14030Args:img_size (int | tuple(int)): Input image size. Default 224patch_size (int | tuple(int)): Patch size. Default: 4in_chans (int): Number of input image channels. Default: 3num_classes (int): Number of classes for classification head. Default: 1000embed_dim (int): Patch embedding dimension. Default: 96depths (tuple(int)): Depth of each Swin Transformer layer.num_heads (tuple(int)): Number of attention heads in different layers.window_size (int): Window size. Default: 7mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: Nonedrop_rate (float): Dropout rate. Default: 0attn_drop_rate (float): Attention dropout rate. Default: 0drop_path_rate (float): Stochastic depth rate. Default: 0.1norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.ape (bool): If True, add absolute position embedding to the patch embedding. Default: Falsepatch_norm (bool): If True, add normalization after patch embedding. Default: Trueuse_checkpoint (bool): Whether to use checkpointing to save memory. Default: Falsefused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False"""def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,norm_layer=nn.LayerNorm, ape=False, patch_norm=True,use_checkpoint=False, fused_window_process=False, **kwargs):super().__init__()self.num_classes = num_classesself.num_layers = len(depths)self.embed_dim = embed_dimself.ape = apeself.patch_norm = patch_normself.num_features = int(embed_dim * 2 ** (self.num_layers - 1))self.mlp_ratio = mlp_ratio# split image into non-overlapping patchesself.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,norm_layer=norm_layer if self.patch_norm else None)num_patches = self.patch_embed.num_patchespatches_resolution = self.patch_embed.patches_resolutionself.patches_resolution = patches_resolution# absolute position embeddingif self.ape:self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))trunc_normal_(self.absolute_pos_embed, std=.02)self.pos_drop = nn.Dropout(p=drop_rate)# stochastic depthdpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule# build layersself.layers = nn.ModuleList()for i_layer in range(self.num_layers):layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),input_resolution=(patches_resolution[0] // (2 ** i_layer),patches_resolution[1] // (2 ** i_layer)),depth=depths[i_layer],num_heads=num_heads[i_layer],window_size=window_size,mlp_ratio=self.mlp_ratio,qkv_bias=qkv_bias, qk_scale=qk_scale,drop=drop_rate, attn_drop=attn_drop_rate,drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],norm_layer=norm_layer,downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,use_checkpoint=use_checkpoint,fused_window_process=fused_window_process)self.layers.append(layer)self.norm = norm_layer(self.num_features)self.avgpool = nn.AdaptiveAvgPool1d(1)self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)def forward_features(self, x):x = self.patch_embed(x)if self.ape:x = x + self.absolute_pos_embedx = self.pos_drop(x)for layer in self.layers:x = layer(x)x = self.norm(x)  # B L Cx = self.avgpool(x.transpose(1, 2))  # B C 1x = torch.flatten(x, 1)return xdef forward(self, x):x = self.forward_features(x)x = self.head(x)return xdef flops(self):flops = 0flops += self.patch_embed.flops()for i, layer in enumerate(self.layers):flops += layer.flops()flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)flops += self.num_features * self.num_classesreturn flops
  • ViT 在输入会给 embedding 进行位置编码。而 Swin-T 这里则是作为一个可选项(self.ape),Swin-T 是在计算 Attention 的时候做了一个相对位置编码,我认为这是这篇论文设计最巧妙的地方。
  • ViT 会单独加上一个可学习参数,作为分类的 token。而 Swin-T 则是直接做平均(avgpool),输出分类,有点类似 CNN 最后的全局平均池化层。

Patch Embedding

在进入 Block 前,需要将图片切分成多个 patch,然后嵌入向量,具体做法是对原始图片裁成多个 window_size * window_size 的窗口大小,然后进行嵌入。即通过二维卷积层,设置 stride = kernel_size = window_size,设定输出通道来确定嵌入向量的大小。最后将 H,W 维度展开,并移动到第一维度。

class PatchEmbed(nn.Module):r""" Image to Patch EmbeddingArgs:img_size (int): Image size.  Default: 224.patch_size (int): Patch token size. Default: 4.in_chans (int): Number of input image channels. Default: 3.embed_dim (int): Number of linear projection output channels. Default: 96.norm_layer (nn.Module, optional): Normalization layer. Default: None"""def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):super().__init__()img_size = to_2tuple(img_size)patch_size = to_2tuple(patch_size)patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]self.img_size = img_sizeself.patch_size = patch_sizeself.patches_resolution = patches_resolutionself.num_patches = patches_resolution[0] * patches_resolution[1]self.in_chans = in_chansself.embed_dim = embed_dimself.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)if norm_layer is not None:self.norm = norm_layer(embed_dim)else:self.norm = Nonedef forward(self, x):B, C, H, W = x.shape# FIXME look at relaxing size constraintsassert H == self.img_size[0] and W == self.img_size[1], \f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw Cif self.norm is not None:x = self.norm(x)return xdef flops(self):Ho, Wo = self.patches_resolutionflops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])if self.norm is not None:flops += Ho * Wo * self.embed_dimreturn flops

Patch Merging

在每个 Stage 开始前做降采样,用于缩小分辨率,调整通道数进而形成层次化的设计,同时也能节省一定运算量。

在 CNN 中,则是在每个 Stage 开始前用stride=2的卷积/池化层来降低分辨率。

每次降采样是两倍,因此在行方向和列方向上,间隔 2 选取元素。

然后拼接在一起作为一整个张量,最后展开。此时通道维度会变成原先的 4 倍(因为 H,W 各缩小 2 倍),此时再通过一个全连接层再调整通道维度为原来的两倍。

class PatchMerging(nn.Module):r""" Patch Merging Layer.Args:input_resolution (tuple[int]): Resolution of input feature.dim (int): Number of input channels.norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm"""def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):super().__init__()self.input_resolution = input_resolutionself.dim = dimself.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)self.norm = norm_layer(4 * dim)def forward(self, x):"""x: B, H*W, C"""H, W = self.input_resolutionB, L, C = x.shapeassert L == H * W, "input feature has wrong size"assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."x = x.view(B, H, W, C)x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 Cx1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 Cx2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 Cx3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 Cx = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*Cx = x.view(B, -1, 4 * C)  # B H/2*W/2 4*Cx = self.norm(x)x = self.reduction(x)return xdef extra_repr(self) -> str:return f"input_resolution={self.input_resolution}, dim={self.dim}"def flops(self):H, W = self.input_resolutionflops = H * W * self.dimflops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dimreturn flops

Window Partition/Reverse

window partition函数是用于对张量划分窗口,指定窗口大小。将原本的张量从 N H W C, 划分成 num_windowsB, window_size, window_size, C,其中 num_windows = HW / window_size*window_size,即窗口的个数。而window reverse函数则是对应的逆过程。这两个函数会在后面的Window Attention用到。

def window_partition(x, window_size):"""Args:x: (B, H, W, C)window_size (int): window sizeReturns:windows: (num_windows*B, window_size, window_size, C)"""B, H, W, C = x.shapex = x.view(B, H // window_size, window_size, W // window_size, window_size, C)windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windowsdef window_reverse(windows, window_size, H, W):"""Args:windows: (num_windows*B, window_size, window_size, C)window_size (int): Window sizeH (int): Height of imageW (int): Width of imageReturns:x: (B, H, W, C)"""B = int(windows.shape[0] / (H * W / window_size / window_size))x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)return x

Window Attention

传统的 Transformer 都是基于全局来计算注意力的,因此计算复杂度十分高。而 Swin Transformer 则将注意力的计算限制在每个窗口内,进而减少了计算量。

主要区别是在原始计算 Attention 的公式中的 Q,K 时加入了相对位置编码 B

class WindowAttention(nn.Module):r""" Window based multi-head self attention (W-MSA) module with relative position bias.It supports both of shifted and non-shifted window.Args:dim (int): Number of input channels.window_size (tuple[int]): The height and width of the window.num_heads (int): Number of attention heads.qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if setattn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0proj_drop (float, optional): Dropout ratio of output. Default: 0.0"""def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):super().__init__()self.dim = dimself.window_size = window_size  # Wh, Wwself.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5# define a parameter table of relative position biasself.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH# get pair-wise relative position index for each token inside the windowcoords_h = torch.arange(self.window_size[0])coords_w = torch.arange(self.window_size[1])coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Wwcoords_flatten = torch.flatten(coords, 1)  # 2, Wh*Wwrelative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Wwrelative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0relative_coords[:, :, 1] += self.window_size[1] - 1relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Wwself.register_buffer("relative_position_index", relative_position_index)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)trunc_normal_(self.relative_position_bias_table, std=.02)self.softmax = nn.Softmax(dim=-1)

相关位置编码

Q,K,V.shape=[numWindwosB, num_heads, window_sizewindow_size, head_dim]

window_size*window_size 即 NLP 中token的个数
head_dim = E m b e d d i n g d i m n u m h e a d s \frac{Embedding_dim}{num_heads} numheadsEmbeddingdim即 NLP 中token的词嵌入向量的维度

Q K T QK^T QKT计算出来的Attention张量的形状为[numWindowsB, num_heads, Q_tokens, K_tokens],其中Q_tokens=K_tokens=window_sizewindow_size

以 window_size = 2 为例


第 i 行表示第 i 个 token 的query对所有token的key的attention。
对于 Attention 张量来说,以不同元素为原点,其他元素的坐标也是不同的



由于最终我们希望使用一维的位置坐标 x+y 代替二维的位置坐标(x,y),为了避免 (1,2) (2,1) 两个坐标转为一维时均为3,我们之后对相对位置索引进行了一些线性变换,使得能通过一维的位置坐标唯一映射到一个二维的位置坐标,详细可以通过代码部分进行理解。

利用torch.arange和torch.meshgrid函数生成对应的坐标

coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2*(wh, ww)
"""(tensor([[0, 0],[1, 1]]), tensor([[0, 1],[0, 1]]))
"""
# 堆叠起来,展开为一个二维向量
coords = torch.stack(coords)  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
"""
tensor([[0, 0, 1, 1],[0, 1, 0, 1]])
"""# 利用广播机制,分别在第一维,第二维,插入一个维度,进行广播相减,得到 2, wh*ww, wh*ww的张量
relative_coords_first = coords_flatten[:, :, None]  # 2, wh*ww, 1
relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh*ww
relative_coords = relative_coords_first - relative_coords_second # 最终得到 2, wh*ww, wh*ww 形状的张量


因为采取的是相减,所以得到的索引是从负数开始的,加上偏移量,让其从 0 开始。

relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1


需要将其展开成一维偏移量。而对于 (1,2)和(2,1)这两个坐标。在二维上是不同的,但是通过将 x,y 坐标相加转换为一维偏移的时候,他的偏移量是相等的。

对其中做了个乘法操作,以进行区分

relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1

然后再最后一维上进行求和,展开成一个一维坐标,并注册为一个不参与网络学习的变量

relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)

之前计算的是相对位置索引,并不是相对位置偏置参数。真正使用到的可训练参数 β ^ \hat{\beta} β^是保存在 relative position bias table 表里的,这个表的长度是等于 (2M−1) × (2M−1) (在二维位置坐标中线性变化乘以2M-1导致)的。那么上述公式中的相对位置偏执参数 B是根据上面的相对位置索引表根据查relative position bias table表得到的。

接着Window Attention代码

def forward(self, x, mask=None):"""Args:x: input features with shape of (num_windows*B, N, C)mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None"""B_, N, C = x.shapeqkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)q = q * self.scaleattn = (q @ k.transpose(-2, -1))relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nHrelative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Wwattn = attn + relative_position_bias.unsqueeze(0)if mask is not None:nW = mask.shape[0]attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B_, N, C)x = self.proj(x)x = self.proj_drop(x)return xdef extra_repr(self) -> str:return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'def flops(self, N):# calculate flops for 1 window with token length of Nflops = 0# qkv = self.qkv(x)flops += N * self.dim * 3 * self.dim# attn = (q @ k.transpose(-2, -1))flops += self.num_heads * N * (self.dim // self.num_heads) * N#  x = (attn @ v)flops += self.num_heads * N * N * (self.dim // self.num_heads)# x = self.proj(x)flops += N * self.dim * self.dimreturn flops
  • 首先输入张量形状为 [numWindows*B, window_size * window_size, C]
  • 然后经过self.qkv这个全连接层后,进行 reshape,调整轴的顺序,得到形状为[3, numWindowsB, num_heads, window_sizewindow_size, c//num_heads],并分配给q,k,v
  • 根据公式,我们对q乘以一个scale缩放系数,然后与k(为了满足矩阵乘要求,需要将最后两个维度调换)进行相乘。得到形状为[numWindowsB,num_heads, window_sizewindow_size, window_size*window_size]的attn张量
  • 之前我们针对位置编码设置了个形状为(2window_size-12window_size-1,numHeads)的可学习变量。我们用计算得到的相对编码位置索引self.relative_position_index.vew(-1)选取,得到形状为(window_sizewindow_size,window_size*window_size, numHeads)的编码,再permute(2,0,1)后加到attn张量上
  • 暂不考虑 mask 的情况,剩下就是跟 transformer 一样的 softmax,dropout,与V矩阵乘,再经过一层全连接层和dropout

Shifted Window Attention

Window Attention 是在每个窗口下计算注意力的,为了更好的和其他 window 进行信息交互,Swin Transformer 还引入了 shifted window 操作。

左边是没有重叠的 Window Attention,而右边则是将窗口进行移位的 Shift Window Attention。可以看到移位后的窗口包含了原本相邻窗口的元素。但这也引入了一个新问题,即 window 的个数翻倍了,由原本四个窗口变成了 9 个窗口。

在实际代码里,我们是通过对特征图移位,并给 Attention 设置 mask 来间接实现的。能在保持原有的 window 个数下,最后的计算结果等价。

Shifted Window step

代码里对特征图移位是通过 torch.roll 来实现的

torch.roll(input, shifts, dims=None) → Tensor

shifts的值为正数相当于向下挤牙膏,挤出的牙膏又从顶部塞回牙膏里面;shifts的值为负数相当于向上挤牙膏,挤出的牙膏又从底部塞回牙膏里面

以 4x4 矩阵 a 为例

a 矩阵中的 ( 1 )代表 A 区域,( 2,3,4 ) 代表 C 区域,( 5,9,13 )代表 B区域,首先将第一行挤到最后一行,如下图矩阵 b

然后再将第一列挤到最后一列,如下图矩阵 b

如果需要reverse cyclic shift的话只需把参数shifts设置为对应的正数值

Attention Mask


通过 roll 操作,我们确实把9块归为了4块,但是 cyclic shift 中,A 从左上角 移动到了 右下角,显然,直接对 cyclic shift 4块进行计算会破坏原有的语义信息,为此,这里使用了 mask 操作。

上图展示 cyclic shift 后 特征图,拿到 window后,执行 Q K T QK^T QKT ,就是将Q K T K^T KT 分别展平然后对应元素相乘,根据这一过程,可以得到如上图所示不同 Windows 的 Mask,-100的紫色区域表示遮掩,紫色部分是不同块的运算结果,应该丢弃

具体代码在 SwinTransformerBlock中

 if self.shift_size > 0:# calculate attention mask for SW-MSAH, W = self.input_resolutionimg_mask = torch.zeros((1, H, W, 1))  # 1 H W 1h_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))w_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))cnt = 0for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1mask_windows = mask_windows.view(-1, self.window_size * self.window_size)attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))else:attn_mask = None

从上图代码中,可以得到如下的Mask

tensor([[[[[   0.,    0.,    0.,    0.],[   0.,    0.,    0.,    0.],[   0.,    0.,    0.,    0.],[   0.,    0.,    0.,    0.]]],[[[   0., -100.,    0., -100.],[-100.,    0., -100.,    0.],[   0., -100.,    0., -100.],[-100.,    0., -100.,    0.]]],[[[   0.,    0., -100., -100.],[   0.,    0., -100., -100.],[-100., -100.,    0.,    0.],[-100., -100.,    0.,    0.]]],[[[   0., -100., -100., -100.],[-100.,    0., -100., -100.],[-100., -100.,    0., -100.],[-100., -100., -100.,    0.]]]]])

在上面的 window attention 模块的前向代码中,使用mask掩膜

if mask is not None:nW = mask.shape[0] # 一张图被分为多少个windows eg:[4,49,49]attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) # torch.Size([128, 4, 12, 49, 49]) torch.Size([1, 4, 1, 49, 49])attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)
else:attn = self.softmax(attn)

将 mask 加到 attention 的计算结果,并进行 softmax。mask 的值设置为 - 100,softmax 后就会忽略掉对应的值。

W-MSA和MSA的复杂度对比

在原论文中,作者提出的基于滑动窗口操作的 W-MSA 能大幅度减少计算量。那么两者的计算量和算法复杂度大概是如何的呢,论文中给出了一下两个公式进行对比。

  • h:feature map的高度
  • w:feature map的宽度
  • C:feature map的通道数(也可以称为embedding size的大小)
  • M:window_size的大小

MSA模块的计算量

首先对于feature map中每一个token(一共有 h w 个token,通道数为C),记作 X h w × C X^{hw\times C} Xhw×C,需要通过三次线性变换 V q , W k , W v V_q,W_k,W_v Vq,Wk,Wv ,产生对应的q,k,v向量,记作 Q h w × C , K h w × C , V h w × C Q^{hw \times C},K^{hw \times C},V^{hw \times C} Qhw×C,Khw×C,Vhw×C(通道数为C)。


根据矩阵运算的计算量公式可以得到运算量为 3 h w C 2 3hwC^2 3hwC2

忽略除以 d \sqrt{d} d

以及softmax的计算量,根据根据矩阵运算的计算量公式可得 h w C × h w + h w 2 × C hwC \times hw + hw^2 \times C hwC×hw+hw2×C,即 2 ( h w 2 ) C 2(hw^2)C 2(hw2)C

最终再通过一个Linear层输出,计算量为 h w C 2 hwC^2 hwC2。因此整体的计算量为 4 h w C 2 + 2 ( h w ) 2 C 4hwC^2+2(hw)^2C 4hwC2+2(hw)2C

W-MSA模块的计算量

对于W-MSA模块,首先会将feature map根据window_size分成 h w M 2 \frac{hw}{M^2} M2hw的窗口,每个窗口的宽高均为 M,然后在每个窗口进行MSA的运算。因此,可以利用上面MSA的计算量公式,将 h = M, w = M 带入,可以得到一个窗口的计算量为 4 M 2 C 2 + 2 M 4 C 4M^2C^2 + 2M^4C 4M2C2+2M4C,又因为有 h w M 2 \frac{hw}{M^2} M2hw个窗口

整体流程



文章来源:https://zhuanlan.zhihu.com/p/430047908

13、Swin Transformer: Hierarchical Vision Transformer using Shifted Windows相关推荐

  1. 【读点论文】Swin Transformer: Hierarchical Vision Transformer using Shifted Windows通过窗口化进行局部MSA,sw-MSA融合信息

    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows abstract 本文提出了一种新的视觉transfor ...

  2. 【文献阅读】Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

    题目:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 时间:2021 会议/期刊:ICCV 2021 研 ...

  3. 【Transformer 论文精读】……Swin Transformer……(Hierarchical Vision Transformer using Shifted Windows)

    文章目录 一.Abstract(摘要) 二.Introduction(引言) 三.Related Work(相关工作) 四.Method(方法) 1.Patch Merging模块 2.W-MSA模块 ...

  4. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

    目录 Introduction Method Hierarchical feature maps and Linear computational complexity Patch merging S ...

  5. 论文阅读笔记:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

    论文阅读笔记:Swin Transformer 摘要 1 简介 2 相关工作 3 方法论 3.1 总览 Swin Transformer block 3.2 shifted window-based ...

  6. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows论文阅读

    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows论文阅读 摘要 介绍 相关工作 方法 整个架构 基于sel ...

  7. 【Swin Transformer】Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

    文章:https://arxiv.org/abs/2103.14030 代码:GitHub - microsoft/Swin-Transformer: This is an official impl ...

  8. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 论文翻译 + 网络详解

    目录 1 3 4 5 是论文的翻译,如果看过论文也可以直接看关于网络的结构详解. Abstract 1. Introduction 3. Method 3.1 Overall Architicture ...

  9. 论文精读:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

    Abstract 本文提出了一种新的vision Transformer,称为Swin Transformer,它能够作为计算机视觉的通用骨干网络.从语言到视觉的挑战来自于这两个领域之间的差异,比如视 ...

最新文章

  1. Java工具类-转换字符编码
  2. 计算机应用论文2500字,计算机应用论文2500字:计算机基础.doc
  3. Yann LeCun 怒喷 Sophia:这就是彻头彻尾的骗局
  4. BSCI—8-(2):OSPF的特殊区域类型与配置
  5. Centos7 网络配置
  6. tracker服务器列表2020_个人服务器采购整理分享
  7. Redis必须要知道的几点
  8. centos6.8下安装dc2012
  9. Tensorflow教程2:使用卷积神经网络的图像分类器
  10. 用Rust保存Windows聚焦图片
  11. 软件工程 之 软件维护
  12. MenuetOS-令人不可思议的64位操作系统!
  13. Java程序员情人节_关于程序员:一封来自Java程序员的情书
  14. 云计算大数据中心(清明作业)
  15. 交通强国,标准先行【附PPT】
  16. R 数据四舍五入函数教程
  17. 有哪些冷门却好用的东西可以网购?
  18. CISP证书价值如何
  19. JavaSwing——利息计算器
  20. Project 的简单使用

热门文章

  1. wireshark分析tcp,rtp
  2. Linux初学者成为高手的学习步骤和建议【新人必看】
  3. MATLAB meshgrid函数
  4. Ubuntu 加入开机自启动命令(rc.local)
  5. 准备机试时候不懂的问题
  6. 足球胜平负数据这样分析竞猜准确率超高,你敢相信吗?
  7. solidwork学习
  8. 国外真实情景,英语日常口语提示
  9. exports 和 module.exports
  10. 叶子的离去,是风的追求还是树的不挽留!(受益终生)