↑ 点击蓝字 关注极市平台作者丨ChaucerG来源丨AI人工智能初学者编辑丨极市平台

极市导读

本文介绍了一种新的注意力机制——Triplet Attention,它通过使用Triplet Branch结构捕获跨维度交互来计算注意力权重,是一个即插即用、简单高效的注意力模块。>>加入极市CV技术交流群,走在计算机视觉的最前沿


论文下载地址和代码开源地址:https://github.com/LandskapeAI/triplet-attention
https://arxiv.org/abs/2010.03045本文研究了轻量且有效的注意力机制,并提出了Triplet Attention,该注意力机制是一种通过使用Triplet Branch结构捕获跨维度交互来计算注意力权重的新方法。对于输入张量,Triplet Attention通过旋转操作和残差变换建立维度间的依存关系,并以可忽略的计算开销对通道和空间信息进行编码。该方法既简单又有效,并且可以轻松地插入经典Backbone中。

1、简介和相关方法

最近许多工作提出使用Channel Attention或Spatial Attention,或两者结合起来提高神经网络的性能。这些Attention机制通过建立Channel之间的依赖关系或加权空间注意Mask有能力改善由标准CNN生成的特征表示。学习注意力权重背后是让网络有能力学习关注哪里,并进一步关注目标对象。这里列举一些具有代表的工作:

1、SENet(Squeeze and Excite module)
2、CBAM(Convolutional Block Attention Module)
3、BAM(Bottleneck Attention Module)
4、Grad-CAM
5、Grad-CAM++
6、-Nets(Double Attention Networks)
7、NL(Non-Local blocks)
8、GSoP-Net(Global Second order Pooling Networks)
9、GC-Net(Global Context Networks)
10、CC-Net(Criss-Cross Networks)
11、SPNet等等方法(这些方法都值得大家去学习和调研,说不定会给你的项目带来意想不到的效果)。
以上大多数方法都有明显的缺点(Cross-dimension),Triplet Attention解决了这些缺点。Triplet Attention模块旨在捕捉Cross-dimension交互,从而能够在一个合理的计算开销内(与上述方法相比可以忽略不计)提供显著的性能收益。

2、本文方法

2.1、分析

本文的目标是研究如何在不涉及任何维数降低的情况下建立廉价但有效的通道注意力模型。Triplet Attention不像CBAM和SENet需要一定数量的可学习参数来建立通道间的依赖关系,本文提出了一个几乎无参数的注意机制来建模通道注意和空间注意,即Triplet Attention。

2.2、Triplet Attention

所提出的Triplet Attention见下图所示。顾名思义,Triplet Attention由3个平行的Branch组成,其中两个负责捕获通道C和空间H或W之间的跨维交互。最后一个Branch类似于CBAM,用于构建Spatial Attention。最终3个Branch的输出使用平均进行聚合。


1、Cross-Dimension Interaction

传统的计算通道注意力的方法涉及计算一个权值,然后使用权值统一缩放这些特征图。但是在考虑这种方法时,有一个重要的缺失。通常,为了计算这些通道的权值,输入张量在空间上通过全局平均池化分解为一个像素。这导致了空间信息的大量丢失,因此在单像素通道上计算注意力时,通道维数和空间维数之间的相互依赖性也不存在。虽然后期提出基于Spatial和Channel的CBAM模型缓解了空间相互依赖的问题,但是依然存在一个问题,即,通道注意和空间注意是分离的,计算是相互独立的。基于建立空间注意力的方法,本文提出了跨维度交互作用(cross dimension interaction)的概念,通过捕捉空间维度和输入张量通道维度之间的交互作用,解决了这一问题。


这里是通过三个分支分别捕捉输入张量的(C, H),(C, W)和(H, W)维间的依赖关系来引入Triplet Attention中的跨维交互作用。

2、Z-pool

Z-pool层负责将C维度的Tensor缩减到2维,将该维上的平均汇集特征和最大汇集特征连接起来。这使得该层能够保留实际张量的丰富表示,同时缩小其深度以使进一步的计算量更轻。可以用下式表示:

