【SOD论文阅读笔记】Visual Saliency Transformer

  • 一、摘要
    • Motivation:
    • Method:
    • Experimental results
  • 二、Introduction
    • 当前最先进的方法以CNN结构为主
    • CNN结构的弊端
    • 引出Transformer
    • 本文中
    • contributions
  • 三、Visual Saliency Transformer
    • Transformer Encoder(T2t_vit_t_14)
    • Transformer Convertor
    • Multi-task Transformer Decoder

一、摘要

Motivation:

现有的SOTA显著性检测方法在很大程度上依赖于基于CNN的网络。可替代地,我们从卷积free的sequence-to-sequence的角度重新考虑此任务,并通过建模长期依赖关系来预测显著性,而这不能通过卷积来实现。

这篇论文的出发点就是利用transformer来创新,并且这篇文章是纯transformer(convolution-free),所以摘要中从transformer和CNN的最大的不同出发来写motivation——即transformer对比CNN来说,是sequence-to-sequence结构的,且更有利于对长期依赖关系建模。

Method:

提出基于纯变压器的模型,即视觉显著性变压器 (VST),用于RGB和RGBD的显著性检测。

  • 以图像补丁为输入,并利用transformer在图像补丁之间传播全局上下文
  • 与视觉变压器 (ViT) 中使用的常规结构不同,我们利用多级token融合,并在变压器框架下提出了一种新的token上采样方法,以获得高分辨率的检测结果。
  • 我们还开发了基于token的多任务解码器,通过引入与任务相关的token和新颖的补丁-任务-注意力机制,同时执行显着性和边界检测。

先解释一下图像补丁。由于transormer是从NLP任务传到CV领域的,在NLP的机器翻译任务中,输入的是一个个单词,所以,把transformer移植到图像任务时,为了与其输入结构保持一致,会把图像切割成不重叠的补丁序列(可以想像一下把一张图片切割成九宫格/N宫格,每一个宫格就是一个补丁)。

再解释一下token。刚刚的图像补丁就可以被称之为一个token,它属于patch token。patch token输入到transformer中后,经过处理得到的feature也可以成为token。此外,transformer中还有一种class token,它本质上就是一个可训练的向量,通常在分类任务中直接通过这个Class token来判断类别。

这篇论文里有一个任务相关的token(task-related tokens),其实相当于tokens的一个头部,代表这个tokens是用于做什么任务的。这是因为,这篇论文提出的是多任务模型,输出的是 显著映射 和 边缘映射,本意是借助边缘的监督提升其显著映射的准确性。

Experimental results

实验结果表明,我们的模型在RGB和RGBD SOD基准数据集上都优于现有方法。

二、Introduction

当前最先进的方法以CNN结构为主

它们通常采用编码器-解码器架构,其中编码器将输入图像编码为多级特征,解码器将提取的特征集成以预测最终的显着性图。

  • RGB-SOD,旨在检测吸引人们眼睛的物体,并可以帮助许多视觉任务。

    • 各种注意力模型,多尺度特征集成方法和多任务学习框架
  • RBGD-SOD,则多了来自深度数据的额外空间结构信息。
    • 各种模态融合方法,如特征融合,知识蒸馏,动态卷积,注意力模型 ,图神经网络 。

CNN结构的弊端

所有方法在学习全局远程依赖方面受到限制

长期以来,全局上下文 和全局对比度 对于显著性检测至关重要。然而,由于cnn在局部滑动窗口中提取特征的内在限制,以前的方法很难利用关键的全局线索。

尽管一些方法利用全连接层,全局池化和非本地模块来合并全局上下文,但它们仅在某些层中这样做,并且基于CNN的体系结构保持不变。

引出Transformer

最近,提出了Transformer用于机器翻译的单词序列之间的全局远程依赖关系。

Transformer的核心思想是自注意机制,它利用query-key的相关性来关联序列中的不同位置。Transformer在编码器和解码器中多次堆叠自注意层,因此可以对每一层中的长距离依赖进行建模。因此,将变压器引入SOD是很自然的,一路利用模型中的全局线索。

本文中

我们从新的序列到序列的角度重新考虑SOD,并基于纯变压器开发了一种新颖的RGB和rgb-d SOD统一模型,称为视觉显着性变压器。

最近提出的ViT模型 [12,74],将每个图像划分为补丁,并在补丁序列上采用变压器模型。然后,变压器在图像补丁之间传播长距离依赖,而无需使用卷积。

然而,将ViT应用于SOD并不简单,存在两大问题:

  • 1.关于密集预测: 如何基于纯变压器执行密集预测任务仍然是一个悬而未决的问题。
    - 我们通过引入与任务相关的token来设计基于token的变压器解码器从而学习决策嵌入。然后,我们提出了一种新颖的补丁-任务-注意力机制来生成密集预测结果,这为在密集预测任务中使用transformer提供了新的范例。
    - 在以前的SOD模型的激励下,利用边界检测来提高SOD性能,我们构建了一个多任务解码器,通过引入显著性token和边界token,同时进行显著性和边界检测。该策略通过简单地学习与任务相关的token来简化多任务预测工作流程,从而大大降低了计算成本,同时获得了更好的结果。
  • 2.关于高分辨率:ViT通常将图像标记为非常粗糙的大小。如何使ViT适应SOD的高分辨率预测需求还不清楚。
    - 受tokens-to-tokens (T2T) 转换 [74] 的启发,该转换减少了tokens的长度,我们提出了一种新的反向T2T转换,通过将每个tokens扩展为多个子tokens来向上采样tokens。然后,我们逐步对补丁tokens进行采样,并将其与低级token融合,以获得最终的全分辨率显着性图。此外,我们还使用交叉模态transformer来深入探索rgb-d SOD的多模态信息之间的相互作用。

