接着上一篇博文介绍2019 CVPR DANet (Dual Attention Network for Scene Segmentation),这一篇DRANet可以看作是DANet进阶版,或者轻量化计算量版本。
原文中表述DANet的问题是虽然没有增加模型参数,但是每个点间和通道间的相关系数求解增加了模型的计算量和GPU内存使用

attention modeling brings a heavy burdern on computation and memory if the number of pixels/channels is huge

于是乎,由原来的relationship between any two pixels/channels(PAM/CAM)替换为relationship between any two pixel/channl and gathering centers(CPAM/CAM)模块

Compact position attention module

相比PAM模块,CPAM主要增加了方框中的内容,目的就是降少计算量。方法:采用了不同尺度的pooling采样,标准像素子集的gathering centers。原文描述如下(可以学习一下如何把一个简单的下采样说得更高逼格)

construct the relationships between each pixel and a few numbers of gathering centers The gathering centers are formally defined as a compact feature vector by gathering feature vectors from a pixel subset in the input tensor. They are implemented by a spatial pyramid pooling scheme that provides context information from different spatial scales.

原文描述:

配合代码,通俗表达:

  • 特征A∈RC×H×W\mathbf{A} \in \mathbb{R}^{C \times H \times W}A∈RC×H×W通过一组不同pooling
    kernel size的下采样后得到L×LRC×L2L \times L \ \mathbb{R}^{C \times L^{2}}L×L RC×L2大小的特征bin(图中的立方体的最小单元),再将所有bin在通道层进行堆叠成F∈RC×M\mathbf{F} \in \mathbb{R}^{C \times M}F∈RC×M

代码对应encoding.nn.dran_att.CPAMEnc

class CPAMEnc(Module):"""CPAM encoding module"""def __init__(self, in_channels, norm_layer):super(CPAMEnc, self).__init__()self.pool1 = AdaptiveAvgPool2d(1)self.pool2 = AdaptiveAvgPool2d(2)self.pool3 = AdaptiveAvgPool2d(3)self.pool4 = AdaptiveAvgPool2d(6)self.conv1 = Sequential(Conv2d(in_channels, in_channels, 1, bias=False),norm_layer(in_channels),ReLU(True))self.conv2 = Sequential(Conv2d(in_channels, in_channels, 1, bias=False),norm_layer(in_channels),ReLU(True))self.conv3 = Sequential(Conv2d(in_channels, in_channels, 1, bias=False),norm_layer(in_channels),ReLU(True))self.conv4 = Sequential(Conv2d(in_channels, in_channels, 1, bias=False),norm_layer(in_channels),ReLU(True))def forward(self, x):b, c, h, w = x.size()feat1 = self.conv1(self.pool1(x)).view(b,c,-1)feat2 = self.conv2(self.pool2(x)).view(b,c,-1)feat3 = self.conv3(self.pool3(x)).view(b,c,-1)feat4 = self.conv4(self.pool4(x)).view(b,c,-1)return torch.cat((feat1, feat2, feat3, feat4), 2)

后面一部分操作就跟DANet基本一样,唯一的区别就是C和D的维度由(HXW)Xd变成了dXM
,进而减少了计算量。其中d是通过一个全连接层和1x1卷积操作进一步压缩通道维度以后的通道维度。这里也可以看出为了减少计算量,对空间和通道都进行了压缩