class ChannelPool(nn.Module):def forward(self, x):return torch.cat((torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=11)

3、Triplet Attention

给定一个输入张量,首先将其传递到Triplet Attention模块中的三个分支中。在第1个分支中,在H维度和C维度之间建立了交互:


为了实现这一点,输入张量沿H轴逆时针旋转90°。这个旋转张量表示为的形状为(W×H×C),再然后经过Z-Pool后的张量的shape为(2×H×C),然后,通过内核大小为k×k的标准卷积层,再通过批处理归一化层,提供维数(1×H×C)的中间输出。然后,通过将张量通过sigmoid来生成的注意力权值。在最后输出是沿着H轴进行顺时针旋转90°保持和输入的shape一致。在第2个分支中,在C维度和W维度之间建立了交互:


为了实现这一点,输入张量沿W轴逆时针旋转90°。这个旋转张量表示为的形状为(H×C×W),再然后经过Z-Pool后的张量的shape为(2×C×W ),然后,通过内核大小为k×k的标准卷积层,再通过批处理归一化层,提供维数(1×C×W)的中间输出。然后,通过将张量通过sigmoid来生成的注意力权值。在最后输出是沿着W轴进行顺时针旋转90°保持和输入的shape一致。在第3个分支中,在H维度和W维度之间建立了交互:输入张量的通道通过Z-pool将变量简化为2。将这个形状的简化张量(2×H×W)简化后通过核大小k定义的标准卷积层,然后通过批处理归一化层。输出通过sigmoid激活层生成形状为(1×H×W)的注意权值,并将其应用于输入,得到结果。然后通过简单的平均将3个分支产生的精细张量(C×H×W)聚合在一起。最终输出的Tensor:

class BasicConv(nn.Module):def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else Nonedef forward(self, x):
        x = self.conv(x)if self.bn is not None:
            x = self.bn(x)if self.relu is not None:
            x = self.relu(x)return xclass ChannelPool(nn.Module):def forward(self, x):return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )class SpatialGate(nn.Module):def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid_(x_out) return x * scaleclass TripletAttention(nn.Module):def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
        super(TripletAttention, self).__init__()
        self.ChannelGateH = SpatialGate()
        self.ChannelGateW = SpatialGate()
        self.no_spatial=no_spatialif not no_spatial:
            self.SpatialGate = SpatialGate()def forward(self, x):
        x_perm1 = x.permute(0,2,1,3).contiguous()
        x_out1 = self.ChannelGateH(x_perm1)
        x_out11 = x_out1.permute(0,2,1,3).contiguous()
        x_perm2 = x.permute(0,3,2,1).contiguous()
        x_out2 = self.ChannelGateW(x_perm2)
        x_out21 = x_out2.permute(0,3,2,1).contiguous()if not self.no_spatial:
            x_out = self.SpatialGate(x)
            x_out = (1/3)*(x_out + x_out11 + x_out21)else:
            x_out = (1/2)*(x_out11 + x_out21)return x_out

4、Complexity Analysis

通过与其他标准注意力机制的比较,验证了Triplet Attention的效率,C为该层的输入通道数,r为MLP在计算通道注意力时瓶颈处使用的缩减比,用于2D卷积的核大小用k表示,k<<

3、实验结果

3.1、图像分类实验


3.2、目标检测实验



3.3、消融实验


3.4、HeatMap输出对比



4、总结

在这项工作中提出了一个新的注意力机制Triplet Attention,它抓住了张量中各个维度特征的重要性。Triplet Attention使用了一种有效的注意计算方法,不存在任何信息瓶颈。实验证明,Triplet Attention提高了ResNet和MobileNet等标准神经网络架构在ImageNet上的图像分类和MS COCO上的目标检测等任务上的Baseline性能,而只引入了最小的计算开销。是一个非常不错的即插即用的注意力模块。

更为详细内容可以参见论文中的描述。

References

[1] Rotate to Attend: Convolutional Triplet Attention Module

推荐阅读

  • 即插即用的涨点神器,南航开源AFF:注意力特征融合

  • 盘点十大即插即用的涨点神器

  • 极市直播回放丨第62期-魏恺轩:免调试即插即用的近端优化算法

ACCV 2020国际细粒度网络图像识别竞赛正式开赛!

添加极市小助手微信(ID : cvmart2),备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳),即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群:每月大咖直播分享、真实项目需求对接、求职内推、算法竞赛、干货资讯汇总、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企视觉开发者互动交流~△长按添加极市小助手△长按关注极市平台,获取最新CV干货觉得有用麻烦给个在看啦~  

