论文地址:CCNet: Criss-Cross Attention for Semantic Segmentation

  代码地址:CCNet github

文章目录

  • 一、简介
  • 二、结构
    • 1、CCNet结构
    • 2、Criss-Cross Attention
    • 3、Recurrent Criss-Cross Attention
    • 4、代码
  • 三、结果

一、简介

  CCNet是2018年11月发布的一篇语义分割方面的文章中提到的网络,该网络有三个优势:

  • GPU内存友好;
  • 计算高效;
  • 性能好。

  CCNet之前的论文比如FCNs只能管制局部特征和少部分的上下文信息,空洞卷积只能够集中于当前像素而无法生成密集的上下文信息,虽然PSANet能够生成密集的像素级的上下文信息但是计算效率过低,其计算复杂度高达O((H*W)*(H*\W))。因此可以明显的看出,CCNet的目的是高效的生成密集的像素级的上下文信息。
  Cirss-Cross Attention Block的参数对比如下图所示:

  CCNet论文的主要贡献:

  • 提出了Cirss-Cross Attention Module;
  • 提出了高效利用Cirss-Cross Attention Module的CCNet。

二、结构

1、CCNet结构

  CCNet的网络结构如下图所示:

  CCNet的基本结构描述如下:

  • 1、图像通过特征提取网络得到feature map的大小为H∗WH*WH∗W,为了更高效的获取密集的特征图,将原来的特征提取网络中的后面两个下采样去除,替换为空洞卷积,使得feature map的大小为输入图像的1/8;
  • 2、feature map X分为两个分支,分别进入3和4;
  • 3、一个分支先将X进行通道缩减压缩特征,然后通过两个CCA(Cirss-Cross Attention)模块,两个模块共享相同的参数,得到特征H′′H^{''}H′′;
  • 4、另一个分支保持不变为X;
  • 5、将3和4两个分支的特征融合到一起最终经过upsample得到分割图像。

2、Criss-Cross Attention

 Criss-Cross Attention模块的结构如下所示,输入feature为H∈RC∗W∗HH\in \mathbb{R}^{C*W*H}H∈RC∗W∗H,HHH分为Q,K,VQ,K,VQ,K,V三个分支,都通过1*1的卷积网络的进行降维得到Q,K∈RC′∗W∗H{Q,K}\in \mathbb{R}^{C^{'}*W*H}Q,K∈RC′∗W∗H(C′<CC^{'}<CC′<C)。其中Attention Map A∈R(H+W−1)∗W∗HA\in \mathbb{R}^{(H+W-1)*W*H}A∈R(H+W−1)∗W∗H是QQQ和KKK通过Affinity操作计算的。Affinity操作定义为:
di,u=QuΩi,uTd_{i,u}=Q_u\Omega_{i,u}^{T} di,u​=Qu​Ωi,uT​
  其中Qu∈RC′Q_u\in\mathbb{R}^{C^{'}}Qu​∈RC′是在特征图Q的空间维度上的u位置的值。Ωu∈R(H+W−1)C′\Omega_u\in\mathbb{R}^{(H+W-1)C^{'}}Ωu​∈R(H+W−1)C′是KKK上uuu位置处的同列和同行的元素的集合。因此,Ωu,i∈RC′\Omega_{u,i}\in\mathbb{R}^{C^{'}}Ωu,i​∈RC′是Ωu\Omega_uΩu​中的第iii个元素,其中i=[1,2,...,∣Ωu∣]i=[1,2,...,|\Omega_u|]i=[1,2,...,∣Ωu​∣]。而di,u∈Dd_{i,u}\in Ddi,u​∈D表示QuQ_uQu​和Ωi,u\Omega_{i,u}Ωi,u​之间的联系的权重,D∈R(H+W−1)∗W∗HD\in \mathbb{R}^{(H+W-1)*W*H}D∈R(H+W−1)∗W∗H。最后对DDD进行在通道维度上继续进行softmax操作计算Attention Map AAA。
  另一个分支VVV经过一个1*1卷积层得到V∈RC∗W∗HV \in \mathbb{R}^{C*W*H}V∈RC∗W∗H的适应性特征。同样定义Vu∈RCV_u \in \mathbb{R}^CVu​∈RC和Φu∈R(H+W−1)∗C\Phi_u\in \mathbb{R}^{(H+W-1)*C}Φu​∈R(H+W−1)∗C,Φu\Phi_uΦu​是VVV上u点的同行同列的集合,则定义Aggregation操作为:
Hu′∑i∈∣Φu∣Ai,uΦi,u+HuH_u^{'}\sum_{i \in |\Phi_u|}{A_{i,u}\Phi_{i,u}+H_u} Hu′​i∈∣Φu​∣∑​Ai,u​Φi,u​+Hu​
  该操作在保留原有feature的同时使用经过attention处理过的feature来保全feature的语义性质。

3、Recurrent Criss-Cross Attention

  单个Criss-Cross Attention模块能够提取更好的上下文信息,但是下图所示,根据criss-cross attention模块的计算方式左边右上角蓝色的点只能够计算到和其同列同行的关联关系,也就是说相应的语义信息的传播无法到达左下角的点,因此再添加一个Criss-Cross Attention模块可以将该语义信息传递到之前无法传递到的点。

  采用Recurrent Criss-Cross Attention之后,先定义loop=2,第一个loop的attention map为AAA,第二个loop的attention map为A′A^{'}A′,从原feature上位置x′,y′x^{'},y^{'}x′,y′到权重Ai,x,yA_{i,x,y}Ai,x,y​的映射函数为Ai,x,y=f(A,x,y,x′,y′)A_{i,x,y}=f(A,x,y,x^{'},y^{'})Ai,x,y​=f(A,x,y,x′,y′),feature HHH中的位置用θ\thetaθ表示,feature中H′′H^{''}H′′用uuu表示,如果uuu和θ\thetaθ相同则:
Hu′′←[f(A,u,θ)+1]⋅f(A′,u,θ)⋅HθH_u^{''}\leftarrow[f(A,u,\theta)+1]\cdot f(A^{'},u,\theta)\cdot H_{\theta} Hu′′​←[f(A,u,θ)+1]⋅f(A′,u,θ)⋅Hθ​
  其中←\leftarrow←表示加到操作,如果uuu和θ\thetaθ不同则:
Hu′′←[f(A,ux,θy,θx,θy)⋅f(A′,ux,uy,ux,θy)+f(A,θx,uy,θx,θy)⋅f(A′,ux,uy,θx,θy)]⋅HθH_u^{''}\leftarrow[f(A,u_x,\theta_{y}, \theta_{x}, \theta_{y})\cdot f(A^{'},u_x,u_{y}, u_{x}, \theta_{y})+f(A,\theta_x,u_{y}, \theta_{x}, \theta_{y})\cdot f(A^{'},u_x,u_{y}, \theta_{x}, \theta_{y})]\cdot H_{\theta} Hu′′​←[f(A,ux​,θy​,θx​,θy​)⋅f(A′,ux​,uy​,ux​,θy​)+f(A,θx​,uy​,θx​,θy​)⋅f(A′,ux​,uy​,θx​,θy​)]⋅Hθ​
  Cirss-Cross Attention模块可以应用于多种任务不仅仅是语义分割,作者同样在多种任务中使用了该模块,可以参考论文。

4、代码

  下面是Cirss-Cross Attention模块的代码可以看到ca_weight便是Affinity操作,ca_map便是Aggregation操作。

class CrissCrossAttention(nn.Module):""" Criss-Cross Attention Module"""def __init__(self,in_dim):super(CrissCrossAttention,self).__init__()self.chanel_in = in_dimself.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self,x):proj_query = self.query_conv(x)proj_key = self.key_conv(x)proj_value = self.value_conv(x)energy = ca_weight(proj_query, proj_key)attention = F.softmax(energy, 1)out = ca_map(attention, proj_value)out = self.gamma*out + xreturn out

  Affinity操作定义如下:

class CA_Weight(autograd.Function):@staticmethoddef forward(ctx, t, f):# Save contextn, c, h, w = t.size()size = (n, h+w-1, h, w)weight = torch.zeros(size, dtype=t.dtype, layout=t.layout, device=t.device)_ext.ca_forward_cuda(t, f, weight)# Outputctx.save_for_backward(t, f)return weight@staticmethod@once_differentiabledef backward(ctx, dw):t, f = ctx.saved_tensorsdt = torch.zeros_like(t)df = torch.zeros_like(f)_ext.ca_backward_cuda(dw.contiguous(), t, f, dt, df)_check_contiguous(dt, df)return dt, df

  Aggregation操作定义如下:

class CA_Map(autograd.Function):@staticmethoddef forward(ctx, weight, g):# Save contextout = torch.zeros_like(g)_ext.ca_map_forward_cuda(weight, g, out)# Outputctx.save_for_backward(weight, g)return out@staticmethod@once_differentiabledef backward(ctx, dout):weight, g = ctx.saved_tensorsdw = torch.zeros_like(weight)dg = torch.zeros_like(g)_ext.ca_map_backward_cuda(dout.contiguous(), weight, g, dw, dg)_check_contiguous(dw, dg)return dw, dg

  其中使用ext是c库文件:

  RCC模块的实现如下所示:

class RCCAModule(nn.Module):def __init__(self, in_channels, out_channels, num_classes):super(RCCAModule, self).__init__()inter_channels = in_channels // 4self.conva = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),InPlaceABNSync(inter_channels))self.cca = CrissCrossAttention(inter_channels)self.convb = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),InPlaceABNSync(inter_channels))self.bottleneck = nn.Sequential(nn.Conv2d(in_channels+inter_channels, out_channels, kernel_size=3, padding=1, dilation=1, bias=False),InPlaceABNSync(out_channels),nn.Dropout2d(0.1),nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True))def forward(self, x, recurrence=1):output = self.conva(x)for i in range(recurrence):output = self.cca(output)output = self.convb(output)output = self.bottleneck(torch.cat([x, output], 1))return output

  CCNet的整体结构:

class ResNet(nn.Module):def __init__(self, block, layers, num_classes):self.inplanes = 128super(ResNet, self).__init__()self.conv1 = conv3x3(3, 64, stride=2)self.bn1 = BatchNorm2d(64)self.relu1 = nn.ReLU(inplace=False)self.conv2 = conv3x3(64, 64)self.bn2 = BatchNorm2d(64)self.relu2 = nn.ReLU(inplace=False)self.conv3 = conv3x3(64, 128)self.bn3 = BatchNorm2d(128)self.relu3 = nn.ReLU(inplace=False)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.relu = nn.ReLU(inplace=False)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # changeself.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2)self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, multi_grid=(1,1,1))#self.layer5 = PSPModule(2048, 512)self.head = RCCAModule(2048, 512, num_classes)self.dsn = nn.Sequential(nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1),InPlaceABNSync(512),nn.Dropout2d(0.1),nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True))def forward(self, x, recurrence=1):x = self.relu1(self.bn1(self.conv1(x)))x = self.relu2(self.bn2(self.conv2(x)))x = self.relu3(self.bn3(self.conv3(x)))x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x_dsn = self.dsn(x)x = self.layer4(x)x = self.head(x, recurrence)return [x, x_dsn]

三、结果

  与主流的方法的比较:

  下面是不同loop时的效果可以看到loop=2时的效果要比loop=2好。下面是不同loop的attention map。




语义分割之《CCNet: Criss-Cross Attention for Semantic Segmentation》论文阅读笔记相关推荐

  1. CCNet: Criss-Cross Attention for Semantic Segmentation论文读书笔记

    CCNet: Criss-Cross Attention for Semantic Segmentation读书笔记 Criss-Cross Network(CCNet): 作用: 用来获得上下文信息 ...

  2. 轻量级实时语义分割:Guided Upsampling Network for Real-Time Semantic Segmentation

    轻量级实时语义分割:Guided Upsampling Network for Real-Time Semantic Segmentation 介绍 网络设计 Guided unsampling mo ...

  3. 论文阅读:PMF基于视觉感知的多传感器融合点云语义分割Perception-Aware Multi-Sensor Fusion for 3D LiDAR Semantic Segmentation

    题目:Perception-Aware Multi-Sensor Fusion for 3D LiDAR Semantic Segmentation 中文:用于 3D LiDAR 语义分割的多传感器感 ...

  4. 语义分割--RefineNet: Multi-Path Refinement Networks for High-Resolution Semantic Segmentation

    RefineNet: Multi-Path Refinement Networks for High-Resolution Semantic Segmentation CVPR2017 https:/ ...

  5. 论文笔记(FCN网络,语义分割):Fully Convolutional Networks for Semantic Segmentation

    FCN论文笔记:Fully Convolutional Networks for Semantic Segmentation 语义分割模型结构时序: FCN SegNet Dilated Convol ...

  6. 语义分割--Efficient and Robust Deep Networks for Semantic Segmentation

    Efficient and Robust Deep Networks for Semantic Segmentation Code: https://lmb.informatik.uni-freibu ...

  7. 【语义分割】OCRNet:Object-Context Representations for Semantic Segmentation

    文章目录 一.文章出发点 二.方法 三.效果 一.文章出发点 每个像素点的类别(label)应该是它所属目标(object)的类别. 所以这篇文章对像素的上下文信息建模 建模方法:求每个像素点和每个类 ...

  8. CVF2020邻域自适应/语义分割:FDA: Fourier Domain Adaptation for Semantic SegmentationFDA:用于语义分割的傅立叶域自适应算法

    邻域自适应/语义分割:FDA: Fourier Domain Adaptation for Semantic Segmentation FDA:用于语义分割的傅立叶域自适应算法 0.摘要 1.概述 1 ...

  9. 论文阅读笔记:MGAT: Multi-view Graph Attention Networks

    论文阅读笔记:MGAT: Multi-view Graph Attention Networks 文章目录 论文阅读笔记:MGAT: Multi-view Graph Attention Networ ...

  10. RFA-Net: Residual feature attention network for fine-grained image inpainting 论文阅读笔记

    RFA-Net: Residual feature attention network for fine-grained image inpainting 论文阅读笔记 摘要 尽管大多数使用生成对抗性 ...

最新文章

  1. linux apache cpu,linux – Apache使用100%的CPU. “ps”命令可以告诉我它在做什么吗?...
  2. 过滤当前主机的IPV4地址
  3. 基于单片机的超市储物柜设计_基于51单片机对电子储物柜系统的设计
  4. 三朵云 华为_云时代和5G将重构网络结构
  5. mysql allowmultiqueries=true_Mysql批量更新的一个坑-allowMultiQueries=true允许批量更新(转)...
  6. C++static类静态成员函数及变量解析
  7. 毕马威首次发布《初探元宇宙》报告:从科幻畅想到产业风口(附报告下载链接)...
  8. centos7 查看oracle运行日志_Linux(CentOS7)部署系列---常规应用部署方案
  9. IPTV码流分析指标
  10. SQL2008如何建立数据库
  11. 2pin接口耳机_让耳机“轻松一下”—— QDC BTX(耳机蓝牙线)
  12. 免费PDF阅读器都是坑?这些开源神器我可是恨不得所有人都知道
  13. android aso优化工具,如何使用ASO优化工具优化安卓应用商店
  14. 怎样在word中打印框选对√
  15. stm32f4晶振管理
  16. 服务器域共享文件夹,访问域共享文件夹
  17. 上海有哪些牛逼的互联网公司?
  18. 潮汕牛肉丸是熟的还是生的 潮汕牛肉生丸和熟丸区别
  19. 内存颗粒版本判断方法和编号解析(三星、美光、海力士)
  20. 【工作经验分享】,大厂面试经验分享

热门文章

  1. 求两个数的最大公约数和最小公倍数
  2. MATLAB图形用户界面设计(GUI)
  3. 正规手游代理该怎么选?
  4. 中文如何翻译成英文?手机中英文一键翻译超简单
  5. 口算训练 HDU - 6287
  6. 华为5G专利收费标准曝光!原来卖专利真的很挣钱
  7. Gauss 求积公式及代码
  8. Android实现本地图片、视频左右镜像翻转
  9. Excel两行交换及两列交换,快速互换相邻表格数据的方法
  10. 一本通 1194:移动路线