class CPAMDec(Module):"""CPAM decoding module"""def __init__(self,in_channels):super(CPAMDec,self).__init__()self.softmax  = Softmax(dim=-1)self.scale = Parameter(torch.zeros(1))self.conv_query = Conv2d(in_channels = in_channels , out_channels = in_channels//4, kernel_size= 1) # query_conv2self.conv_key = Linear(in_channels, in_channels//4) # key_conv2self.conv_value = Linear(in_channels, in_channels) # value2def forward(self, x,y):"""inputs :x : input feature(N,C,H,W) y:gathering centers(N,K,M)returns :out : compact position attention featureattention map: (H*W)*M"""m_batchsize,C,width ,height = x.size()m_batchsize,K,M = y.size()proj_query  = self.conv_query(x).view(m_batchsize,-1,width*height).permute(0,2,1)#BxNxdproj_key =  self.conv_key(y).view(m_batchsize,K,-1).permute(0,2,1)#BxdxKenergy =  torch.bmm(proj_query,proj_key)#BxNxKattention = self.softmax(energy) #BxNxkproj_value = self.conv_value(y).permute(0,2,1) #BxCxKout = torch.bmm(proj_value,attention.permute(0,2,1))#BxCxNout = out.view(m_batchsize,C,width,height)out = self.scale*out + xreturn out

整个CPAM模块在encoding.models.sseg.dran.py中的class DranHead中可以看出首尾分别进行了一组3x3的卷积操作

        ## Convs or modules for CPAM self.conv_cpam_b = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),norm_layer(inter_channels),nn.ReLU()) # conv5_sself.cpam_enc = CPAMEnc(inter_channels, norm_layer) # en_sself.cpam_dec = CPAMDec(inter_channels) # de_sself.conv_cpam_e = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),norm_layer(inter_channels),nn.ReLU()) # conv52## Compact Spatial Attention Module(CPAM)cpam_b = self.conv_cpam_b(multix[-1])cpam_f = self.cpam_enc(cpam_b).permute(0,2,1)#BKDcpam_feat = self.cpam_dec(cpam_b,cpam_f)

Compact Channel Attention Module

理解了CPAM 再来搞定CCAM就很简单了。

核心思想就是在CAM的基础上使用1x1卷积压缩通道维度

class CCAMDec(Module):"""CCAM decoding module"""def __init__(self):super(CCAMDec,self).__init__()self.softmax  = Softmax(dim=-1)self.scale = Parameter(torch.zeros(1))def forward(self, x,y):"""inputs :x : input feature(N,C,H,W) y:gathering centers(N,K,H,W)returns :out : compact channel attention featureattention map: K*C"""m_batchsize,C,width ,height = x.size()x_reshape =x.view(m_batchsize,C,-1)B,K,W,H = y.size()y_reshape =y.view(B,K,-1)proj_query  = x_reshape #BXC1XNproj_key  = y_reshape.permute(0,2,1) #BX(N)XCenergy =  torch.bmm(proj_query,proj_key) #BXC1XCenergy_new = torch.max(energy,-1,keepdim=True)[0].expand_as(energy)-energyattention = self.softmax(energy_new)proj_value = y.view(B,K,-1) #BCNout = torch.bmm(attention,proj_value) #BC1Nout = out.view(m_batchsize,C,width ,height)out = x + self.scale*outreturn out

整个CPAM模块在encoding.models.sseg.dran.py中的class DranHead中可以看出前部分进行了一组3x3和1x1的卷积操作,后部分进行了3x3的卷积操作

        ## Convs or modules for CCAMself.conv_ccam_b = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),norm_layer(inter_channels),nn.ReLU()) # conv5_cself.ccam_enc = nn.Sequential(nn.Conv2d(inter_channels, inter_channels//16, 1, bias=False),norm_layer(inter_channels//16),nn.ReLU()) # conv51_cself.ccam_dec = CCAMDec() # de_cself.conv_ccam_e = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),norm_layer(inter_channels),nn.ReLU()) # conv51def forward(self, multix):## Compact Channel Attention Module(CCAM)ccam_b = self.conv_ccam_b(multix[-1])ccam_f = self.ccam_enc(ccam_b)ccam_feat = self.ccam_dec(ccam_b,ccam_f)

Cross-Level Gating Decoder

最后还有一个比较有意思的模块,高低特征交叉融合门控(赋权值)解码器模块。数据流如下,H和L分别代表高低图像特征。这里的关键,不是简单想U-Net一样将高低特征在通道层进行堆叠。而是选择将高级图像特征作为门控对低级特征进行赋权值,进而有效选取有用的低级特征与高级特征进行堆叠。

这里我们先看encoding.models.sseg.dran.py的class DranHead,说清楚高低特征的输入。跟上图一样,高级图像特征来自两种注意力机制堆叠后,通过3x3卷积操作。低级特征来自ResNet layer1的输出。

        ## Fusion convself.conv_cat = nn.Sequential(nn.Conv2d(inter_channels*2, inter_channels//2, 3, padding=1, bias=False),norm_layer(inter_channels//2),nn.ReLU()) # conv_f## Cross-level Gating Decoder(CLGD) self.clgd = CLGD(inter_channels//2,inter_channels//2,norm_layer)def forward(self, multix):## Cross-level Gating Decoder(CLGD) final_feat = self.clgd(multix[0], feat_sum)

再来看看关键的CLGD模块,低级和高级特征分别进行3x3卷积核2倍上采样后进行堆叠,再经过1x1卷积+sigmoiod操作后,乘以一个系数与低级特征相乘(门控)。最后与高级特征进行堆叠后经过两组3x3卷积操作得到输出。

class CLGD(Module):"""Cross-level Gating Decoder"""def __init__(self, in_channels, out_channels, norm_layer):super(CLGD, self).__init__()inter_channels= 32self.conv_low = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),norm_layer(inter_channels),nn.ReLU()) #skipconvself.conv_cat = nn.Sequential(nn.Conv2d(in_channels+inter_channels, in_channels, 3, padding=1, bias=False),norm_layer(in_channels),nn.ReLU()) # fusion1self.conv_att = nn.Sequential(nn.Conv2d(in_channels+inter_channels, 1, 1),nn.Sigmoid()) # attself.conv_out = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),norm_layer(out_channels),nn.ReLU()) # fusion2self._up_kwargs = up_kwargsself.gamma = nn.Parameter(torch.ones(1))def forward(self, x,y):"""inputs :x : low level feature(N,C,H,W)  y:high level feature(N,C,H,W)returns :out :  cross-level gating decoder feature"""low_lvl_feat = self.conv_low(x)high_lvl_feat = upsample(y, low_lvl_feat.size()[2:], **self._up_kwargs)feat_cat = torch.cat([low_lvl_feat,high_lvl_feat],1)low_lvl_feat_refine = self.gamma*self.conv_att(feat_cat)*low_lvl_feat low_high_feat = torch.cat([low_lvl_feat_refine,high_lvl_feat],1)low_high_feat = self.conv_cat(low_high_feat)low_high_feat = self.conv_out(low_high_feat)return low_high_feat