通道注意力机制_即插即用,Triplet Attention机制让Channel和Spatial交互更加丰富(附开源代码)...相关推荐

  1. attention机制_简析Attention机制—优缺点,实现,应用

    什么是Attention机制? Attention机制的本质来自于人类视觉注意力机制.人们在看东西的时候一般不会从到头看到尾全部都看,往往只会根据需求观察注意特定的一部分. 简单来说,就是一种权重参数 ...

  2. 建立完善的员工晋升机制_【员工晋升机制】多渠道员工晋升机制如何建立

    北京华恒智信人力资源顾问有限公司 [员工晋升机制]多渠道员工晋升机制如何建立 引言: 员工晋升机制是员工由较低层级职位上升到较高层级职位的过程, 合理的员工晋 升机制可以实现良好的资源配置, 使合适的 ...

  3. 工作中用到的java反射机制_(转)JAVA-反射机制的使用

    Java反射机制的实现原理 反射机制:所谓的反射机制就是java语言在运行时拥有一项自观的能力.通过这种能力可以彻底的了解自身的情况为下一步的动作做准备.下面具体介绍一下java的反射机制.这里你将颠 ...

  4. 哈佛NLP组论文解读:基于隐变量的注意力模型 | 附开源代码

    作者丨邓云天 学校丨哈佛大学NLP组博士生 研究方向丨自然语言处理 摘要 Attention 注意力模型在神经网络中被广泛应用.在已有的工作中,Attention 机制一般是决定性的而非随机变量.我们 ...

  5. java提供两种处理异常的机制_浅析Java异常处理机制

    关于异常处理的文章已有相当的篇幅,本文简单总结了Java的异常处理机制,并结合代码分析了一些异常处理的最佳实践,对异常的性能开销进行了简单分析. 博客另一篇文章<[译]Java异常处理的最佳实践 ...

  6. java线程锁机制_多线程之锁机制

    前言 在Java并发编程实战,会经常遇到多个线程访问同一个资源的情况,这个时候就需要维护数据的一致性,否则会出现各种数据错误,其中一种同步方式就是利用Synchronized关键字执行锁机制,锁机制是 ...

  7. 简述java的异常处理机制_简述java异常处理机制

    引言: Hello,我的好朋友们,又到我们相聚的时间了,今天我要和大家分享一些有关java异常处理的相关 知识,也是通过老师的讲解和相关材料的借鉴之后的一个比较系统的总结,真心希望写完这篇文章的我和看 ...

  8. 双亲委托类加载机制_图解JVM类加载机制和双亲委派模型

    我们都知道以 .java 结尾的 Java 源文件,经过编译之后会变成 .class 结尾的字节码文件.JVM 通过类加载器来加载字节码文件,然后再执行程序. 什么时候加载一个类 那么,什么时候类加载 ...

  9. mysql数据变化通通知机制_深入理解Notification机制

    先贴上这些源码里面相关的文件: framework/base/core/java/android/app/NotificationManager.java framework/base/service ...

最新文章

  1. logback property 默认值_看完这篇文章还不会给spring boot配置logback,请你吃瓜
  2. HDU-2444 The Accomodation of Students
  3. Python对象类型
  4. 一条空间不足报警的分析
  5. 滑动到底部或顶部响应的ScrollView实现
  6. Hyperledger Fabric chaincode 开发(疑难解答)
  7. oracle中的sql文本类型,Oracle数据库的空间数据类型
  8. 【数据结构笔记04】线性结构:线性表及其实现
  9. 【C语言】02-第一个C程序
  10. 地平线开源网站源码Deepsoon v1.2.3
  11. en开头的单词_英语四级en-词汇前后缀解析
  12. 聊聊程序员的简历应该怎么写(帮修改简历)
  13. 最详细的工业网络通讯技术与协议总结解读(现场总线、工业以太网、工业无线)
  14. 所在位置 行:1 字符: 1+ cnpm i+ ~~~~ + CategoryInfo : SecurityError: (:) [],PSSecurityExcepti
  15. DDoS 攻击防御方法
  16. 【ROSE】1. Rational Rose简介
  17. Cesium火灾动画(模型动画,粒子特效)
  18. 深度学习笔记(一)了解深度学习
  19. #UVM# 关于多次TB中 include “uvm_macros.svh“的疑问篇
  20. Python3,自动识别图片文字,这个库,我爱了。

热门文章

  1. 家庭用计算机怎样选择设置网络位置,win7系统怎么选择网络位置
  2. html让空间高度跟随父级,CSS子元素跟父元素的高度一致的实现方法
  3. android手机常用功能,Windows Phone 7/Android手机常用功能对比
  4. oracle11环境变量path设置_LUENT软件UDF环境变量配置
  5. qt android刘海屏状态栏,华为Mate30 Pro设计曝光:仍配刘海屏+3D结构光
  6. springboot 文件服务器_spring boot还不了解?一份spring boot实战文档送给你
  7. git合并分支的时候将某个文件添加到忽略列表_常用的 Git 命令
  8. mysql list转表_mysql系统表【转】
  9. android横向滑动缩放,移动端实现内容左右滑动,并点击放大效果的问题
  10. mongodb java and or,【MongoDB】-Java实现对mongodb的And、Or、In操作