Swim transformer是2021年ICCV的best paper。

论文题目:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows(使用移动窗口的层级式的Vision transformer)
论文地址:https://arxiv.org/abs/2103.14030

动机:

transformer从NLP用于CV任务中主要有两个挑战:
1、尺度问题。例如一张图像中包含很多车和行人,物体有大有小,而且代表同一个语义的词可能在图像中具有不同的尺寸。
2、图像的分辨率大,如果以像素为单位,序列的长度太长。之前减少序列长度主要有一下三方面的工作:
1、使用网络后续的特征图,当做transformer的输入。
2、图片打成patch,减少图像的resolution。
3、将图片划成小窗口,在窗口中做自注意力。

本文作者提出层级式的transformer结构,特征通过移动窗口的方式学习得到。窗口自注意力(Window Multi-heads Self-Attention, W-MSA)相比于基于全局的自注意力(Multi-heads Self-Attention, MSA)方法的方式,减少了计算复杂度;但由于每个窗口之间互不重叠,导致相邻窗口之间信息无法交互,因此提出移动窗口(Shifted Window Multi-heads Self-Attention, SW-MSA)的方法,使相邻两个窗口之间的信息有了交互,上下层之间有了cross window connection,变相的达到了全局建模的能力。层级结构的好处在于可以提供各个尺度的特征信息,使transformer拥有了像CNN一样的分层结构,有了多尺度特征,更好的应用在下游任务上。

模型结构


由图可知模型主要包含Patch partition、Linear embedding、Swim Transformer Block、W-MSA和SW-MSA组成。接下来一一介绍以及代码实现。

Patch Partition + Linear Embedding

Patch Partition对输入图像进行下采样,将原始输入图像H,W,C,宽高下降1/4,通道进调整到48。Linear Embedding是对Patch Partition的输出在通道维度进行调整为C后,过Layer Normalization。C为不同的模型设置不同的C值。 实现是直接使用nn.Conv2d直接从输入图像的通道维度转为模型的需要设置的通道维度,即[H, W, C] -> [H/4, W/4, C].

class patchEmbed(nn.Module):# 将图片分割成不重叠的小patch 尺寸下采样尺寸为patch size的大小def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None):super(patchEmbed, self).__init__()patch_size = (patch_size, patch_size)self.patch_size = patch_sizeself.in_channels = in_channelsself.embed_dim = embed_dimself.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()def forward(self, x):_, _, H, W = x.shape# padding 如果输入image的H, W不是patch size的整数倍,进行paddingpad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)if pad_input:# 图像的W方向的右侧padding H方向的下方padding# Pad (w_left, w_right, h_top, h_bottle, c_front, c_back)# (左边填充数, 右边填充数, 上边填充数, 下边填充数, 前边填充数,后边填充数)x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],0, self.patch_size[0] - H % self.patch_size[0],0, 0))# down-sample 下采样率为patch size的大小x = self.proj(x)_, _, H, W = x.shape# flatten [B, C, H, W] -> [B, C, HW]# transpose [B, C, HW] -> [B, HW, C]x = x.flatten(2).transpose(1, 2)x = self.norm(x)return x, H, W

Patch Merging

下采样的作用。通过该模型特征图的H,W缩小一半,通道维度翻倍。使用2*2的窗口大小对图像进行切分,然后分割成不同的小块,在通道维度进行concat,通道维度提高了4倍,过Layer Norm后,使用全连接进行通道维度的调整,调整为2倍。

class patchmerging(nn.Module):# down-sampledef __init__(self, dim, norm_layer=nn.LayerNorm):super(patchmerging, self).__init__()self.dim = dimself.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)self.norm = norm_layer(4 * dim)def forward(self, x, H, W):# x: [B, HW, C]B, L, C = x.shapex = x.view(B, H, W, C)# padding 如果H,W不是2的整数倍,进行填充pad_input = (H % 2 == 1) or (W % 2 == 1)if pad_input:x = F.pad(x, (0, W % 2, 0, H % 2, 0, 0))# [B, H, W, C]x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C] 左上x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C] 左下x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C] 右上x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C] 右下x = torch.cat([x0, x1, x2, x3], -1) # [B, H/2, W/2, 4*C]x = x.view(B, -1, 4 * C) # [B, H/2*W/2, 4*C]x = self.norm(x)x = self.reduction(x) # [B, H/2*W/2, 4*C] -> [B, H/2*W/2, 2*C]return x