在RGB和RGBD数据上,以有可比性的数量的参数和计算成本,优于现有的最先进的SOD方法

contributions

  • 以序列to序列建模的新视角,设计了一种基于纯变压器架构的RGB和rgb-d SOD的新型统一模型。
  • 设计了一种多任务变压器解码器,通过引入任务相关的token和补丁-任务-注意力来联合进行显著性和边界检测
  • 一种新的基于transformer的token上采样方法
  • state-of-the-art结果

三、Visual Saliency Transformer

我们为RGB和RGBD SOD提出的VST模型的整体架构。它首先使用编码器从输入的图像补丁序列中生成多级tokens。然后,采用转换器将补丁tokens转换为解码器空间,并对rgb-d数据进行跨模态信息融合。最后,解码器通过我们提出的与任务相关的token以及补丁-任务-注意机制同时预测显着图和边界图。还提出了一种RT2T转换,以逐步上采样补丁tokens。虚线表示rgb-d SOD的专用成分。

  • 主要组件包括3部分:基于T2T-ViT的变压器encoder (T2t_vit_t_14),用于将补丁tokens从编码器空间转换到解码器空间的变压器转换器 (Transformer),以及多任务变压器decoder (token_Transformer, Decoder)。
class ImageDepthNet(nn.Module):def __init__(self, args):super(ImageDepthNet, self).__init__()# VST Encoderself.rgb_backbone = T2t_vit_t_14(pretrained=True, args=args)# VST Convertorself.transformer = Transformer(embed_dim=384, depth=4, num_heads=6, mlp_ratio=3.)# VST Decoderself.token_trans = token_Transformer(embed_dim=384, depth=4, num_heads=6, mlp_ratio=3.)self.decoder = Decoder(embed_dim=384, token_dim=64, depth=2, img_size=args.img_size)def forward(self, image_Input):B, _, _, _ = image_Input.shape# image_Input [B, 3, 224, 224]# VST Encoderrgb_fea_1_16, rgb_fea_1_8, rgb_fea_1_4 = self.rgb_backbone(image_Input)# rgb_fea_1_16 [B, 14*14, 384]# rgb_fea_1_8 [B, 28*28, 384]# rgb_fea_1_4 [B, 56*56, 384]# VST Convertorrgb_fea_1_16 = self.transformer(rgb_fea_1_16)# rgb_fea_1_16 [B, 14*14, 384]# VST Decodersaliency_fea_1_16, fea_1_16, saliency_tokens, contour_fea_1_16, contour_tokens = self.token_trans(rgb_fea_1_16)# saliency_fea_1_16 [B, 14*14, 384]# fea_1_16 [B, 1 + 14*14 + 1, 384]# saliency_tokens [B, 1, 384]# contour_fea_1_16 [B, 14*14, 384]# contour_tokens [B, 1, 384]outputs = self.decoder(saliency_fea_1_16, fea_1_16, saliency_tokens, contour_fea_1_16, contour_tokens, rgb_fea_1_8, rgb_fea_1_4)# [mask_1_16, mask_1_8, mask_1_4, mask_1_1],[contour_1_16, contour_1_8, contour_1_4, contour_1_1]# mask_1_16/contour_1_16 [B, 1, 14, 14]# mask_1_1/contour_1_1 [B, 1, 224, 224]return outputs

Transformer Encoder(T2t_vit_t_14)

以下是Transformer Encoder的整体框架

class T2T_ViT(nn.Module):def __init__(self, img_size=224, tokens_type='performer', in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,drop_path_rate=0., norm_layer=nn.LayerNorm):super().__init__()self.tokens_to_token = T2T_module(img_size=img_size, tokens_type=tokens_type, in_chans=in_chans, embed_dim=embed_dim)num_patches = self.tokens_to_token.num_patchesself.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(n_position=num_patches + 1, d_hid=embed_dim), requires_grad=False)self.blocks = nn.ModuleList([Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer)for i in range(depth)])self.norm = norm_layer(embed_dim)def forward(self, x):B = x.shape[0]x, x_1_8, x_1_4 = self.tokens_to_token(x)#[B,196,384],[B, 28×28, 384],[B, 56×56, 384]cls_tokens = self.cls_token.expand(B, -1, -1)#[1,1,384]->[B,1,384]x = torch.cat((cls_tokens, x), dim=1)#cat([B,1,384],[B,196,384])->[B,197,384]x = x + self.pos_embed#[B,197,384]+[1,197,384]->[B,197,384]# T2T-ViT backbonefor blk in self.blocks:x = blk(x)#[B,197,384]x = self.norm(x)#[B,197,384]return x[:, 1:, :], x_1_8, x_1_4

可以看出,Transformer Encoder由一个T2T模块和一些后处理步骤构成。
输入:(B,3,224,224)
输出:由于我们做的是像素级分类而不是对象级分类,所以输出了多级特征:fea_1_16 [B, 14×14, 384],fea_1_8 [B, 28×28, 384],fea_1_4 [B, 56×56, 384]。

