Swim_transformer

model

整体架构

  1. 首先图片经过Patch_Embeding操作,将图片分成patch,和vit前置操作一样,只不过这个大小是4*4
  2. 将得到的patch图片送入Stage,每个stage都由不同数量的block组成,上图为[2,2,6,2]
  3. 将得到的向量送入head分类头,就完成了

1

2

2.

下面将详细讲解每一部分

Patch_Embeding


'''将图片裁剪成patch_size大小的一个个patch,经过了Patch_Embeding操作,我们得到了[batch,num_patches,embed_dim]大小的向量
'''
class Patch_Embeding(nn.Module):def __init__(self, dim=96, patch_size=4):super().__init__()# 96=4*4*3*2# 将3维图片转为96维度,然后对每个(4*4)的patch进行扫描,和VIT一样self.patch = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)self.norm = nn.LayerNorm(dim)def forward(self, x):x = self.patch(x)  # [B, C, H, W] , C = dimx = x.flatten(2).transpose(1, 2)  # [B, num_patches, C]x = self.norm(x)# x=[batch_size,num_patches,C]return x

Swin_stage

下面将向量送入第一个stage,每个stage将生成depth个block层
所以每个stage做一次block层,做一次patch_merge操作


'''每个stage相当于特征金字塔的一个层
'''
class Swin_stage(nn.Module):def __init__(self,depth,#每个block深度dim,#输入的维度num_heads,#多头注意力input_res,#输入特征图的h,wwindow_size,#窗口数量qkv_bias=None,#自注意力的偏置patch_merging=None#是否将patch进行合并):super().__init__()# 根据每个stage的深度进行堆叠blockself.blocks = nn.ModuleList([Swin_Block(dim=dim,num_heads=num_heads,input_res=input_res,window_size=window_size,qkv_bias=qkv_bias,shift_size=0 if (i % 2 == 0) else window_size // 2 #根据depth决定是否进行移位操作)for i in range(depth)])if patch_merging is None:self.patch_merge = nn.Identity()else:self.patch_merge = Patch_Merging(input_res, dim)def forward(self, x):# 由于patch_size为4,所以总共是56*56个patch# 第一次进入的特征图为[b,56*56,96]for block in self.blocks:# 见3.1x = block(x)# 见4.1x = self.patch_merge(x)return x

每个block要经过layernorm层,注意力操作,layernorm层,MLP层,最后输出

Swin_Block

# swin_encode & Patch_Merging
class Swin_Block(nn.Module):def __init__(self, dim, num_heads, input_res, window_size, qkv_bias=False, shift_size=0):super().__init__()self.dim = dim#输入维度Cself.resolution = input_res#当前特征图的H,Wself.num_heads = num_headsself.window_size = window_sizeself.shift_size = shift_sizeself.atten_norm = nn.LayerNorm(dim)self.atten = window_attention(dim, num_heads, qkv_bias)self.mlp_norm = nn.LayerNorm(dim)self.mlp = MLP(dim, mlp_ratio=4)def forward(self, x):# x:[B, num_patches, embed_dim]# resolution是每个特征图的大小# [56,56]-->[28,28]-->[14,14]-->[7,7]H, W = self.resolutionB, N, C = x.shapeassert N == H * Wh = xx = self.atten_norm(x)# 展平,方便移动窗口x = x.reshape(B, H, W, C)# 第一次进入block没有平移操作,等下面再讲,可以跳过shift_size>0这步if self.shift_size > 0:shift_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))atten_mask = generate_mask(input_res=self.resolution, window_size=self.window_size,shift_size=self.shift_size)else:shift_x = xatten_mask = None# 将特征图划分为窗口大小,对每个窗口做自注意力操作# [B*num_patches, window_size, window_size, C]x_window = window_partition(shift_x, self.window_size)  # reshapex_window = x_window.reshape(-1, self.window_size * self.window_size, C)# 自注意力操作atten_window = self.atten(x_window, mask=atten_mask)  # [B*num_patches, window_size*window_size, C]# 重新reshape回来atten_window = atten_window.reshape(-1, self.window_size, self.window_size, C)# 再将每个窗口还原回去每个patch大小的维度x = window_reverse(atten_window, self.window_size, H, W)  # [B, H, W, C]x = x.reshape(B, -1, C)# resnetx = h + xh = xx = self.mlp_norm(x)# MLP操作x = self.mlp(x)x = h + xreturn x

注意力操作 window_attention

为了使得窗口间有交互,做自注意力。将A,B,C向左向上移动,填充到右下角.

import numpy as np
import matplotlib.pylab as plt
import torch
data=np.array([[1,2,2,3],[4,5,5,6],[4,5,5,6],[7,8,8,9]
])
shift_x = torch.roll(torch.from_numpy(data), shifts=(-1, -1), dims=(0, 1))
plt.matshow(data)
plt.matshow(shift_x.numpy())
plt.show()

没有移动前的图

进行roll移动后的图

关于掩码可以看这个window_mask
我们需要对每个窗口做自注意力,但是3和6不应该做,1和2也不应该做,(4,5,7,8)也不应该相互做,所以我们需要掩码操作


# 对于不需要计算的部分产生一个大的负数-100,这样softmax之后就是0
def generate_mask(input_res, window_size, shift_size):H, W, = input_res# 保证H、W可以被window size整除 ceil 向上取整Hp = int(np.ceil(H / window_size)) * window_sizeWp = int(np.ceil(W / window_size)) * window_sizeimage_mask = torch.zeros((1, Hp, Wp, 1))h_slice = (slice(0, -window_size),slice(-window_size, -shift_size),slice(-shift_size, None))w_slice = (slice(0, -window_size),slice(-window_size, -shift_size),slice(-shift_size, None))cnt = 0for h in h_slice:for w in w_slice:image_mask[:, h, w, :] = cntcnt += 1# 将mask划分成一个个窗口# [B * window_num , Hp, Wp, C]mask_window = window_partition(image_mask, window_size)# 将每一个窗口内的元素展平# [B * window_num * C, Hp*Wp]mask_window = mask_window.reshape(-1, window_size * window_size)# [B * window_num * C, 1, Hp*Wp] - [B * window_num * C, Hp*Wp, 1] 广播机制 -> [B * window_num * C, Hp*Wp, Hp*Wp]#见下attn_mask = mask_window.unsqueeze(1) - mask_window.unsqueeze(2)# 将不等于0的值变为-100,将等于0的值变为0attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))return attn_mask

对于一个窗口,只有数字相同我们才做自注意力,如下图所示.对于数字不同的,我们应该加上掩码,不用其做自注意力。

mask_window.unsqueeze(1) - mask_window.unsqueeze(2)

例子:

import torch
import numpy as np
import matplotlib.pylab as plt
a=np.array([1,1,3,4,4])
b=aaa=torch.from_numpy(a).unsqueeze(0)
bb=torch.from_numpy(b).unsqueeze(1)
print(f'-------------')
print(aa)
print(f'-------------')
print(bb)
print(aa.shape,bb.shape)
dd=aa-bb
print(dd)
print(dd.shape)
plt.matshow(dd[:,:])
plt.show()

我们可以看到,可以相乘的地方会变成0,然后可以masked_fill操作

这是上面github链接的代码例子,可以帮助理解掩码

import torchimport matplotlib.pyplot as pltdef window_partition(x, window_size):"""Args:x: (B, H, W, C)window_size (int): window sizeReturns:windows: (num_windows*B, window_size, window_size, C)"""B, H, W, C = x.shapex = x.view(B, H // window_size, window_size, W // window_size, window_size, C)windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windowswindow_size = 7
shift_size = 3
H, W = 14, 14
img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
h_slices = (slice(0, -window_size),slice(-window_size, -shift_size),slice(-shift_size, None))
w_slices = (slice(0, -window_size),slice(-window_size, -shift_size),slice(-shift_size, None))
cnt = 0
for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1mask_windows = window_partition(img_mask, window_size)  # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)attn_mask = mask_windows.unsqueeze(2) - mask_windows.unsqueeze(1)
print(attn_mask.shape)
squemask=attn_mask
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))# 给不同的区域上不同的色
plt.matshow(img_mask[0, :, :, 0].numpy())plt.matshow(attn_mask[0].numpy())
plt.matshow(attn_mask[1].numpy())
plt.matshow(attn_mask[2].numpy())
plt.matshow(attn_mask[3].numpy())plt.show()