Swim Transformer Block

1、(shifted)Window Multi-head self-Attention (W-MSA)

     Multi-head Self-Attent: 在特征图中对每个像素求q,k,v, 将每个像素求得的q,与特征图中所有像素的key做匹配,然后进行后面softmax,v的操作。每个像素都会和整个特征图的所有像素进行信息的交互。Window Multi-head self-Attention: 首先对特征图分为一个个窗口,对每个窗口的内部做multi-head self-attention操作。Shifted Window Multi-head self-Attention:向右和向下移动. 向下取整(M/2), 向下取整(M/2)。M为窗口大小。例如:feature_map = 9, window_size = 3, 则特征图第一行移动到最下面,最左边第一列,移动到最右边。在移动后的特征图中,使用window_size进行分割成一个个小窗口。在窗口内进行W-MSA/SW-MSA计算. 计算SW-MSA需要先创建mask模板,
       if self.shift_size > 0.:# SW-MSA 从上往下 从左往右# 上面的shift size移动到下面 左边移动右边shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else:# W-MSAshifted_x = xattn_mask = None
# 指定window大小,重新划分window
def window_partition(x, window_size: int):# 将feature map(image mask) 按照 window_size的大小 划分成一个个没有重叠的windowB, H, W, C = x.shape# [B, H//M, W//M, M, C] M: window_sizex = x.view(B, H // window_size, window_size, W // window_size, window_size, C)# permute: [B, H//M, M, W//M, M, C] -> [B, H//M, W//M, M, M, C]# contiguous(): 变为内存连续的数据# view: [B, H//M, W//M, M, M, C] -> [B * window_num, MH, MW, C] 第一个M为窗口H, 第二个M为窗口Wwindows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windows
# 将window还原成一个feature map
def window_reverse(windows, window_size: int, H: int, W:int):# 将窗口还原成一个feature map. H,W 代表分割之前的H,W# windows: [B * window_num , MH, MW, C]B = int(windows.shape[0] / (H * W / window_size / window_size))# view: [B * window_num , MH, MW, C] -> [B, H//M, W//M, M, M, C]x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)# permute: [B, H//M, W//M, M, M, C] -> [B, H//M, M, W//M, M, C]# view: [B, H//M, M, W//M, M, C] -> [B, H, W, C]x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)return x

    def create_mask(self, x, H, W):# 保证H、W可以被window size整除 ceil 向上取整H_padding = int(np.ceil(H / self.window_size)) * self.window_sizeW_padding = int(np.ceil(W / self.window_size)) * self.window_size# 和feature map一样的通道排列顺序img_mask = torch.zeros((1, H_padding, W_padding, 1), device=x.device) # [B, H, W, C]# slice 切片 切出每个窗口中分别具有相似元素的位置。如上图例子。feature map=9*9, window_size=3*3h_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))w_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.window_size),slice(-self.window_size, None))cnt = 0for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1  # 将mask划分成一个个窗口# [B * window_num , MH, MW, C]mask_windows = window_partition(img_mask, self.window_size)# 将每一个窗口内的元素展平# [B * window_num * C, MH*MW]mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # [B * window_num * C, 1, MH*MW] - [B * window_num * C, MH*MW, 1] 广播机制 -> [B * window_num * C, MH*MW, MH*MW]#  mask_windows.unsqueeze(1) 将每个窗口的行向量复制MH*MW次#  mask_windows.unsqueeze(2) 将每个窗口的行向量中每个元素 复制MH*MW次attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)# 同一区域为0 不同区域为非0数。 得到当前窗口中对应某一个像素 所采用的attention mask。attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.)).masked_fill(attn_mask == 0, float(0.0))return attn_mask