T2T模块:待会儿详细介绍。
后处理步骤:

  1. 首先,x被concat了一个1维的全零分类tokens,由于其被初始化为0,所以没什么好介绍的。

x = torch.cat((cls_tokens, x), dim=1)

  1. 其次,x被add了一个shape与其shape相同的正弦位置tokens

self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(n_position=num_patches + 1, d_hid=embed_dim), requires_grad=False)
x = x + self.pos_embed

这里对self.pos_embed的初始化是有讲究的,用到的是《Attention is all you need》中提出的正弦位置,参数就是要生成的shape的参数。

3.最后,重复经过depth个Blocks。这里depth设置为14。
每个Block:

class Block(nn.Module):def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):super().__init__()self.norm1 = norm_layer(dim)self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)def forward(self, x):x = x + self.attn(self.norm1(x))x = x + self.mlp(self.norm2(x))return

该过程就是不断Attention、MLP的迭代过程,且输出与输入的shape保持一致[B, 197, 384]。
Attention就是普通多头attention(Linear[通道数扩大三倍]、分为qkv、softmax(q*k)*v,最后再Linear[不改变通道数])

class Attention(nn.Module):def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):super().__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)def forward(self, x):B, N, C = x.shape#[B,197,384]qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)# self.qkv(x):[B,197*3,384]#.reshape(B, N, 3, self.num_heads, C // self.num_heads): [B,197,3,6,64]#.permute(2, 0, 3, 1, 4): [3,B,6,197,64]q, k, v = qkv[0], qkv[1], qkv[2]#[B,6,197,64]attn = (q @ k.transpose(-2, -1)) * self.scale# k.transpose(-2, -1): [B,6,64,197]# q @ k.transpose(-2, -1):[B,6,197,197]attn = attn.softmax(dim=-1)# [B,6,197,197]x = (attn @ v).transpose(1, 2).reshape(B, N, C)# attn @ v : [B,6,197,197] * [B,6,197,64] -> [B,6,197,64]# .transpose(1, 2) : [B,197,6,64]# .reshape(B, N, C) : [B,197,384]x = self.proj(x)#[B,197,384]return x

MLP就是(Linear[通道数扩大3倍]、Gelu激活、Linear[通道数恢复])

class Mlp(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.fc2 = nn.Linear(hidden_features, out_features)def forward(self, x):x = self.fc1(x)#[B,197,384*3]x = self.act(x)x = self.fc2(x)#[B,197,384]return x
  • Tokens to Token模块

给定一系列长度为l的补丁tokens T’,T2T-ViT会连续堆叠T2T模块。
T2T模块是由重构步骤(a re-structurization step: 多头自注意力+多层感知机)和软拆分步骤(a soft split step:unfold)组成的,对T’中的局部结构信息进行建模,并获得新的token序列。
T2T变换可以多次迭代进行。在每次的迭代中,重构步骤首先将以前的token嵌入转换为新的嵌入,并且还在所有token内集成了远程依赖关系。然后,软拆分操作将每个k × k邻居中的token聚合成一个新token,该token准备用于下一层。
此外,当设置s <k-1时,token的长度可以逐渐减小。

个人觉得这里的tokens-to-tokens模块更应该叫做features-to-features模块,因为这个模块的输入是二维的features,进入模块后会先软分割(unfold)变形为1维的向量,即tokens串,然后self-attention,最后再reshape成二维的特征图。

  • 重构步骤 a re-structurization step
    tokens T’会首先使用一个transformer层,获得一个新的tokens T∈Rl×cT∈R^{l×c}TRl×c
    transformer层: MSA 多头自注意力+MLP多层感知机
    之后,T会被reshape为2维图像I∈Rh×w×c,从而恢复空间结构
  • 软拆分步骤 a soft split step
    与ViT不同,T2T-ViT中采用的重叠补丁拆分在相邻补丁中引入了局部对应关系,从而带来了空间先验。
    I∈Rh×w×cI∈R^{h×w×c}IRh×w×c首先会给边界补上p个0,之后被拆分为重叠区域为s的k×k个补丁块。
    然后图像补丁块会被展开成一系列tokens To∈Rlo×ck2T_{o}∈ R^{l_{o}×ck^{2}}ToRlo×ck2
  • 具体设置:我们按照 [74] 首先将输入图像软分割成补丁,然后两次迭代T2T模块。在三个软拆分步骤中,补丁大小设置为k = [7,3,3],重叠映射设置为s = [3,1,1],填充大小设置为p = [2,1,1]。因此,我们可以获得多级tokensT1 ∈ Rl1 × c,T2 ∈ Rl2 × c和T3 ∈ Rl3 × c。给定输入图像的宽度和高度分别为H和W,则l1 = H /4 × W/ 4,l2 = H/8 × W/8,l3 = H/16 × W/16。我们遵循 [74] 设置c = 64,并使用t3上的线性投影层将其嵌入尺寸从c转换为d = 384。
class T2T_module(nn.Module):"""Tokens-to-Token encoding module"""def __init__(self, img_size=224, tokens_type='performer', in_chans=3, embed_dim=768, token_dim=64):super().__init__()if tokens_type == 'transformer':print('adopt transformer encoder for tokens-to-token')self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))self.attention1 = Token_transformer(dim=in_chans * 7 * 7, in_dim=token_dim, num_heads=1, mlp_ratio=1.0)self.attention2 = Token_transformer(dim=token_dim * 3 * 3, in_dim=token_dim, num_heads=1, mlp_ratio=1.0)self.project = nn.Linear(token_dim * 3 * 3, embed_dim)elif tokens_type == 'performer':……elif tokens_type == 'convolution':  # just for comparison with conolution, not our model……self.num_patches = (img_size // (4 * 2 * 2)) * (img_size // (4 * 2 * 2))  # there are 3 sfot split, stride are 4,2,2 seperatelydef forward(self, x):#Input[B,3,224,224]# step0: soft splitx = self.soft_split0(x).transpose(1, 2)# (224 + 2*2 - 7) / 4 + 1 =  56# self.soft_split0(x):[B,147=7*7*3,56*56]# .transpose(1, 2):[B, 56*56, 147=7*7*3]# iteration1: restricturization/reconstructionx_1_4 = self.attention1(x)# [B, 56*56, 64]B, new_HW, C = x_1_4.shapex = x_1_4.transpose(1,2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))#[B,64,56,56]# iteration1: soft splitx = self.soft_split1(x).transpose(1, 2)# self.soft_split1(x) : [B,576=3*3*64,28*28]#.transpose(1, 2) : [B,28*28,576]# iteration2: restricturization/reconstructionx_1_8 = self.attention2(x)#[B,28*28,64]B, new_HW, C = x_1_8.shapex = x_1_8.transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))#[B,64,28,28]# iteration2: soft splitx = self.soft_split2(x).transpose(1, 2)#[B,14*14,576=3*3*64]# final tokensx = self.project(x)#[B,196,384]return x, x_1_8, x_1_4