这个例子中窗口大小是7X7的,但是我们得到的掩码却是49X49的,那是为什么呢?
因为我们的掩码是添加在Q@K之后的
Q,K是49维的向量,做自注意力时相乘变为49X49,此时想要添加掩码,掩码也是49X49大小的
然后再乘V变回来

patch_merge

直接见图


class Patch_Merging(nn.Module):def __init__(self, input_res, dim):super().__init__()self.resolution = input_resself.dim = dimself.reduction = nn.Linear(4 * dim, 2 * dim)self.norm = nn.LayerNorm(2 * dim)def forward(self, x):# x: [B, num_patches, C]H, W = self.resolutionB, _, C = x.shapex = x.reshape(B, H, W, C)# Focus操作x0 = x[:, 0::2, 0::2, :]x1 = x[:, 0::2, 1::2, :]x2 = x[:, 1::2, 0::2, :]x3 = x[:, 1::2, 1::2, :]x = torch.cat((x0, x1, x2, x3), -1)x = x.reshape(B, -1, 4 * C)x = self.reduction(x)x = self.norm(x)return x

相对位置编码可以看这个博客

Swim_transformer相关推荐

最新文章

  1. 基于光流的3D速度检测
  2. 团队作业4——第一次项目冲刺(Alpha版本)-第一篇
  3. 107. Leetcode 123. 买卖股票的最佳时机 III (动态规划-股票交易)
  4. 鸿蒙HarmonyOS环境搭建与HelloWrold应用运行
  5. 关于汇编跟C/C++已经java的内存理解
  6. [Android1.5]打开多个Activity,返回到第一个Activity的问题
  7. LoadRunner 技巧之THML 与 URL两种录制模式分析
  8. 标准C++类std::string的内存共享和Copy-On-Write(写时拷贝)
  9. 如何使用CPU来加速你的Linux命令
  10. bzoj 2245 [SDOI2011]工作安排【最小费用最大流】
  11. java并发:join源码分析
  12. Apache配置问题
  13. 正宗eMule官方网站导航
  14. SyntaxError: invalid syntax都可能是是什么错!!!(持续更新)
  15. linux apache 404配置文件,apache httpd服务器404错误跳转配置教程
  16. 网络安全--SQL注入介绍
  17. 函数的可重入和不可重入
  18. 谱定理、瑞利熵、PCA(主成分分析)、clustering algorihtm
  19. Integer128==128?false
  20. 用条码标签打印软件制作双排或多排标签

热门文章

  1. [内网端口映射]内网端口映射ubuntu
  2. 9.7 StringTokenizer类
  3. win10引导启动设置方法
  4. POWER BI:服务器端全套部署
  5. 输入URL之后会执行什么流程?
  6. C++Static 静态函数调用非静态函数
  7. Java学习笔记:探索yzk18-commons库
  8. 教你搭建礼品卡券兑换、会员积分福利商城
  9. three.js夜间树林摇曳h5动画js特效
  10. 黑鲨的“游戏”人生:寻找手机红海之中一抹蓝