#
class WindowAttention(nn.Module):# 实现W-MSA SW-MSAdef __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):super(WindowAttention, self).__init__()self.dim = dimself.window_size = window_sizeself.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5 # 根号dself.relative_positive_bias_table = nn.Parameter(# 长度为[(2*Mh-1) * (2*Mw-1), num_heads]torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) )# 生成relative_position_indexcoords_h = torch.arange(self.window_size[0])coords_w = torch.arange(self.window_size[1])coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # [2, Mh, Mw]# 第一行为feature map中每一个像素对应的行标(x)# 第二行为feature map中每一个像素对应的列标(y)coords_flatten = torch.flatten(coords, 1) # [2, Mh*Mw] 绝对位置索引# [2, Mh*Mw, 1] - [2, 1, Mh*Mw]# [2, Mh*Mw, Mh*Mw] 得到相对位置索引的矩阵。 以每一个像素作为参考点 - 当前feature map/window当中所有的像素点绝对位置索引 = 得到相对位置索引的矩阵# broadcast coords_flatten[:, :, None] 按w维度 每一行的元素复制# coords_flatten[:, None, :] 按h维度 每一行元素整体复制relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]# permute: 将窗口中按每个像素求得的相对位置索引 组成矩阵relative_coords = relative_coords.permute(1, 2, 0).contiguous() # [Mh, Mw, 2]# 二元索引->一元索引relative_coords[:, :, 0] += self.window_size[0] - 1relative_coords[:, :, 1] += self.window_size[1] - 1relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1relative_position_index = relative_coords.sum(-1) # [Mh*Mw, Mh*Mw]# 放到模型缓存中self.register_buffer('relative_position_index', relative_position_index)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)nn.init.trunc_normal_(self.relative_positive_bias_table, std=.02)self.softmax = nn.Softmax(dim=-1)def forward(self, x, mask: Optional[torch.Tensor] = None):# [batch_size * num_windows, Mh*Mw, total_embed_dim]B_, N, C = x.shape# qkv: -> [batch_size * num_windows, Mh*Mw, 3 * total_embed_dim]# reshape: -> [batch_size * num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]# permute: -> [3, batch_size * num_windows, num_heads, Mh*Mw, embed_dim_per_head]qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)# [batch_size * num_windows, num_heads, Mh*Mw, embed_dim_per_head]q, k, v = qkv.unbind(0)q = q * self.scale# transpose: -> [batch_size * num_windows, num_heads,embed_dim_per_head, Mh*Mw]# @: multiply: -> [batch_size * num_windows, num_heads, Mh*Mw, Mh*Mw]attn = (q @ k.transpose(-2, -1))# self.relative_positive_bias_table.view: -> [Mh*Mw*Mh*Mw, num_head] -> [Mh*Mw, Mh*Mw, num_head]relative_position_bias = self.relative_positive_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # [num_head, Mh*Mw, Mh*Mw]# [batch_size * num_windows, num_heads, Mh * Mw, Mh * Mw]attn = attn + relative_position_bias.unsqueeze(0)if mask is not None:# mask: [num_windows, Mh*Mw, Mh*Mw]num_window = mask.shape[0]# view: [batch_size, num_windows, num_heads, Mh * Mw, Mh * Mw]# mask: [1, num_windows, 1, Mh*Mw, Mh*Mw]attn = attn.view(B_ // num_window, num_window, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)# [batch_size*num_windows, num_heads, Mh * Mw, Mh * Mw]attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn)attn = self.attn_drop(attn)# @: [batch_size*num_windows, num_heads, Mh * Mw, embed_dim_per_head]# transpose: [batch_size*num_windows, Mh * Mw, num_heads, embed_dim_per_head]# reshape: [num_windows, Mh * Mw, num_heads*embed_dim_per_head]x = (attn @ v).transpose(1, 2).reshape(B_, N, C)x = self.proj(x)x = self.proj_drop(x)return x

2、MLP

和传统自注意力机制使用的MLP结构相同。