其中,Token_transformer的结构与上述Block特别相似,都是由Attention和MLP组成。
区别:
Attention中:to_qkv时不再设置为原有通道数的3倍,而是64的3倍,从而实现了通道数的改变;
不再设置multi-head;最终残差相加的不是原来的输入(因为通道数变了,没办法直接加),而是v。

MLP中:两次Linear的通道数没有改变。

Encoder with T2T-ViT Backbone

  • 最后的token序列T3与编码2D位置信息的正弦位置嵌入 [61] add起来。然后,使用LεL^{\varepsilon}Lε transformer层对T3之间的长期依赖进行建模,以提取强大的补丁token嵌入Tε∈Rl3×dT^{\varepsilon} ∈ R^{l_{3} × d}TεRl3×d
  • SOD:应用1个transformer encoder将RGB图像编码为补丁tokens Trε∈Rl3×dT_{r}^{\varepsilon} ∈ R^{l_{3} × d}TrεRl3×d
  • RSOD:应用双流transformer encoder,将深度图像以同样的方式编码为补丁tokens Tdε∈Rl3×dT_{d}^{\varepsilon} ∈ R^{l_{3} × d}TdεRl3×d

Transformer Convertor

我们在变压器编码器和解码器之间插入一个转换器模块,以将编码器补丁tokensTE ∗ 从编码器空间转换到解码器空间,从而获得转换后的补丁tokensTc ∈ Rl3 × d。从输出的shape可以看出,这里特征的形状并没有改变。

  • RGB-D Convertor
  • RGB Convertor
  • transforner层:多个Block+layernorm
    Block:
    x = x+self-attention(layernorm(x))
    x = x+mlp(layernorm(x))

与刚刚Transformer Encoder中最后进行的多个Block的完全一样,这次设置了4个Block,加上刚刚的14个,相当于让fea_1_16经历了18次Attention+MLP。

class TransformerEncoder(nn.Module):def __init__(self, depth, num_heads, embed_dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,drop_path_rate=0., norm_layer=nn.LayerNorm):super(TransformerEncoder, self).__init__()self.blocks = nn.ModuleList([Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,norm_layer=norm_layer)for i in range(depth)])self.rgb_norm = norm_layer(embed_dim)def forward(self, rgb_fea):for block in self.blocks:rgb_fea = block(rgb_fea)rgb_fea = self.rgb_norm(rgb_fea)return

这里不改变输入的shape,输入该模块的是fea_1_16[B,14×14,384],输出的仍然是fea_1_16[B,14×14,384]。

Multi-task Transformer Decoder

这个模块在论文中的思路已经在思维导图中写了,以下按照代码思路串一遍。
刚刚在总框架代码中写了,decoder实际上包含了两部分:token_Transformer, Decoder。

def __init__(self, args):……# VST Decoderself.token_trans = token_Transformer(embed_dim=384, depth=4, num_heads=6, mlp_ratio=3.)self.decoder = Decoder(embed_dim=384, token_dim=64, depth=2, img_size=args.img_size)
def forward(self, image_Input):……# VST Decodersaliency_fea_1_16, fea_1_16, saliency_tokens, contour_fea_1_16, contour_tokens = self.token_trans(rgb_fea_1_16)# saliency_fea_1_16 [B, 14*14, 384]# fea_1_16 [B, 1 + 14*14 + 1, 384]# saliency_tokens [B, 1, 384]# contour_fea_1_16 [B, 14*14, 384]# contour_tokens [B, 1, 384]outputs = self.decoder(saliency_fea_1_16, fea_1_16, saliency_tokens, contour_fea_1_16, contour_tokens, rgb_fea_1_8, rgb_fea_1_4)# [mask_1_16, mask_1_8, mask_1_4, mask_1_1],[contour_1_16, contour_1_8, contour_1_4, contour_1_1]# mask_1_16/contour_1_16 [B, 1, 14, 14]# mask_1_1/contour_1_1 [B, 1, 224, 224]return outputs

