文章目录

  • 前言
  • 一、Compat Position Attention Module紧凑型位置注意力模块
  • 二、Compat Channel Attention Module紧凑型通道注意力模块
  • 三、效果
  • 四、代码实现
    • 1.Pytorch源码(省略了引用库)
    • 2.keras实现

前言

之前看过一篇dual attention做自然图像分割的文章[1],后来看到作者还出了个优化版,叫Dual Relation-Aware Attention[2],主要解决的问题是dual attention计算和存储成本过高的问题(我跑dual attention也是一直OOM)。顺便试着实现了一下,但是不保证准确性,欢迎讨论,指出错误^_=

[1]Dual Attention Network for Scene Segmentation
论文:https://arxiv.org/abs/1809.02983
[2]Scene Segmentation With Dual Relation-Aware Attention Network
论文:https://ieeexplore.ieee.org/abstract/document/9154612/

代码地址(Pytorch):https://github.com/junfu1115/DANet (作者把两个论文的代码放在一个项目里了)


一、Compat Position Attention Module紧凑型位置注意力模块


Dual Attention的位置注意力PAM是通过特征图的内积实现的(实际上就是Self Attention,通过矩阵点乘建模像素间的全局关系),但当特征图比较大的时候,需要高昂的GPU内存开销和计算成本,因此作者提出了紧凑型位置注意力模块CPAM,其实现是通过金字塔池化(不同大小的池化)构建了每个像素和几个聚集中心之间的关系,将这些池化特征拼接起来做self attention内积,一定程度上减少了运算量和内存消耗。

二、Compat Channel Attention Module紧凑型通道注意力模块

作者注意到,使用通道注意力(CAM)模块时,如果特征映射数(通道数)较大时,需要注意计算量的问题。为了解决这个问题,作者提出了一个紧凑型通道注意力模块CCAM来建立每个通道图和通道聚集中心之间的关系。主要实现是先对输入特征通过1x1卷积进行降维, 再对其计算self attention。

三、效果


可以看出,降维之后的dual attention对算力和内存的消耗还是少了很多的(在我自己的网络上体现为终于不会OOM了,不过我做小样本医学图像分割,也没观察到涨点)。CAM和CCAM在涨点上差不多就不放了,值得注意的是PAM和CPAM,虽然两者涨点相同,但是作者发现PAM对小目标的提升更大,CPAM对大目标的提升更大(表3),感觉是因为池化导致像素间的关系缺失,导致小目标的识别比较受限。

另外,作者比较了CPAM和CCAM的连接方式,最后发现对输入特征分别并行计算CPAM和CCAM,再拼接最有效。

文章还提出了一个Details of cross-level gating decoder,不过还没仔细研究。

四、代码实现

1.Pytorch源码(省略了引用库)