class MLP(nn.Module):def __init__(self, in_features, hidden_features=None, act=nn.GELU, drop=0.):super(MLP, self).__init__()self.fc1 = nn.Linear(in_features, hidden_features)self.act = act()self.fc2 = nn.Linear(hidden_features, in_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return x

3、Swim Transformer Block整体结构

class SwimTransformerBlock(nn.Module):def __init__(self, dim, num_heads, window_size=7, shift_size=0., mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):super(SwimTransformerBlock, self).__init__()self.dim = dimself.num_heads = num_headsself.window_size = window_sizeself.shift_size = shift_sizeself.mlp_ratio = mlp_ratioself.norm1 = norm_layer(dim)self.attn = WindowAttention(dim=dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = (int(dim * mlp_ratio))self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act=act_layer, drop=drop)def forward(self, x, attn_mask):H, W = self.H, self.W # feature map H WB, L, C = x.shape # L = H * Wshortcut = xx = self.norm1(x)x = x.view(B, H, W, C)x_r = (self.window_size - W % self.window_size) % self.window_sizex_d = (self.window_size - H % self.window_size) % self.window_sizex = F.pad(x, (0, 0, 0, x_r, 0, x_d))_, Hp, Wp, _ = x.shape # Hp Wp代表padding后的H Wif self.shift_size > 0.:# SW-MSA 从上往下 从左往右shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) # 上面的shift size移动到下面 左边移动右边else:# W-MSAshifted_x = xattn_mask = None# 特征图切成小窗口x_windows = window_partition(shifted_x, self.window_size) # [B * window_num, MH, MW, C]x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # [B * window_num, MH*MW, C]# W-MSA SW-MSAattn_windows = self.attn(x_windows, mask=attn_mask)# 小窗口合并成特征图attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # [B * window_num, MH, MW, C]shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # [B, H, W, C]# SW-MSA后还原数据 从下往上 从右往左if self.shift_size > 0:x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))else:x = shifted_x# 移除paddingif x_r > 0 or x_d > 0:x = x[:, :H, :W, :].contiguous()x = x.view(B, H * W, C)x = shortcut + self.drop_path(x)x = x + self.drop_path(self.mlp(self.norm2(x)))return x

构建模型的stage(Swim Transformer Block + Patch Merging)

class BasicLayer(nn.Module):# 每个stage的实现def __init__(self, dim, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None):super(BasicLayer, self).__init__()self.dim = dimself.depth = depthself.window_size = window_sizeself.shift_size = window_size // 2 # 窗口大小/2 向下取整 窗口向右向下移动的步长self.blocks = nn.ModuleList([SwimTransformerBlock(dim=dim, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else self.shift_size,mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,norm_layer=norm_layer) for i in range(depth)])if downsample is not None:self.downsample = downsample(dim=dim, norm_layer=norm_layer)else:self.downsample = Nonemask_windows = window_partition(img_mask, self.window_size) # [B * window_num , MH, MW, C]mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # [B * window_num * C, MH*MW]# [B * window_num * C, 1, MH*MW] - [B * window_num * C, MH*MW, 1] 广播机制 -> [B * window_num * C, MH*MW, MH*MW]attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.)).masked_fill(attn_mask == 0, float(0.0))return attn_maskdef forward(self, x, H, W):attn_mask = self.create_mask(x, H, W) # [B * window_num * C, MH*MW, MH*MW]for blk in self.blocks:blk.H, blk.W = H, Wx = blk(x, attn_mask)if self.downsample is not None:x = self.downsample(x, H, W)H, W = (H + 1) // 2, (W + 1) // 2return x, H, W

Swim Transform做分类的最终构建

class SwimTransformer(nn.Module):def __init__(self, downsapmle_size=4, in_channels=3, num_classes=1000, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), window_size=7, mlp_ratio=4.,qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, patch_norm=True, **kwargs):super(SwimTransformer, self).__init__()self.num_classes = num_classesself.num_layers = len(depths)self.embed_dim = embed_dimself.patch_norm = patch_norm# stage4 输出的特征矩阵的Channelself.num_features = int(embed_dim * 2 ** (self.num_layers - 1))self.mlp_ratio = mlp_ratioself.patch_embed = patchEmbed(patch_size=downsapmle_size, in_channels=in_channels, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None)self.pos_drop = nn.Dropout(p=drop_rate)dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]self.layers = nn.ModuleList()for i_layer in range(self.num_layers):layers = BasicLayer(dim=int(embed_dim * 2 ** i_layer),depth=depths[i_layer],num_heads=num_heads[i_layer],window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate,drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=patchmerging if (i_layer < self.num_layers - 1) else None)self.layers.append(layers)# classificationself.norm = norm_layer(self.num_features)self.avgpool = nn.AdaptiveAvgPool1d(1)self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()def _init_weights(self, m):if isinstance(m, nn.Linear):nn.init.trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)def forward(self, x):# [B, L, C]x, H, W = self.patch_embed(x)x = self.pos_drop(x)for layer in self.layers:x, H, W = layer(x, H, W)x = self.norm(x)# [B, L, C]x = self.avgpool(x.transpose(1, 2)) # [B, C, 1]x = torch.flatten(x, 1)x = self.head(x)return x

