2. Window & Shifted Window based Self-Attention

Swin Transformer另一个重要的改进就是window-based的self-attention layer,之前提到过,ViT的一个缺点是计算复杂度是和patch数量成平方关系的,为了减少计算量,Swin的做法是将输入图片划分成不重合的windows,然后在不同的window内进行self-attention计算。假设一个图片有  的patches,每个window包含MxM个patches,作者将其成为window based self-attention(W-MSA)layer,W-MSA和multi-head self-attention(MSA)的计算复杂度分别为:

由于window内部的patch数量远小于图片patch数量,并且window数量是保持不变的,W-MSA的计算复杂度和图像尺寸呈线性关系,从而大大降低了模型的计算复杂度。

虽然W-MSA能够降低计算复杂度,但是不重合的window之间缺乏信息交流,这样其实就失去了transformer利用self-attention从全局构建关系的能力,于是文章进一步引入shifted window partition来跨window进行信息交流,作者将其成为shifted window based self-attention(SW-MSA)。

如上图所示,Layer 1中8x8尺寸feature map划分成2x2个patch,每个patch尺寸为4x4, 通过将patch位置整体平移1/2个patch大小,在下一层得到新的window,包括3x3个不重合的patch。移动window的划分方式使上一层相邻的不重合window之间引入连接,大大的增加了感受野。

但这样做带来的另一个问题就是window内部patch的数量从原本的4个增加到了9个,为了让patch数量保持不变,如下图所示,作者的解决思路是把平移之后左上角A,B,C部分的patch与右下角不满足4x4尺度的patch拼接,这样patch的数量还是4个,但是又满足了window外的信息交互,作者将其成为cyclic shift。

window的划分与合并