reference

https://github.com/junfu1115/DANet

【语义分割】 DRANet Scene Segmentation With Dual Relation-Aware Attention Network相关推荐

  1. CVPR2020论文解读:三维语义分割3D Semantic Segmentation

    CVPR2020论文解读:三维语义分割3D Semantic Segmentation xMUDA: Cross-Modal Unsupervised Domain Adaptation for 3D ...

  2. 语义分割--End-to-End Instance Segmentation with Recurrent Attention

    End-to-End Instance Segmentation with Recurrent Attention CVPR2017 https://github.com/renmengye/rec- ...

  3. 读论文-OVSeg-基于遮罩自适应CLIP的开放词汇语义分割-Open-vicabulr semantic segmentation with mask-adaptived CLIP

    OPEN-VOCABULARY SEMANTIC SEGMENTATION WITH MASK-ADAPTED CLIP 基于MASK-ADAPTED剪辑的开放词汇语义分割 摘要 Open-vocab ...

  4. 小菊的语义分割1——语义分割科普Semantic Segmentation

    小菊的语义分割

  5. 语义分割(Semantic Segmentation)方法

    翻译来自:https://gist.github.com/khanhnamle1994/e2ff59ddca93c0205ac4e566d40b5e88 语义分割方面的资源:GitHub - mrgl ...

  6. 图像语义分割(semantic segmentation)

    本文对图像语义分割近年来的主要发展做一个综述性的介绍. 翻译了以下两篇博文,并进行了整合. https://www.jeremyjordan.me/semantic-segmentation/ htt ...

  7. 弱监督语义分割(Weakly-Supervised Semantic Segmentation)

    语义分割(Semantic Segmentation) 语义分割是指将图像中的每个像素分类为一个实例,其中每个实例都对应于一个类. 这项技术一直是计算机视觉图像领域的主要任务之一.而在实际应用中,由于 ...

  8. CVPR 2020 论文大盘点-语义分割篇

    图像分割应用广泛,在CVPR 2020 论文中所占比例很高,可说是一大热门,有110多篇相关论文,本文盘点CVPR 2020 所有语义分割(Semantic Segmentation)相关论文(文末有 ...

  9. IROS 2020 | 跨视角语义分割前沿进展

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨潘柏文@知乎 编辑丨人工智能算法与Python大数据 作者简介 潘柏文,麻省理工学院计算机科学与 ...

  10. 2020年,语义分割可以在哪些方向进行研究并取得突破?

    点击上方,选择星标或置顶,不定期资源大放送! 阅读大概需要10分钟 Follow小博主,每天更新前沿干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:链接:https://www.zhih ...

最新文章

  1. oracle--with as
  2. 【原创】DevExpress控件GridControl中的布局详解
  3. centos 卸载docker_Spring Boot学习05_Docker卸载与安装
  4. Linux纯干货知识总结|面试专用
  5. html-盒子模型及pading和margin相关
  6. junit单元测试,反射,注解
  7. javascript正则表达式一
  8. The Willpower Instinct(自控力,意志力)
  9. SQL学习之drop语句
  10. linux 监控平台介绍
  11. 自然语言处理(一)——中英文分词
  12. 《iOS 开发进阶(唐巧)》读书笔记
  13. 【Excel】如何使用RegexString正则表达式
  14. 软件中级设计师 - 程序语言设计
  15. Ray----Tune(2):Tune的用户指南
  16. 用 UML 图绘制三略
  17. 可视化DIY制作小程序APP和网站时为什么能千变万化?
  18. 蓝桥杯练习题六 - 大数乘法(c++)
  19. 大学生创新创业训练计划如何获得国家级立项
  20. mysql采购系统_Max采购管理系统采购管理系统官方下载)V1.2.1.8 MySQL官方版下载 - 下载吧...

热门文章

  1. Intellj(IDEA)部署新项目, “warning no artifacts configured” 完美解决方案
  2. openwrt nas_真牛气,矿渣蜗牛星际也能玩软路由Openwrt和NAS虚拟一体机
  3. Google 手機程式設計
  4. WEB前端设计师常用工具集锦
  5. Python 汉字转拼音的库--- PyPinyin
  6. oracle分页查询最常用的,常用的数据库分页查询语句
  7. 关于贷后的8个专业名词解析
  8. Golang获取时间戳并增加一天
  9. 计算机网络介绍,TCP协议,Socket网络编程
  10. elixir 规格_六家使用Elixir的著名公司-以及为什么做出改变