网络结构整体代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from timm.models.layers import DropPath
from typing import Optionalclass patchEmbed(nn.Module):# 将图片分割成不重叠的小patch 尺寸下采样尺寸为patch size的大小def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None):super(patchEmbed, self).__init__()patch_size = (patch_size, patch_size)self.patch_size = patch_sizeself.in_channels = in_channelsself.embed_dim = embed_dimself.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()def forward(self, x):_, _, H, W = x.shape# padding 如果输入image的H, W不是patch size的整数倍,进行paddingpad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)if pad_input:# 图像的W方向的右侧padding H方向的下方padding# Pad (w_left, w_right, h_top, h_bottle, c_front, c_back) (左边填充数, 右边填充数, 上边填充数, 下边填充数, 前边填充数,后边填充数)x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],0, self.patch_size[0] - H % self.patch_size[0],0, 0))# down-sample 下采样率为patch size的大小x = self.proj(x)_, _, H, W = x.shape# flatten [B, C, H, W] -> [B, C, HW]# transpose [B, C, HW] -> [B, HW, C]x = x.flatten(2).transpose(1, 2)x = self.norm(x)return x, H, Wclass patchmerging(nn.Module):# down-sampledef __init__(self, dim, norm_layer=nn.LayerNorm):super(patchmerging, self).__init__()self.dim = dimself.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)self.norm = norm_layer(4 * dim)def forward(self, x, H, W):# x: [B, HW, C]B, L, C = x.shapex = x.view(B, H, W, C)# padding 如果H,W不是2的整数倍,进行填充pad_input = (H % 2 == 1) or (W % 2 == 1)if pad_input:x = F.pad(x, (0, W % 2, 0, H % 2, 0, 0))# [B, H, W, C]x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C] 左上x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C] 左下x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C] 右上x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C] 右下x = torch.cat([x0, x1, x2, x3], -1) # [B, H/2, W/2, 4*C]x = x.view(B, -1, 4 * C) # [B, H/2*W/2, 4*C]x = self.norm(x)x = self.reduction(x) # [B, H/2*W/2, 4*C] -> [B, H/2*W/2, 2*C]return xclass MLP(nn.Module):def __init__(self, in_features, hidden_features=None, act=nn.GELU, drop=0.):super(MLP, self).__init__()self.fc1 = nn.Linear(in_features, hidden_features)self.act = act()self.fc2 = nn.Linear(hidden_features, in_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xclass WindowAttention(nn.Module):# 实现W-MSA SW-MSAdef __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):super(WindowAttention, self).__init__()self.dim = dimself.window_size = window_sizeself.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5 # 根号dself.relative_positive_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) # 长度为[(2*Mh-1) * (2*Mw-1), num_heads])# 生成relative_position_indexcoords_h = torch.arange(self.window_size[0])coords_w = torch.arange(self.window_size[1])coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # [2, Mh, Mw]coords_flatten = torch.flatten(coords, 1) # [2, Mh*Mw] 绝对位置索引# [2, Mh*Mw, 1] - [2, 1, Mh*Mw]# [2, Mh*Mw, Mh*Mw] 得到相对位置索引 以每一个像素作为参考点 - 当前feature map/window当中所有的像素点绝对位置索引 = 得到相对位置索引的矩阵relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]relative_coords = relative_coords.permute(1, 2, 0).contiguous() # [Mh, Mw, 2]# 二元索引->一元索引relative_coords[:, :, 0] += self.window_size[0] - 1relative_coords[:, :, 1] += self.window_size[1] - 1relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1relative_position_index = relative_coords.sum(-1) # [Mh*Mw, Mh*Mw]self.register_buffer('relative_position_index', relative_position_index)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)nn.init.trunc_normal_(self.relative_positive_bias_table, std=.02)self.softmax = nn.Softmax(dim=-1)def forward(self, x, mask: Optional[torch.Tensor] = None):# [batch_size * num_windows, Mh*Mw, total_embed_dim]B_, N, C = x.shape# qkv: -> [batch_size * num_windows, Mh*Mw, 3 * total_embed_dim]# reshape: -> [batch_size * num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]# permute: -> [3, batch_size * num_windows, num_heads, Mh*Mw, embed_dim_per_head]qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)# [batch_size * num_windows, num_heads, Mh*Mw, embed_dim_per_head]q, k, v = qkv.unbind(0)q = q * self.scale# transpose: -> [batch_size * num_windows, num_heads,embed_dim_per_head, Mh*Mw]# @: multiply: -> [batch_size * num_windows, num_heads, Mh*Mw, Mh*Mw]attn = (q @ k.transpose(-2, -1))# self.relative_positive_bias_table.view: -> [Mh*Mw*Mh*Mw, num_head] -> [Mh*Mw, Mh*Mw, num_head]relative_position_bias = self.relative_positive_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # [num_head, Mh*Mw, Mh*Mw]# [batch_size * num_windows, num_heads, Mh * Mw, Mh * Mw]attn = attn + relative_position_bias.unsqueeze(0)if mask is not None:# mask: [num_windows, Mh*Mw, Mh*Mw]num_window = mask.shape[0]# view: [batch_size, num_windows, num_heads, Mh * Mw, Mh * Mw]# mask: [1, num_windows, 1, Mh*Mw, Mh*Mw]attn = attn.view(B_ // num_window, num_window, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)# [batch_size*num_windows, num_heads, Mh * Mw, Mh * Mw]attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn)attn = self.attn_drop(attn)# @: [batch_size*num_windows, num_heads, Mh * Mw, embed_dim_per_head]# transpose: [batch_size*num_windows, Mh * Mw, num_heads, embed_dim_per_head]# reshape: [num_windows, Mh * Mw, num_heads*embed_dim_per_head]x = (attn @ v).transpose(1, 2).reshape(B_, N, C)x = self.proj(x)x = self.proj_drop(x)return xclass SwimTransformerBlock(nn.Module):def __init__(self, dim, num_heads, window_size=7, shift_size=0., mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):super(SwimTransformerBlock, self).__init__()self.dim = dimself.num_heads = num_headsself.window_size = window_sizeself.shift_size = shift_sizeself.mlp_ratio = mlp_ratioself.norm1 = norm_layer(dim)self.attn = WindowAttention(dim=dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = (int(dim * mlp_ratio))self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act=act_layer, drop=drop)def forward(self, x, attn_mask):H, W = self.H, self.W # feature map H WB, L, C = x.shape # L = H * Wshortcut = xx = self.norm1(x)x = x.view(B, H, W, C)x_r = (self.window_size - W % self.window_size) % self.window_sizex_d = (self.window_size - H % self.window_size) % self.window_sizex = F.pad(x, (0, 0, 0, x_r, 0, x_d))_, Hp, Wp, _ = x.shape # Hp Wp代表padding后的H Wif self.shift_size > 0.:# SW-MSA 从上往下 从左往右shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) # 上面的shift size移动到下面 左边移动右边else:# W-MSAshifted_x = xattn_mask = None# 特征图切成小窗口x_windows = window_partition(shifted_x, self.window_size) # [B * window_num, MH, MW, C]x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # [B * window_num, MH*MW, C]# W-MSA SW-MSAattn_windows = self.attn(x_windows, mask=attn_mask)# 小窗口合并成特征图attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # [B * window_num, MH, MW, C]shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # [B, H, W, C]# SW-MSA后还原数据 从下往上 从右往左if self.shift_size > 0:x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))else:x = shifted_x# 移除paddingif x_r > 0 or x_d > 0:x = x[:, :H, :W, :].contiguous()x = x.view(B, H * W, C)x = shortcut + self.drop_path(x)x = x + self.drop_path(self.mlp(self.norm2(x)))return xdef window_partition(x, window_size: int):# 将feature map(image mask) 按照 window_size的大小 划分成一个个没有重叠的windowB, H, W, C = x.shape# [B, H//M, W//M, M, C] M: window_sizex = x.view(B, H // window_size, window_size, W // window_size, window_size, C)# permute: [B, H//M, M, W//M, M, C] -> [B, H//M, W//M, M, M, C]# contiguous(): 变为内存连续的数据# view: [B, H//M, W//M, M, M, C] -> [B * window_num, MH, MW, C] 第一个M为窗口H, 第二个M为窗口Wwindows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windowsdef window_reverse(windows, window_size: int, H: int, W:int):# 将窗口还原成一个feature map. H,W 代表分割之前的H,W# windows: [B * window_num , MH, MW, C]B = int(windows.shape[0] / (H * W / window_size / window_size))# view: [B * window_num , MH, MW, C] -> [B, H//M, W//M, M, M, C]x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)# permute: [B, H//M, W//M, M, M, C] -> [B, H//M, M, W//M, M, C]# view: [B, H//M, M, W//M, M, C] -> [B, H, W, C]x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)return xclass BasicLayer(nn.Module):# 每个stage的实现def __init__(self, dim, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None):super(BasicLayer, self).__init__()self.dim = dimself.depth = depthself.window_size = window_sizeself.shift_size = window_size // 2 # 窗口大小/2 向下取整 窗口向右向下移动的步长self.blocks = nn.ModuleList([SwimTransformerBlock(dim=dim, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else self.shift_size,mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,norm_layer=norm_layer) for i in range(depth)])if downsample is not None:self.downsample = downsample(dim=dim, norm_layer=norm_layer)else:self.downsample = Nonedef create_mask(self, x, H, W):# 保证H、W可以被window size整除 ceil 向上取整H_padding = int(np.ceil(H / self.window_size)) * self.window_sizeW_padding = int(np.ceil(W / self.window_size)) * self.window_size# 和feature map一样的通道排列顺序img_mask = torch.zeros((1, H_padding, W_padding, 1), device=x.device) # [B, H, W, C]# slice 切片 切出每个窗口中分别具有相似元素的位置。h_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))w_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.window_size),slice(-self.window_size, None))cnt = 0for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1mask_windows = window_partition(img_mask, self.window_size) # [B * window_num , MH, MW, C]mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # [B * window_num * C, MH*MW]# [B * window_num * C, 1, MH*MW] - [B * window_num * C, MH*MW, 1] 广播机制 -> [B * window_num * C, MH*MW, MH*MW]attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.)).masked_fill(attn_mask == 0, float(0.0))return attn_maskdef forward(self, x, H, W):attn_mask = self.create_mask(x, H, W) # [B * window_num * C, MH*MW, MH*MW]for blk in self.blocks:blk.H, blk.W = H, Wx = blk(x, attn_mask)if self.downsample is not None:x = self.downsample(x, H, W)H, W = (H + 1) // 2, (W + 1) // 2return x, H, Wclass SwimTransformer(nn.Module):def __init__(self, downsapmle_size=4, in_channels=3, num_classes=1000, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), window_size=7, mlp_ratio=4.,qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, patch_norm=True, **kwargs):super(SwimTransformer, self).__init__()self.num_classes = num_classesself.num_layers = len(depths)self.embed_dim = embed_dimself.patch_norm = patch_norm# stage4 输出的特征矩阵的Channelself.num_features = int(embed_dim * 2 ** (self.num_layers - 1))self.mlp_ratio = mlp_ratioself.patch_embed = patchEmbed(patch_size=downsapmle_size, in_channels=in_channels, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None)self.pos_drop = nn.Dropout(p=drop_rate)dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]self.layers = nn.ModuleList()for i_layer in range(self.num_layers):layers = BasicLayer(dim=int(embed_dim * 2 ** i_layer),depth=depths[i_layer],num_heads=num_heads[i_layer],window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate,drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=patchmerging if (i_layer < self.num_layers - 1) else None)self.layers.append(layers)# classificationself.norm = norm_layer(self.num_features)self.avgpool = nn.AdaptiveAvgPool1d(1)self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()def _init_weights(self, m):if isinstance(m, nn.Linear):nn.init.trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)def forward(self, x):# [B, L, C]x, H, W = self.patch_embed(x)x = self.pos_drop(x)for layer in self.layers:x, H, W = layer(x, H, W)x = self.norm(x)# [B, L, C]x = self.avgpool(x.transpose(1, 2)) # [B, C, 1]x = torch.flatten(x, 1)x = self.head(x)return xif __name__ == '__main__':x = torch.randn(1, 3, 224, 224)model = SwimTransformer()out = model(x)print(out.shape)