首先看 token_Transformer,
这部分主要引入了与任务相关的token以及patch-任务-注意力。
它的输入是fea_1_16[B,14×14,384],输出了5部分:

  • 代表saliency任务的任务tokens: saliency_tokens [B, 1, 384]
  • 代表saliency任务的特征tokens:saliency_fea_1_16 [B, 14×14, 384]
  • 代表边缘任务的任务tokens: contour_tokens [B, 1, 384]
  • 代表边缘任务的特征tokens:contour_fea_1_16 [B, 14*14, 384]
  • 总的特征tokens:fea_1_16 [B, 1 + 14×14 + 1, 384]
class token_Transformer(nn.Module):def __init__(self, embed_dim=384, depth=14, num_heads=6, mlp_ratio=3.):super(token_Transformer, self).__init__()self.norm = nn.LayerNorm(embed_dim)self.mlp_s = nn.Sequential(nn.Linear(embed_dim, embed_dim),nn.GELU(),nn.Linear(embed_dim, embed_dim),)self.saliency_token = nn.Parameter(torch.zeros(1, 1, embed_dim))self.contour_token = nn.Parameter(torch.zeros(1, 1, embed_dim))self.encoderlayer = token_TransformerEncoder(embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio)self.saliency_token_pre = saliency_token_inference(dim=embed_dim, num_heads=1)self.contour_token_pre = contour_token_inference(dim=embed_dim, num_heads=1)def forward(self, rgb_fea):B, _, _ = rgb_fea.shapefea_1_16 = self.mlp_s(self.norm(rgb_fea))   # [B, 14*14, 384]saliency_tokens = self.saliency_token.expand(B, -1, -1) # [B, 1, 384]fea_1_16 = torch.cat((saliency_tokens, fea_1_16), dim=1) # [B, 1+14*14, 384]contour_tokens = self.contour_token.expand(B, -1, -1) # [B, 1, 384]fea_1_16 = torch.cat((fea_1_16, contour_tokens), dim=1) #[B, 1 + 14*14 + 1, 384]fea_1_16 = self.encoderlayer(fea_1_16)# fea_1_16 [B, 1 + 14*14 + 1, 384]saliency_tokens = fea_1_16[:, 0, :].unsqueeze(1) # [B, 1, 384]contour_tokens = fea_1_16[:, -1, :].unsqueeze(1) # [B, 1, 384]saliency_fea_1_16 = self.saliency_token_pre(fea_1_16) # [B, 14*14, 384]contour_fea_1_16 = self.contour_token_pre(fea_1_16) # [B, 14*14, 384]return saliency_fea_1_16, fea_1_16, saliency_tokens, contour_fea_1_16,

这里,token_TransformerEncoder与刚刚的Transformer Convertor设置完全一样,仍然是4个多头注意力Attention+MLP组成的blocks。

重点介绍一下saliency_token_inference和contour_token_inference。
它们俩的输入都是总的特征tokens fea_1_16 [B, 1 + 14×14 + 1, 384],输出的是分别代表saliency和边缘的特征tokens: [B, 14×14, 384] 。

saliency_token_inference:

class saliency_token_inference(nn.Module):def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):super().__init__()self.norm = nn.LayerNorm(dim)self.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5self.q = nn.Linear(dim, dim, bias=qkv_bias)self.k = nn.Linear(dim, dim, bias=qkv_bias)self.v = nn.Linear(dim, dim, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)self.sigmoid = nn.Sigmoid()def forward(self, fea):B, N, C = fea.shapex = self.norm(fea)T_s, F_s = x[:, 0, :].unsqueeze(1), x[:, 1:-1, :]# T_s [B, 1, 384]  F_s [B, 14*14, 384]q = self.q(F_s).reshape(B, N-2, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)#[B,196,1,384]->[B,1,196,384]k = self.k(T_s).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)#[B,1,1,384]->[B,1,1,384]v = self.v(T_s).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)#[B,1,1,384]->[B,1,1,384]attn = (q @ k.transpose(-2, -1)) * self.scale#[B,1,196,384]*[B,1,384,1]->[B,1,196,1]attn = self.sigmoid(attn)attn = self.attn_drop(attn)infer_fea = (attn @ v).transpose(1, 2).reshape(B, N-2, C)#[B,1,196,1]*[B,1,1,384]->[B,1,196,384]->[B,196,1,384]->[B,196,384]infer_fea = self.proj(infer_fea)#[B,196,384]infer_fea = self.proj_drop(infer_fea)infer_fea = infer_fea + fea[:, 1:-1, :]#[B,196,384]return infer_fea

contour_token_inference与saliency_token_inference一样,只不过在取任务token时,取的是-1位。

接下来介绍Decoder。
这部分主要是反T2T的上采样,以及多级特征融合。
输入的是7部分,包括刚刚第一部分的decoder的输出,以及 encoder输出的fea_1_8和 fea_1_4。

  • saliency_fea_1_16 [B, 14*14, 384]
  • fea_1_16 [B, 1 + 14*14 + 1, 384]
  • saliency_tokens [B, 1, 384]
  • contour_fea_1_16 [B, 14*14, 384]
  • contour_tokens [B, 1, 384]
  • fea_1_8 [B, 28*28, 64]
  • fea_1_4 [B, 56*56, 64]