class CPAMEnc(Module):"""CPAM encoding module"""def __init__(self, in_channels, norm_layer):super(CPAMEnc, self).__init__()self.pool1 = AdaptiveAvgPool2d(1)self.pool2 = AdaptiveAvgPool2d(2)self.pool3 = AdaptiveAvgPool2d(3)self.pool4 = AdaptiveAvgPool2d(6)self.conv1 = Sequential(Conv2d(in_channels, in_channels, 1, bias=False),norm_layer(in_channels),ReLU(True))self.conv2 = Sequential(Conv2d(in_channels, in_channels, 1, bias=False),norm_layer(in_channels),ReLU(True))self.conv3 = Sequential(Conv2d(in_channels, in_channels, 1, bias=False),norm_layer(in_channels),ReLU(True))self.conv4 = Sequential(Conv2d(in_channels, in_channels, 1, bias=False),norm_layer(in_channels),ReLU(True))def forward(self, x):b, c, h, w = x.size()feat1 = self.conv1(self.pool1(x)).view(b,c,-1)feat2 = self.conv2(self.pool2(x)).view(b,c,-1)feat3 = self.conv3(self.pool3(x)).view(b,c,-1)feat4 = self.conv4(self.pool4(x)).view(b,c,-1)return torch.cat((feat1, feat2, feat3, feat4), 2)class CPAMDec(Module):"""CPAM decoding module"""def __init__(self,in_channels):super(CPAMDec,self).__init__()self.softmax  = Softmax(dim=-1)self.scale = Parameter(torch.zeros(1))self.conv_query = Conv2d(in_channels = in_channels , out_channels = in_channels//4, kernel_size= 1) # query_conv2self.conv_key = Linear(in_channels, in_channels//4) # key_conv2self.conv_value = Linear(in_channels, in_channels) # value2def forward(self, x,y):"""inputs :x : input feature(N,C,H,W) y:gathering centers(N,K,M)returns :out : compact position attention featureattention map: (H*W)*M"""m_batchsize,C,width ,height = x.size()m_batchsize,K,M = y.size()proj_query  = self.conv_query(x).view(m_batchsize,-1,width*height).permute(0,2,1)#BxNxdproj_key =  self.conv_key(y).view(m_batchsize,K,-1).permute(0,2,1)#BxdxKenergy =  torch.bmm(proj_query,proj_key)#BxNxKattention = self.softmax(energy) #BxNxkproj_value = self.conv_value(y).permute(0,2,1) #BxCxKout = torch.bmm(proj_value,attention.permute(0,2,1))#BxCxNout = out.view(m_batchsize,C,width,height)out = self.scale*out + xreturn outclass CCAMDec(Module):"""CCAM decoding module"""def __init__(self):super(CCAMDec,self).__init__()self.softmax  = Softmax(dim=-1)self.scale = Parameter(torch.zeros(1))def forward(self, x,y):"""inputs :x : input feature(N,C,H,W) y:gathering centers(N,K,H,W)returns :out : compact channel attention featureattention map: K*C"""m_batchsize,C,width ,height = x.size()x_reshape =x.view(m_batchsize,C,-1)B,K,W,H = y.size()y_reshape =y.view(B,K,-1)proj_query  = x_reshape #BXC1XNproj_key  = y_reshape.permute(0,2,1) #BX(N)XCenergy =  torch.bmm(proj_query,proj_key) #BXC1XCenergy_new = torch.max(energy,-1,keepdim=True)[0].expand_as(energy)-energyattention = self.softmax(energy_new)proj_value = y.view(B,K,-1) #BCNout = torch.bmm(attention,proj_value) #BC1Nout = out.view(m_batchsize,C,width ,height)out = x + self.scale*outreturn out

2.keras实现

本人使用的tensorflow版本为2.8.0,建议使用2.5.0及以上版本
为了使用自适应平均池化(AdaptiveAveragePooling),需要安装tensorflow_addons(当然也可以自己计算池化尺度,然后用普通池化)

# DRANet
import tensorflow as tf
import numpy as np
from keras.layers import *
import tensorflow_addons as tfadef conv_norm_act(input_tensor, filters, kernel_size , dilation=1, norm_type='batch', act_type='relu'):'''Conv2d + Normalization(norm_type:str) + Activation(act_type:str)'''output_tensor = Conv2D(filters, kernel_size, padding='same', dilation_rate=(dilation, dilation), use_bias=False if norm_type is not None else True, kernel_initializer='he_normal')(input_tensor)output_tensor = normalization(output_tensor, normalization=norm_type)output_tensor = Activation(act_type)(output_tensor)return output_tensor# 仅支持channel last
def cpam_enc(x):'''x: input tensor with shape [B, H, W, C]'''b, h, w, c = x.shape# x = tf.transpose(x, [0, 3, 1, 2])   # must be channel lastfeat1 = tfa.layers.AdaptiveAveragePooling2D(output_size=(1, 1))(x)feat2 = tfa.layers.AdaptiveAveragePooling2D(output_size=(2, 2))(x)feat3 = tfa.layers.AdaptiveAveragePooling2D(output_size=(3, 3))(x)feat4 = tfa.layers.AdaptiveAveragePooling2D(output_size=(6, 6))(x)feat1 = tf.reshape(tf.transpose(conv_norm_act(feat1, c, 1, 'batch', 'relu'), [0, 3, 1, 2]), (-1, c, 1))feat2 = tf.reshape(tf.transpose(conv_norm_act(feat2, c, 1, 'batch', 'relu'), [0, 3, 1, 2]), (-1, c, 4))feat3 = tf.reshape(tf.transpose(conv_norm_act(feat3, c, 1, 'batch', 'relu'), [0, 3, 1, 2]), (-1, c, 9))feat4 = tf.reshape(tf.transpose(conv_norm_act(feat4, c, 1, 'batch', 'relu'), [0, 3, 1, 2]), (-1, c, 36))return concatenate([feat1, feat2, feat3, feat4], 2)def cpam_dec(x, y):'''inputs :x : input feature(N,H,W,C) y:gathering centers(N,K,M)returns :out : compact position attention featureattention map: (H*W)*M'''b, h, w, c = x.shapeb, k, m = y.shape# scale = tf.Variable(tf.zeros(1))scale = tf.Variable(tf.ones(1))proj_query = Conv2D(c//4, 1)(x)proj_query = tf.transpose(proj_query, [0, 3, 1, 2])proj_query = tf.transpose(tf.reshape(proj_query, (-1, c//4, h*w)), [0, 2, 1])proj_key = Dense(c//4)(y)proj_key = tf.transpose(tf.reshape(proj_key, (-1, k, c//4)), [0, 2, 1])energy = tf.matmul(proj_query, proj_key)attention = tf.nn.softmax(energy)proj_value = tf.transpose(Dense(c)(y), [0, 2, 1])out = tf.matmul(proj_value, tf.transpose(attention, [0, 2, 1]))out = tf.reshape(out, (-1, c, h, w))out = tf.transpose(out, [0, 2, 3, 1])out = out * scale + xreturn outdef ccam_enc(x):b, h, w, c = x.shapex = conv_norm_act(x, c//8, 1, 'batch', 'relu')x = tf.transpose(x, [0, 3, 1, 2])return xdef ccam_dec(x, y):'''inputs:x : input feature(N,H,W,C), y:gathering centers(N,K,H,W)returns :out : compact channel attention featureattention map: K*C'''m_batchsize, height, width, c = x.shapex = tf.transpose(x, [0, 3, 1, 2])   # must be channel lastx_reshape = tf.reshape(x, (-1, c, height*width))# scale = tf.Variable(tf.zeros(1))scale = tf.Variable(tf.ones(1))b, k, h, w = y.shapey_reshape = tf.reshape(y, (-1, k, h*w))proj_query = x_reshapeporj_key = tf.transpose(y_reshape, [0, 2, 1])energy = tf.matmul(proj_query, porj_key)energy_new = tf.reduce_max(energy, -1, keepdims=True)energy_new = tf.repeat(energy_new, energy.shape[-1], -1)energy_new = energy_new - energyattention = tf.nn.softmax(energy_new)proj_value = tf.reshape(y, (-1, k, h*w))out = tf.matmul(attention, proj_value)out = tf.reshape(out, (-1, c, height, width))out = x + scale * outout = tf.transpose(out, [0, 2, 3, 1])return outdef dra_attention(x, filters):y1 = cpam_enc(x)y2 = ccam_enc(x)att1 = cpam_dec(x, y1)att2 = ccam_dec(x, y2)att = concatenate([att1, att2], -1)   # channel lastatt = conv_norm_act(att, filters, 1, 'batch', 'relu')return att

欢迎纠错,指出问题TAT

双重关系感知注意力机制 Dual Relation-Aware Attention[keras实现 dual attention优化版]相关推荐

  1. ECCV 2018 | 美图云联合中科院提出基于交互感知注意力机制神经网络的行为分类技术...

    以往注意机制模型通过加权所有局部特征计算和提取关键特征,忽略了各局部特征间的强相关性,特征间存在较强的信息冗余.为解决此问题,来自美图云视觉技术部门和中科院自动化所的研发人员借鉴 PCA(主成分分析) ...

  2. keras cnn注意力机制_TensorFlow、PyTorch、Keras:NLP框架哪家强

    全文共3412字,预计学习时长7分钟 在对TensorFlow.PyTorch和Keras做功能对比之前,先来了解一些它们各自的非竞争性柔性特点吧. 非竞争性特点 下文介绍了TensorFlow.Py ...

  3. 【论文解读】基于关系感知的全局注意力

    一.论文信息 标题:<Relation-Aware Global Attention for Person Re-identification> 作者:Zhizheng Zhang et ...

  4. 万字长文解析CV中的注意力机制(通道/空间/时域/分支注意力)

    点击下方卡片,关注"自动驾驶之心"公众号 ADAS巨卷干货,即可获取 点击进入→自动驾驶之心技术交流群 后台回复[transformer综述]获取2022最新ViT综述论文! 注意 ...

  5. 万字长文解读计算机视觉中的注意力机制(附论文和代码链接)

    文中论文和代码已经整理,如果需要,点击下方公号关注,领取,持续传达瓜货 所向披靡的张大刀 注意力机制是机器学习中嵌入的一个网络结构,主要用来学习输入数据对输出数据贡献:注意力机制在NLP和CV中均有使 ...

  6. 一文读懂CV中的注意力机制

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨AdamLau@知乎 来源丨https://zhuanlan ...

  7. 一文看懂CV中的注意力机制

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨AdamLau@知乎 来源丨https://zhuanlan.zhihu.com/p/28875 ...

  8. 深度学习中的注意力机制(二)

    作者 | 蘑菇先生 来源 | NewBeeNLP 目前深度学习中热点之一就是注意力机制(Attention Mechanisms).Attention源于人类视觉系统,当人类观察外界事物的时候,一般不 ...

  9. Deep Reading | 从0到1再读注意力机制,此文必收藏!

    译者 | forencegan 编辑 | 琥珀 出品 | AI科技大本营(ID: rgznai100) [AI科技大本营导语]注意力机制(Attention)已经成为深度学习必学内容之一,无论是计算机 ...

最新文章

  1. [剑指Offer]5.二维数组中的查找
  2. JPA中persistence.xml模板
  3. POJ 3255 Roadblocks 次短路
  4. Select For update语句浅析
  5. nginx动态库加载出现is not binary compatible问题
  6. 【Ubuntu-Opencv】Ubuntu14.04 Opencv3.3.0 使用中出现OpenCV Error: Unspecified error
  7. UVA 10588—— Queuing at the doctors
  8. 前端学习(1381):多人管理项目1项目管理搭建
  9. javascript中的undefined 和 not defined
  10. 【英语学习】【English L06】U02 Food L2 Salad dressings
  11. win7 path环境变量被覆盖了怎么恢复_系统小技巧:还原Windows10路径环境变量
  12. GStreamer 编写一个简单的MP3播放器
  13. webpack 入门总结和实践(按需异步加载,css单独打包,生成多个入口文件)
  14. 页面回填当前日期与时间
  15. JVM 类加载机制详解
  16. 常用jQuery代码
  17. mfc窗口右下角如何显示一个三角形图案_以C4D制作金属碳笼为例:安利一个友好的三维制图软件...
  18. 200多个电脑修复工具问你要不要?
  19. 只需10行代码就让你的U盘变成纯净版winPE系统安装启动盘
  20. 物体移动时按下Shift键加快速度

热门文章

  1. 当代女性修身养性的箴言书——《读史做女人》
  2. 我丈母娘家的小店竟然被Dos攻击了
  3. ShareIntentUtil【调用系统自带的分享的工具类】
  4. PhotoShop: PSD精准切图
  5. ⭐App爬虫之路⭐:海量食谱数据爬取持久化!!!
  6. 台式计算机中远程登录在哪里,电脑如何进行远程连接
  7. vim:修改vim录制的宏
  8. redis的发布订阅缺陷
  9. c语言 游程编码,简单的行程编码-C语言实现
  10. 关于华为手机P20pro装包时总提示冲突问题