基于pytorch的Swim Transformer代码实现与讲解相关推荐

  1. Swim Transformer代码环境配置纪录

    Swim Transformer使用的代码:https://github.com/FZfangzheng/Swin-Transformer-Semantic-Segmentation-Without- ...

  2. PyTorch Cookbook(常用代码合集)

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale学习 作者:张皓,南京大学,编辑:机器学习实验室 众所周知,程序猿在写代码时通 ...

  3. PyTorch Cookbook(常用代码段集锦)

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 来自 | 知乎   作者 | 张皓 链接 | https://zh ...

  4. 从CNN到Transformer:基于PyTorch的遥感影像、无人机影像的地物分类、目标检测、语义分割和点云分类

    更多资讯,请关注:Ai尚研修科研技术动态 公众号 我国高分辨率对地观测系统重大专项已全面启动,高空间.高光谱.高时间分辨率和宽地面覆盖于一体的全球天空地一体化立体对地观测网逐步形成,将成为保障国家安全 ...

  5. 基于Pytorch的Transformer翻译模型前期数据处理方法

    基于Pytorch的Transformer翻译模型前期数据处理方法 Google于2017年6月在arxiv上发布了一篇非常经典的文章:Attention is all you need,提出了解决s ...

  6. 使用google的bert结合哈工大预训练模型进行中文/英文文本二分类,基于pytorch和transformer

    使用bert的哈工大预训练模型进行中文/英文文本二分类,基于pytorch和transformer 前提 简要介绍 开始 导入必要的包和环境 准备并读取数据 导入模型的tokenizer 对数据进行t ...

  7. 【代码实战】基于pytorch实现中文文本分类任务

    点击上方,选择星标或置顶,不定期资源大放送! 阅读大概需要15分钟 Follow小博主,每天更新前沿干货 来自 | 知乎 地址 | https://zhuanlan.zhihu.com/p/73176 ...

  8. 【深度学习】梯度和方向导数概念解析(代码基于Pytorch实现)

    [深度学习]梯度和方向导数概念解析(代码基于Pytorch实现) 文章目录 1 方向导数 2 梯度 3 自动求导实现 4 梯度下降4.1 概述4.2 小批量梯度下降 5 总结 1 方向导数 方向导数的 ...

  9. Pytorch之CNN:基于Pytorch框架实现经典卷积神经网络的算法(LeNet、AlexNet、VGG、NIN、GoogleNet、ResNet)——从代码认知CNN经典架构

    Pytorch之CNN:基于Pytorch框架实现经典卷积神经网络的算法(LeNet.AlexNet.VGG.NIN.GoogleNet.ResNet)--从代码认知CNN经典架构 目录 CNN经典算 ...

