Swin Transformer(W-MSA详解)代码+图解
2. Window & Shifted Window based Self-Attention
Swin Transformer另一个重要的改进就是window-based的self-attention layer,之前提到过,ViT的一个缺点是计算复杂度是和patch数量成平方关系的,为了减少计算量,Swin的做法是将输入图片划分成不重合的windows,然后在不同的window内进行self-attention计算。假设一个图片有 的patches,每个window包含MxM个patches,作者将其成为window based self-attention(W-MSA)layer,W-MSA和multi-head self-attention(MSA)的计算复杂度分别为:
由于window内部的patch数量远小于图片patch数量,并且window数量是保持不变的,W-MSA的计算复杂度和图像尺寸呈线性关系,从而大大降低了模型的计算复杂度。
虽然W-MSA能够降低计算复杂度,但是不重合的window之间缺乏信息交流,这样其实就失去了transformer利用self-attention从全局构建关系的能力,于是文章进一步引入shifted window partition来跨window进行信息交流,作者将其成为shifted window based self-attention(SW-MSA)。
如上图所示,Layer 1中8x8尺寸feature map划分成2x2个patch,每个patch尺寸为4x4, 通过将patch位置整体平移1/2个patch大小,在下一层得到新的window,包括3x3个不重合的patch。移动window的划分方式使上一层相邻的不重合window之间引入连接,大大的增加了感受野。
但这样做带来的另一个问题就是window内部patch的数量从原本的4个增加到了9个,为了让patch数量保持不变,如下图所示,作者的解决思路是把平移之后左上角A,B,C部分的patch与右下角不满足4x4尺度的patch拼接,这样patch的数量还是4个,但是又满足了window外的信息交互,作者将其成为cyclic shift。
window的划分与合并
# window_partition是划分,window_reverse是合并
def window_partition(x, window_size):"""Args:x: (B, H, W, C)window_size (int): window sizeReturns:windows: (num_windows*B, window_size, window_size, C)"""B, H, W, C = x.shapex = x.reshape([B, H // window_size, window_size, W // window_size, window_size, C])windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape([-1, window_size, window_size, C])return windowsdef window_reverse(windows, window_size, H, W):"""Args:windows: (num_windows*B, window_size, window_size, C)window_size (int): Window sizeH (int): Height of imageW (int): Width of imageReturns:x: (B, H, W, C)"""B = int(windows.shape[0] / (H * W / window_size / window_size))x = windows.reshape([B, H // window_size, W // window_size, window_size, window_size, -1])x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, H, W, -1])return x
Window Attention
这是这篇文章的关键。传统的Transformer都是基于全局来计算注意力的,因此计算复杂度十分高。而Swin Transformer则将注意力的计算限制在每个窗口内,进而减少了计算量。
我们先简单看下公式
主要区别是:在原始计算Attention的公式中的Q,K时加入了相对位置编码。后续实验有证明相对位置编码的加入提升了模型性能。
在原始计算Attention的公式中的Q,K时加入了相对位置编码
总体代码:
window_size =(2,2)
coords_flatten = np.array([[0, 0, 1, 1],[0, 1, 0, 1]])
new = torch.tensor(coords_flatten)new_first=new[:, :, None] # (2,4,1)
new_second=new[:, None, :] # (2,1,4)relative_coords = (new_first-new_second).permute(1, 2, 0).contiguous() # (4,4,2):4个4行2列的矩阵
relative_coords[:, :, 0] += window_size[0] - 1 # 在每个矩阵的第0列加1
relative_coords[:, :, 1] += window_size[1] - 1 # 在每个矩阵的第1列加1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1 # 在每个矩阵的第0列*3
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 将4个矩阵的每一个矩阵按行求和,作为新的行(4,2)==》(1,4)
下面我把涉及到相关位置编码的逻辑给单独拿出来,这部分比较绕
首先QK计算出来的Attention张量形状为(numWindows*B, num_heads, window_size*window_size, window_size*window_size)
。
而对于Attention张量来说,以不同元素为原点,其他元素的坐标也是不同的,以window_size=2
为例,其相对位置编码如下图所示
首先我们利用torch.arange
和torch.meshgrid
函数生成对应的坐标,这里我们以windowsize=2
为例子
关于torch.meshgrid
函数请看: 【pytorch】torch.meshgrid()==>常用于生成二维网格,比如图像的坐标点_小马牛的博客-CSDN博客
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.meshgrid([coords_h, coords_w]) # -> 2*(wh, ww)
"""
此时是两个张量,每个 张量是一个2*2的二维矩阵(tensor([[0, 0],[1, 1]]), tensor([[0, 1],[0, 1]]))
"""
然后堆叠起来,展开为一个二维向量
coords = torch.stack(coords) # 2, Wh, Ww
# 将两个tensor堆叠起来,就变成了一个2*2*2的三维矩阵
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
# 再将三维矩阵压扁为2*4的一个二维矩阵
"""
tensor([[0, 0, 1, 1],[0, 1, 0, 1]])
"""
利用广播机制,分别在第二维,第一维,插入一个维度,进行广播相减,得到 2, wh*ww, wh*ww
的张量
解释:
relative_coords_first = coords_flatten[:, :, None] # 2, wh*ww, 1
relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh*ww
relative_coords_first的None是在第二维,那么就是在第二维插入了一个维度1,就是由 2, wh*ww ==>2, wh*ww, 1
relative_coords_second的None是在第一维,那么就是在第一维插入了一个维度1,就是由 2, wh*ww ==>2, 1,wh*ww
图二a :
relative_coords_first = coords_flatten[:, :, None] # 2, wh*ww, 1
relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh*ww
relative_coords = relative_coords_first - relative_coords_second # 最终得到 2, wh*ww, wh*ww 形状的张量
# 此处加上下面这一句就是图二的第一个
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
因为采取的是相减,所以得到的索引是从负数开始的
图二a对应的索引矩阵:
图二b :
因为采取的是相减,所以得到的索引是从负数开始的,我们加上偏移量,让其从0开始。
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
此处相当于在矩阵中每个加(1,1),相当于在索引矩阵上加2
图二c :
后续我们需要将其展开成一维偏移量。而对于(1,2)和(2,1)这两个坐标。在二维上是不同的,但是通过将x,y坐标相加转换为一维偏移的时候,他的偏移量是相等的。
所以最后我们对其中做了个乘法操作,以进行区分
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
图二:
图二d :
然后再最后一维上进行求和,展开成一个一维坐标,并注册为一个不参与网络学习的变量
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
现在得到了位置索引,那么下一步就是输入进计算公式(4)进行forward
特征图移位操作
代码里对特征图移位是通过torch.roll
来实现的,下面是示意图
如果需要
reverse cyclic shift
的话只需把参数shifts
设置为对应的正数值。
Attention Mask
我认为这是Swin Transformer的精华,通过设置合理的mask,让Shifted Window Attention
在与Window Attention
相同的窗口个数下,达到等价的计算结果。
首先我们对Shift Window后的每个窗口都给上index,并且做一个roll
操作(window_size=2, shift_size=1)
我们希望在计算Attention的时候,让具有相同index QK进行计算,而忽略不同index QK计算结果。例如下图的5353与5353进行计算,那么只会5与5计算,3与3计算,而不会5与3计算
最后正确的结果如下图所示
而要想在原始四个窗口下得到正确的结果,我们就必须给Attention的结果加入一个mask
相关代码如下:
if self.shift_size > 0:# calculate attention mask for SW-MSAH, W = self.input_resolutionimg_mask = torch.zeros((1, H, W, 1)) # 1 H W 1h_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))w_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))cnt = 0for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1mask_windows = mask_windows.view(-1, self.window_size * self.window_size)attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
以上图的设置,我们用这段代码会得到这样的一个mask
tensor([[[[[ 0., 0., 0., 0.],[ 0., 0., 0., 0.],[ 0., 0., 0., 0.],[ 0., 0., 0., 0.]]],[[[ 0., -100., 0., -100.],[-100., 0., -100., 0.],[ 0., -100., 0., -100.],[-100., 0., -100., 0.]]],[[[ 0., 0., -100., -100.],[ 0., 0., -100., -100.],[-100., -100., 0., 0.],[-100., -100., 0., 0.]]],[[[ 0., -100., -100., -100.],[-100., 0., -100., -100.],[-100., -100., 0., -100.],[-100., -100., -100., 0.]]]]])
在之前的window attention模块的前向代码里,包含这么一段
if mask is not None:nW = mask.shape[0]attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)
将mask加到attention的计算结果,并进行softmax。mask的值设置为-100,softmax后就会忽略掉对应的值
图解Swin Transformer - 知乎
SOTA 模型 Swin Transformer 是如何炼成的! - 极市社区
Swin Transformer(W-MSA详解)代码+图解相关推荐
- 深度学习之图像分类(十九)-- Bottleneck Transformer(BoTNet)网络详解
深度学习之图像分类(十九)Bottleneck Transformer(BoTNet)网络详解 目录 深度学习之图像分类(十九)Bottleneck Transformer(BoTNet)网络详解 1 ...
- 五分钟搞懂后缀数组!后缀数组解析以及应用(附详解代码)
为什么学后缀数组 后缀数组是一个比较强大的处理字符串的算法,是有关字符串的基础算法,所以必须掌握. 学会后缀自动机(SAM)就不用学后缀数组(SA)了?不,虽然SAM看起来更为强大和全面,但是有些SA ...
- NLP:Transformer的架构详解之详细攻略(持续更新)
NLP:Transformer的架构详解之详细攻略(持续更新) 目录 Transformer的架构详解 1. Encoder 1.1.Positional Encoding-数据预处理的部分 1.2. ...
- win7 64位操作系统中Oracle 11g + plsql安装教程详解(图解)
这篇文章主要介绍了win7 64位操作系统中Oracle 11g + plsql安装教程详解(图解),详细的介绍了Oracle 11g 安装的步骤,有兴趣的可以了解一下. 先去网上把下面列表里的文件下 ...
- 『ML笔记』HOG特征提取原理详解+代码
HOG特征提取原理详解+代码! 文章目录 一. HOG特征介绍 二. HOG算法具体流程+代码 2.1. 图像灰度化和gamma矫正 2.2. 计算图像像素梯度图 2.3. 在8×8的网格中计算梯度直 ...
- Transformer模型详解(图解最完整版)
前言 Transformer由论文<Attention is All You Need>提出,现在是谷歌云TPU推荐的参考模型.论文相关的Tensorflow的代码可以从GitHub获取, ...
- 计算机代码 w6,蓝魔w6hd的参数介绍和刷机教程详解【图解】
手机和电脑的逐渐普及,让人们已经逐渐的习惯数码产品的存在,但是随着时间的推移,人们对于数码产品的要求也越来越高.手机虽然方便,但是毕竟它的功能还是没有电脑强大,而且屏幕还很小.电脑虽然功能强大,但是携 ...
- xvid 详解 代码分析 编译等
1. Xvid参数详解 众所周知,Mencoder以其极高的压缩速率和不错的画质赢得了很多朋友的认同! 原来用Mencoder压缩Xvid的AVI都是使用Xvid编码器的默认设置,现在我来给大家冲 ...
- Pytorch中 nn.Transformer的使用详解与Transformer的黑盒讲解
文章目录 本文内容 将Transformer看成黑盒 Transformer的推理过程 Transformer的训练过程 Pytorch中的nn.Transformer nn.Transformer简 ...
最新文章
- C++拷贝构造函数详解
- DownloadProvider 源码详细分析
- 动态规划 —— 背包问题 P08 —— 泛化物品背包
- [深度学习-NLP]什么是Self-attention, Muti-attention和Transformer
- java mybits架构图_java架构之路-(mybatis源码)mybatis执行流程源码解析
- C语言编写汇编的编译器,用c编写一个asm的编译器
- 根号x_8.八年级数学:根号(2a1)=12a,怎么求a的取值范围?二次根式
- spark 的RDD各种转换和动作
- 【java】 jsp网页session和application,全局变量方法
- 拓端tecdat|R语言社区发现算法检测心理学复杂网络:spinglass、探索性图分析walktrap算法与可视化
- 徐州一姑娘写的(女孩看了是自省,男孩看了是激励)
- 华为mate40鸿蒙系统用久了会卡吗,华为mate40用多久会卡_华为mate40能流畅使用多久...
- PHP文件处理--打开文件
- (一)ROS中新建机器人模型(urdf格式)并用rviz显示
- 10个重要的算法C语言实现源代码:拉格朗日,牛顿插值,高斯,龙贝格,牛顿迭代,牛顿-科特斯,雅克比,秦九昭,幂法,高斯塞德尔 (转帖)
- GBASE 8s UDR内存管理_04_mi_zalloc
- php统计页面访问量_PHP 统计 网页 总访问次数 附代码
- 访问ftp服务器网页,访问ftp服务器是网页
- 获取Google Advertising ID作为唯一识别码
- RoboCup仿真3D底层通信模块介绍(一)