class Decoder(nn.Module):def __init__(self, embed_dim=384, token_dim=64, depth=2, img_size=224):super(Decoder, self).__init__()self.norm = nn.LayerNorm(embed_dim)self.mlp = nn.Sequential(nn.Linear(embed_dim, embed_dim),nn.GELU(),nn.Linear(embed_dim, token_dim),)self.norm_c = nn.LayerNorm(embed_dim)self.mlp_c = nn.Sequential(nn.Linear(embed_dim, embed_dim),nn.GELU(),nn.Linear(embed_dim, token_dim),)self.img_size = img_size# token upsampling and multi-level token fusionself.decoder1 = decoder_module(dim=embed_dim, token_dim=token_dim, img_size=img_size, ratio=8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), fuse=True)self.decoder2 = decoder_module(dim=embed_dim, token_dim=token_dim, img_size=img_size, ratio=4, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), fuse=True)self.decoder3 = decoder_module(dim=embed_dim, token_dim=token_dim, img_size=img_size, ratio=1, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2), fuse=False)self.decoder3_c = decoder_module(dim=embed_dim, token_dim=token_dim, img_size=img_size, ratio=1, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2), fuse=False)# token based multi-task predictionsself.token_pre_1_8 = token_trans(in_dim=token_dim, embed_dim=embed_dim, depth=depth, num_heads=1)self.token_pre_1_4 = token_trans(in_dim=token_dim, embed_dim=embed_dim, depth=depth, num_heads=1)# predict saliency mapsself.pre_1_16 = nn.Linear(token_dim, 1)self.pre_1_8 = nn.Linear(token_dim, 1)self.pre_1_4 = nn.Linear(token_dim, 1)self.pre_1_1 = nn.Linear(token_dim, 1)# predict contour mapsself.pre_1_16_c = nn.Linear(token_dim, 1)self.pre_1_8_c = nn.Linear(token_dim, 1)self.pre_1_4_c = nn.Linear(token_dim, 1)self.pre_1_1_c = nn.Linear(token_dim, 1)for m in self.modules():classname = m.__class__.__name__if classname.find('Conv') != -1:nn.init.xavier_uniform_(m.weight),if m.bias is not None:nn.init.constant_(m.bias, 0)elif classname.find('Linear') != -1:nn.init.xavier_uniform_(m.weight),if m.bias is not None:nn.init.constant_(m.bias, 0)elif classname.find('BatchNorm') != -1:nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)def forward(self, saliency_fea_1_16, token_fea_1_16, saliency_tokens, contour_fea_1_16, contour_tokens, rgb_fea_1_8, rgb_fea_1_4):# saliency_fea_1_16 [B, 14*14, 384]# contour_fea_1_16 [B, 14*14, 384]# token_fea_1_16  [B, 1 + 14*14 + 1, 384] (contain saliency token and contour token)# saliency_tokens [B, 1, 384]# contour_tokens [B, 1, 384]# rgb_fea_1_8 [B, 28*28, 64]# rgb_fea_1_4 [B, 56*56, 64]B, _, _, = token_fea_1_16.size()saliency_fea_1_16 = self.mlp(self.norm(saliency_fea_1_16))# saliency_fea_1_16 [B, 14*14, 64]mask_1_16 = self.pre_1_16(saliency_fea_1_16)# mask_1_16 [B,14*14,1]mask_1_16 = mask_1_16.transpose(1, 2).reshape(B, 1, self.img_size // 16, self.img_size // 16)# mask_1_16 [B,1,14,14]contour_fea_1_16 = self.mlp_c(self.norm_c(contour_fea_1_16))# contour_fea_1_16 [B, 14*14, 64]contour_1_16 = self.pre_1_16_c(contour_fea_1_16)contour_1_16 = contour_1_16.transpose(1, 2).reshape(B, 1, self.img_size // 16, self.img_size // 16)# 1/16 -> 1/8# reverse T2T and fuse low-level featurefea_1_8 = self.decoder1(token_fea_1_16[:, 1:-1, :], rgb_fea_1_8)# token predictionsaliency_fea_1_8, contour_fea_1_8, token_fea_1_8, saliency_tokens, contour_tokens = self.token_pre_1_8(fea_1_8, saliency_tokens, contour_tokens)# predict saliency maps and contour mapsmask_1_8 = self.pre_1_8(saliency_fea_1_8)mask_1_8 = mask_1_8.transpose(1, 2).reshape(B, 1, self.img_size // 8, self.img_size // 8)contour_1_8 = self.pre_1_8_c(contour_fea_1_8)contour_1_8 = contour_1_8.transpose(1, 2).reshape(B, 1, self.img_size // 8, self.img_size // 8)# 1/8 -> 1/4fea_1_4 = self.decoder2(token_fea_1_8[:, 1:-1, :], rgb_fea_1_4)# token predictionsaliency_fea_1_4, contour_fea_1_4, token_fea_1_4, saliency_tokens, contour_tokens = self.token_pre_1_4(fea_1_4, saliency_tokens, contour_tokens)# predict saliency maps and contour mapsmask_1_4 = self.pre_1_4(saliency_fea_1_4)mask_1_4 = mask_1_4.transpose(1, 2).reshape(B, 1, self.img_size // 4, self.img_size // 4)contour_1_4 = self.pre_1_4_c(contour_fea_1_4)contour_1_4 = contour_1_4.transpose(1, 2).reshape(B, 1, self.img_size // 4, self.img_size // 4)# 1/4 -> 1saliency_fea_1_1 = self.decoder3(saliency_fea_1_4)contour_fea_1_1 = self.decoder3_c(contour_fea_1_4)mask_1_1 = self.pre_1_1(saliency_fea_1_1)mask_1_1 = mask_1_1.transpose(1, 2).reshape(B, 1, self.img_size // 1, self.img_size // 1)contour_1_1 = self.pre_1_1_c(contour_fea_1_1)contour_1_1 = contour_1_1.transpose(1, 2).reshape(B, 1, self.img_size // 1, self.img_size // 1)return [mask_1_16, mask_1_8, mask_1_4, mask_1_1], [contour_1_16, contour_1_8, contour_1_4, contour_1_1]