最新文章

  1. 设计模式之简单工厂模式
  2. [UT]Unit Test理解
  3. 领域适配前沿研究——场景、方法与模型选择
  4. c语言fopen函数打不开,[讨论]用fopen函数无法打开文件有哪几种情况
  5. Matlab实用程序--图形应用-三维曲线图
  6. Java黑皮书课后题第8章:***8.35(最大块)给定一个元素为0或者1的方阵,编写程序,找到一个元素都为1的最大的子方阵。程序提示用户输入矩阵的行数。然后显示最大的子方阵的第一个元素、行数
  7. resque java_php-resque :基于Redis的后台任务系统
  8. C语言数组初始化的问题
  9. 关于云原生,这是最详细的技术知识
  10. 火狐修复神秘的严重漏洞,同时影响Chrome 浏览器
  11. 【深度优先搜索】计蒜客:引爆炸弹
  12. Linux 网卡设备驱动程序设计(3)
  13. python 中主线程结束 子线程还在运行么_Python爬虫进阶(二)爬虫之多任务模块(Ⅰ)...
  14. perl查看文件,提取指定信息输出到文件
  15. spss之数据预处理
  16. 测绘类(遥感)投稿期刊(卓越期刊、EI、CSCD、核心期刊)汇总,内附投稿链接
  17. wps批量删除sheet 批量删除工作簿
  18. 《Towards Conversational Recommendation over Multi-Type Dialogs》解读,融合多个对话类型的对话式推荐
  19. 一文让你理解什么是shallow heap及retained heap
  20. Cascading Style Sheet层叠样式表

热门文章

  1. 【MongoDB】索引属性>稀疏(间隙)索引(Sparse Indexes)
  2. Swagger(Api接口管理)
  3. 用python画路飞代码_用Python画路飞骷髅图案【附代码】-Go语言中文社区
  4. python随机生成手机号和密码
  5. linux查看防火墙
  6. 2020最新java头条校招实习面经解析(建议收藏系列)
  7. KeyPress 和KeyDown 、KeyPress之间的区别
  8. 输入当天日期,求第二天的日期
  9. 程序员过中秋的一百种方式
  10. hazptr folly