# window_partition是划分,window_reverse是合并
def window_partition(x, window_size):"""Args:x: (B, H, W, C)window_size (int): window sizeReturns:windows: (num_windows*B, window_size, window_size, C)"""B, H, W, C = x.shapex = x.reshape([B, H // window_size, window_size, W // window_size, window_size, C])windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape([-1, window_size, window_size, C])return windowsdef window_reverse(windows, window_size, H, W):"""Args:windows: (num_windows*B, window_size, window_size, C)window_size (int): Window sizeH (int): Height of imageW (int): Width of imageReturns:x: (B, H, W, C)"""B = int(windows.shape[0] / (H * W / window_size / window_size))x = windows.reshape([B, H // window_size, W // window_size, window_size, window_size, -1])x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, H, W, -1])return x

Window Attention

这是这篇文章的关键。传统的Transformer都是基于全局来计算注意力的,因此计算复杂度十分高。而Swin Transformer则将注意力的计算限制在每个窗口内,进而减少了计算量。

我们先简单看下公式

 主要区别是:在原始计算Attention的公式中的Q,K时加入了相对位置编码。后续实验有证明相对位置编码的加入提升了模型性能。

在原始计算Attention的公式中的Q,K时加入了相对位置编码

总体代码:

window_size =(2,2)
coords_flatten = np.array([[0, 0, 1, 1],[0, 1, 0, 1]])
new = torch.tensor(coords_flatten)new_first=new[:, :, None]  # (2,4,1)
new_second=new[:, None, :] # (2,1,4)relative_coords = (new_first-new_second).permute(1, 2, 0).contiguous() # (4,4,2):4个4行2列的矩阵
relative_coords[:, :, 0] += window_size[0] - 1  # 在每个矩阵的第0列加1
relative_coords[:, :, 1] += window_size[1] - 1  # 在每个矩阵的第1列加1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1 # 在每个矩阵的第0列*3
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww 将4个矩阵的每一个矩阵按行求和,作为新的行(4,2)==》(1,4)

下面我把涉及到相关位置编码的逻辑给单独拿出来,这部分比较绕

首先QK计算出来的Attention张量形状为(numWindows*B, num_heads, window_size*window_size, window_size*window_size)

而对于Attention张量来说,以不同元素为原点,其他元素的坐标也是不同的,以window_size=2为例,其相对位置编码如下图所示

首先我们利用torch.arangetorch.meshgrid函数生成对应的坐标,这里我们以windowsize=2为例子

关于torch.meshgrid函数请看: 【pytorch】torch.meshgrid()==>常用于生成二维网格,比如图像的坐标点_小马牛的博客-CSDN博客

coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.meshgrid([coords_h, coords_w]) # -> 2*(wh, ww)
"""
此时是两个张量,每个 张量是一个2*2的二维矩阵(tensor([[0, 0],[1, 1]]), tensor([[0, 1],[0, 1]]))
"""

然后堆叠起来,展开为一个二维向量

coords = torch.stack(coords)  # 2, Wh, Ww
# 将两个tensor堆叠起来,就变成了一个2*2*2的三维矩阵
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
# 再将三维矩阵压扁为2*4的一个二维矩阵
"""
tensor([[0, 0, 1, 1],[0, 1, 0, 1]])
"""

利用广播机制,分别在第二维,第一维,插入一个维度,进行广播相减,得到 2, wh*ww, wh*ww的张量

解释:

relative_coords_first = coords_flatten[:, :, None] # 2, wh*ww, 1

relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh*ww

relative_coords_first的None是在第二维,那么就是在第二维插入了一个维度1,就是由 2, wh*ww ==>2, wh*ww, 1

relative_coords_second的None是在第一维,那么就是在第一维插入了一个维度1,就是由 2, wh*ww ==>2, 1,wh*ww

图二a :

relative_coords_first = coords_flatten[:, :, None]  # 2, wh*ww, 1
relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh*ww
relative_coords = relative_coords_first - relative_coords_second # 最终得到 2, wh*ww, wh*ww 形状的张量
# 此处加上下面这一句就是图二的第一个
relative_coords = relative_coords.permute(1, 2, 0).contiguous()

因为采取的是相减,所以得到的索引是从负数开始的

图二a对应的索引矩阵:

图二b :

因为采取的是相减,所以得到的索引是从负数开始的,我们加上偏移量,让其从0开始

relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1

此处相当于在矩阵中每个加(1,1),相当于在索引矩阵上加2

图二c :

后续我们需要将其展开成一维偏移量。而对于(1,2)和(2,1)这两个坐标。在二维上是不同的,但是通过将x,y坐标相加转换为一维偏移的时候,他的偏移量是相等的

所以最后我们对其中做了个乘法操作,以进行区分

relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1

图二:

图二d :

然后再最后一维上进行求和,展开成一个一维坐标,并注册为一个不参与网络学习的变量

relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)

现在得到了位置索引,那么下一步就是输入进计算公式(4)进行forward

特征图移位操作

代码里对特征图移位是通过torch.roll来实现的,下面是示意图

如果需要reverse cyclic shift的话只需把参数shifts设置为对应的正数值。

Attention Mask

我认为这是Swin Transformer的精华,通过设置合理的mask,让Shifted Window Attention在与Window Attention相同的窗口个数下,达到等价的计算结果。

首先我们对Shift Window后的每个窗口都给上index,并且做一个roll操作(window_size=2, shift_size=1)

我们希望在计算Attention的时候,让具有相同index QK进行计算,而忽略不同index QK计算结果。例如下图的5353与5353进行计算,那么只会5与5计算,3与3计算,而不会5与3计算

最后正确的结果如下图所示

而要想在原始四个窗口下得到正确的结果,我们就必须给Attention的结果加入一个mask

相关代码如下:

        if self.shift_size > 0:# calculate attention mask for SW-MSAH, W = self.input_resolutionimg_mask = torch.zeros((1, H, W, 1))  # 1 H W 1h_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))w_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))cnt = 0for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1mask_windows = mask_windows.view(-1, self.window_size * self.window_size)attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

以上图的设置,我们用这段代码会得到这样的一个mask

tensor([[[[[   0.,    0.,    0.,    0.],[   0.,    0.,    0.,    0.],[   0.,    0.,    0.,    0.],[   0.,    0.,    0.,    0.]]],[[[   0., -100.,    0., -100.],[-100.,    0., -100.,    0.],[   0., -100.,    0., -100.],[-100.,    0., -100.,    0.]]],[[[   0.,    0., -100., -100.],[   0.,    0., -100., -100.],[-100., -100.,    0.,    0.],[-100., -100.,    0.,    0.]]],[[[   0., -100., -100., -100.],[-100.,    0., -100., -100.],[-100., -100.,    0., -100.],[-100., -100., -100.,    0.]]]]])

在之前的window attention模块的前向代码里,包含这么一段

        if mask is not None:nW = mask.shape[0]attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)

将mask加到attention的计算结果,并进行softmax。mask的值设置为-100,softmax后就会忽略掉对应的值

图解Swin Transformer - 知乎

SOTA 模型 Swin Transformer 是如何炼成的! - 极市社区