核心在于decoder_module模块。
我们用出现的第一个decoder_module模块为例,它的参数设置为:

self.decoder1 = decoder_module(dim=384, token_dim=64, img_size=224, ratio=8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), fuse=True)

输入的是token_fea_1_16的中间段(即去掉两头的任务token,留下feature token)[B,196,384]
以及rgb_fea_1_8 [B, 28*28, 64]

fea_1_8 = self.decoder1(token_fea_1_16[:, 1:-1, :], rgb_fea_1_8)

下面是decoder_module

class decoder_module(nn.Module):def __init__(self, dim=384, token_dim=64, img_size=224, ratio=8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), fuse=True):super(decoder_module, self).__init__()self.project = nn.Linear(token_dim, token_dim * kernel_size[0] * kernel_size[1])self.upsample = nn.Fold(output_size=(img_size // ratio,  img_size // ratio), kernel_size=kernel_size, stride=stride, padding=padding)self.fuse = fuseif self.fuse:self.concatFuse = nn.Sequential(nn.Linear(token_dim*2, token_dim),nn.GELU(),nn.Linear(token_dim, token_dim),)self.att = Token_performer(dim=token_dim, in_dim=token_dim, kernel_ratio=0.5)# project input feature to 64 dimself.norm = nn.LayerNorm(dim)self.mlp = nn.Sequential(nn.Linear(dim, token_dim),nn.GELU(),nn.Linear(token_dim, token_dim),)def forward(self, dec_fea, enc_fea=None):if self.fuse:# from 384 to 64#[B,14*14,384]->[B,14*14,64]dec_fea = self.mlp(self.norm(dec_fea))# [1] token upsampling by the proposed reverse T2T module#由于要扩大feature的面积,所以要改变通道#[B,14*14,64]->[B,14*14,64*3*3]dec_fea = self.project(dec_fea)#[B,14*14,64*3*3]->[B,64*3*3,14*14]->[B,64,28,28]dec_fea = self.upsample(dec_fea.transpose(1, 2))B, C, _, _ = dec_fea.shape#[B,64,28*28]->[B,28*28,64]dec_fea = dec_fea.view(B, C, -1).transpose(1, 2)# [B, HW, C]if self.fuse:# [2] fuse encoder fea and decoder fea#concat([B,28*28,64],[B, 28*28, 64])->[B, 28*28, 128]->[B, 28*28, 64]dec_fea = self.concatFuse(torch.cat([dec_fea, enc_fea], dim=2))#[B, 28*28, 64]dec_fea = self.att(dec_fea)return

这里的att不同于以上的Token_transformer。
以上的Token_transformer是由多头Attention+MLP(通道数先扩大再缩小)组成。
而此处的att由token_performer和MLP(通道数保持不变)组成。

class Token_performer(nn.Module):def __init__(self, dim, in_dim, head_cnt=1, kernel_ratio=0.5, dp1=0.1, dp2 = 0.1):super().__init__()self.emb = in_dim * head_cnt # we use 1, so it is no need hereself.kqv = nn.Linear(dim, 3 * self.emb)self.dp = nn.Dropout(dp1)self.proj = nn.Linear(self.emb, self.emb)self.head_cnt = head_cntself.norm1 = nn.LayerNorm(dim)self.norm2 = nn.LayerNorm(self.emb)self.epsilon = 1e-8  # for stable in divisionself.mlp = nn.Sequential(nn.Linear(self.emb, 1 * self.emb),nn.GELU(),nn.Linear(1 * self.emb, self.emb),nn.Dropout(dp2),)self.m = int(self.emb * kernel_ratio)self.w = torch.randn(self.m, self.emb)self.w = nn.Parameter(nn.init.orthogonal_(self.w) * math.sqrt(self.m), requires_grad=False)def prm_exp(self, x):# part of the function is borrow from https://github.com/lucidrains/performer-pytorch# and Simo Ryu (https://github.com/cloneofsimo)# ==== positive random features for gaussian kernels ====# x = (B, T, hs)# w = (m, hs)# return : x : B, T, m# SM(x, y) = E_w[exp(w^T x - |x|/2) exp(w^T y - |y|/2)]# therefore return exp(w^Tx - |x|/2)/sqrt(m)xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, self.m) / 2wtx = torch.einsum('bti,mi->btm', x.float(), self.w)return torch.exp(wtx - xd) / math.sqrt(self.m)def single_attn(self, x):k, q, v = torch.split(self.kqv(x), self.emb, dim=-1)kp, qp = self.prm_exp(k), self.prm_exp(q)  # (B, T, m), (B, T, m)D = torch.einsum('bti,bi->bt', qp, kp.sum(dim=1)).unsqueeze(dim=2)  # (B, T, m) * (B, m) -> (B, T, 1)kptv = torch.einsum('bin,bim->bnm', v.float(), kp)  # (B, emb, m)y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.emb) + self.epsilon)  # (B, T, emb)/Diag# skip connection# y = v + self.dp(self.proj(y))  # same as token_transformer in T2T layer, use v as skip connectiony = self.dp(self.proj(y))return ydef forward(self, x):x = x + self.single_attn(self.norm1(x))x = x + self.mlp(self.norm2(x))return x

【SOD论文阅读笔记】Visual Saliency Transformer相关推荐

  1. 论文阅读笔记:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

    论文阅读笔记:Swin Transformer 摘要 1 简介 2 相关工作 3 方法论 3.1 总览 Swin Transformer block 3.2 shifted window-based ...

  2. Dynamic MDETR: A Dynamic Multimodal Transformer Decoder for Visual Grounding 论文阅读笔记

    Dynamic MDETR: A Dynamic Multimodal Transformer Decoder for Visual Grounding 论文阅读笔记 一.Abstract 二.引言 ...

  3. [论文翻译] Visual Saliency Transformer

    (2021.8.4)注意:本文仍正在施工中,实验以及RGB-D SOD部分尚未翻译 论文地址:https://arxiv.org/abs/2104.12099 代码:https://github.co ...

  4. Visual Attribute Transfer through Deep Image Analogy论文阅读笔记

    Visual Attribute Transfer through Deep Image Analogy论文阅读笔记 介绍 论文提出了一种新的两张图片直接进行视觉属性迁移的方法.该方法针对的是两张具有 ...

  5. Visual Saliency Transformer 读后感

    最近一直在忙着搞各种杂活 好久没有认真的读文章了 之前做的是 SOD,因为transformer的兴起,就想试试它在SOD里的效果.因为论文审稿意见下来,拖了很久,一直没有实现,现在先读一下这篇pap ...

  6. PolyFormer: Referring Image Segmentation as Sequential Polygon Generation 论文阅读笔记

    PolyFormer: Referring Image Segmentation as Sequential Polygon Generation 论文阅读笔记 一.Abstract 二.引言 三.相 ...

  7. 异常检测阅读笔记《Inpainting Transformer for Anomaly Detection》CVPR 2021

    异常检测阅读笔记<Inpainting Transformer for Anomaly Detection> CVPR 2021 来源:2021年的CVPR,原文论链接 论文的方向是图像方 ...

  8. 解决参考图像分割中的随机性问题:MMNet: Multi-Mask Network for Referring Image Segmentation 论文阅读笔记

    解决参考图像分割中的随机性问题:MMNet: Multi-Mask Network for Referring Image Segmentation 论文阅读笔记 一.Abstract 二.引言 三. ...

  9. DCP(Deep Closest Point)论文阅读笔记以及详析

    DCP论文阅读笔记 前言 本文中图片仓库位于github,所以如果阅读的时候发现图片加载困难.建议挂个梯子. 作者博客:https://codefmeister.github.io/ 转载前请联系作者 ...

最新文章

  1. Thrift的接口定义语言IDL
  2. 《C++primer》第二章--变量和基本内置类型
  3. c# 如何找到项目中图片的相对路径
  4. 【SIS-OAS 1.52.0】【C03-测试报告】常规版本回归测试报告-------回归测试报告模板...
  5. java 子类 父类 转换_Java子类与父类之间的类型转换
  6. Flutter 权限申请
  7. 查看游戏帧数:FPS的软件
  8. ANSYS 2020R2 workbench汉化的方法
  9. 微信小程序入门4-扫普通二维码进入小程序、打开短链接进入小程序
  10. 基于51单片机的铂电阻PT100温度计proteus仿真
  11. 论一个好翻译的重要性
  12. foxmail 不知道这样的主机
  13. 图解:麦肯锡工作术!
  14. linux内核 4g拨号,openwrt 基于qmi的 3G|4G拨号
  15. bway ESL电竞联赛十六季C组对战前瞻 三组战队情报分析
  16. 说说Oracle分区
  17. 有效前沿,CAMP, CAL, SML
  18. 我00后,会Python,月薪5000,兼职1.5w
  19. Java游戏项目之王者荣耀
  20. 【仓储管理系统需求分析(四)】

热门文章

  1. 承德医学院计算机信息,研究生院 信息发布
  2. 《Mysql是怎样运行的》读书笔记之B+树索引
  3. Extraction of individual trees based on Canopy Height Model to monitor the state of the forest
  4. 【最简易c语言】有一篇文章,共有3行文字,每行有80个字符。要求分别统计其中英文大写字母、小写字母、数字、空格以及其他字符的个数。
  5. 天空之城 单音版
  6. 与阿里云整个生态体系共同成长,更快更好的为房地产行业客户提供高价值的服务。
  7. NRZ码位同步原理及FPGA实现--CDR
  8. @SuppressWarnings是什么意思?
  9. js单行代码------数组
  10. IIS中应用程序池和站点通过命令启停方法