Swin Transformer(W-MSA详解)代码+图解相关推荐

  1. 深度学习之图像分类(十九)-- Bottleneck Transformer(BoTNet)网络详解

    深度学习之图像分类(十九)Bottleneck Transformer(BoTNet)网络详解 目录 深度学习之图像分类(十九)Bottleneck Transformer(BoTNet)网络详解 1 ...

  2. 五分钟搞懂后缀数组!后缀数组解析以及应用(附详解代码)

    为什么学后缀数组 后缀数组是一个比较强大的处理字符串的算法,是有关字符串的基础算法,所以必须掌握. 学会后缀自动机(SAM)就不用学后缀数组(SA)了?不,虽然SAM看起来更为强大和全面,但是有些SA ...

  3. NLP:Transformer的架构详解之详细攻略(持续更新)

    NLP:Transformer的架构详解之详细攻略(持续更新) 目录 Transformer的架构详解 1. Encoder 1.1.Positional Encoding-数据预处理的部分 1.2. ...

  4. win7 64位操作系统中Oracle 11g + plsql安装教程详解(图解)

    这篇文章主要介绍了win7 64位操作系统中Oracle 11g + plsql安装教程详解(图解),详细的介绍了Oracle 11g 安装的步骤,有兴趣的可以了解一下. 先去网上把下面列表里的文件下 ...

  5. 『ML笔记』HOG特征提取原理详解+代码

    HOG特征提取原理详解+代码! 文章目录 一. HOG特征介绍 二. HOG算法具体流程+代码 2.1. 图像灰度化和gamma矫正 2.2. 计算图像像素梯度图 2.3. 在8×8的网格中计算梯度直 ...

  6. Transformer模型详解(图解最完整版)

    前言 Transformer由论文<Attention is All You Need>提出,现在是谷歌云TPU推荐的参考模型.论文相关的Tensorflow的代码可以从GitHub获取, ...

  7. 计算机代码 w6,蓝魔w6hd的参数介绍和刷机教程详解【图解】

    手机和电脑的逐渐普及,让人们已经逐渐的习惯数码产品的存在,但是随着时间的推移,人们对于数码产品的要求也越来越高.手机虽然方便,但是毕竟它的功能还是没有电脑强大,而且屏幕还很小.电脑虽然功能强大,但是携 ...

  8. xvid 详解 代码分析 编译等

    1.   Xvid参数详解 众所周知,Mencoder以其极高的压缩速率和不错的画质赢得了很多朋友的认同! 原来用Mencoder压缩Xvid的AVI都是使用Xvid编码器的默认设置,现在我来给大家冲 ...

  9. Pytorch中 nn.Transformer的使用详解与Transformer的黑盒讲解

    文章目录 本文内容 将Transformer看成黑盒 Transformer的推理过程 Transformer的训练过程 Pytorch中的nn.Transformer nn.Transformer简 ...

最新文章

  1. C++拷贝构造函数详解
  2. DownloadProvider 源码详细分析
  3. 动态规划 —— 背包问题 P08 —— 泛化物品背包
  4. [深度学习-NLP]什么是Self-attention, Muti-attention和Transformer
  5. java mybits架构图_java架构之路-(mybatis源码)mybatis执行流程源码解析
  6. C语言编写汇编的编译器,用c编写一个asm的编译器
  7. 根号x_8.八年级数学:根号(2a1)=12a,怎么求a的取值范围?二次根式
  8. spark 的RDD各种转换和动作
  9. 【java】 jsp网页session和application,全局变量方法
  10. 拓端tecdat|R语言社区发现算法检测心理学复杂网络:spinglass、探索性图分析walktrap算法与可视化
  11. 徐州一姑娘写的(女孩看了是自省,男孩看了是激励)
  12. 华为mate40鸿蒙系统用久了会卡吗,华为mate40用多久会卡_华为mate40能流畅使用多久...
  13. PHP文件处理--打开文件
  14. (一)ROS中新建机器人模型(urdf格式)并用rviz显示
  15. 10个重要的算法C语言实现源代码:拉格朗日,牛顿插值,高斯,龙贝格,牛顿迭代,牛顿-科特斯,雅克比,秦九昭,幂法,高斯塞德尔 (转帖)
  16. GBASE 8s UDR内存管理_04_mi_zalloc
  17. php统计页面访问量_PHP 统计 网页 总访问次数 附代码
  18. 访问ftp服务器网页,访问ftp服务器是网页
  19. 获取Google Advertising ID作为唯一识别码
  20. RoboCup仿真3D底层通信模块介绍(一)

热门文章

  1. 【jeecg Docker安装】使用 Docker 搭建 Java Web 运行环境
  2. JEECG支付服务窗专题 - 平台与服务窗接口对接
  3. 【JEECG技术博文】JEECG国际化介绍
  4. Jeecg平台扩展性不好的地方收集启动。
  5. 电脑删除的文件怎么恢复?你要找的方案
  6. 13.Azure流量管理器(上)
  7. Struts2学习(四):Action执行的时候发生了什么
  8. 用 cctld工具创建带有国家代码的IP地址表
  9. Ubuntu下将dash装换成bash
  10. JAVA a --; 与 -- a;