文章目录

  • 前言
  • 1. ECA
  • 2. Coordinate attention
  • 3. Dual attention
  • 4. FrequencyChannelAttention
  • 5. BAM
  • 6.GlobalContext
    • 部分参考文献

前言

研究生阶段的一些工作、因为涉及到了注意力方面的研究,所以复现了一些比较出名的注意力模块,这些都是我和朋友根据自己理解复现的,用的是keras,不保证复现的正确性,欢迎交流。

1. ECA


https://blog.csdn.net/qq_35054151/article/details/115434812import math
from keras.layers import *
from keras.layers import Activation
from keras.layers import GlobalAveragePooling2D
import keras.backend as K
import tensorflow as tf
def eca_layer(inputs_tensor=None,num=None,gamma=2,b=1):"""注意力模块-NET:param inputs_tensor: input_tensor.shape=[batchsize,h,w,channels]:param num::param gamma::param b::return:"""channels = K.int_shape(inputs_tensor)[-1]t = int(abs((math.log(channels,2)+b)/gamma))k = t if t%2 else t+1x_global_avg_pool = GlobalAveragePooling2D()(inputs_tensor)x = Reshape((channels,1))(x_global_avg_pool)x = Conv1D(1, kernel_size=k,padding="same",name="eca_conv1_" + str(num))(x)x = Activation('sigmoid', name='eca_conv1_relu_' + str(num))(x)  #shape=[batch,chnnels,1]x = Reshape((1, 1, channels))(x)output = multiply([inputs_tensor,x])return output

2. Coordinate attention

import tensorflow as tf
from keras.layers import Lambda,Concatenate,Reshape,Conv2D,BatchNormalization,Activation,Multiply,Adddef coordinate(inputs,ratio=2, name="name"):W,H,C = [int(x) for x in inputs.shape[1:]]temp_dim = max(int(C//ratio),ratio)H_pool = Lambda(lambda x: tf.reduce_mean(x, axis=1))(inputs)W_pool = Lambda(lambda x: tf.reduce_mean(x, axis=2))(inputs)x = Concatenate(axis=1)([H_pool,W_pool])x = Reshape((1,W+H,C))(x)x = Conv2D(temp_dim,1, name=name+'1')(x)x = BatchNormalization()(x)x = Activation('relu')(x)x_h,x_w = Lambda(lambda x:tf.split(x,[H,W],axis=2))(x)x_w = Reshape((W,1,temp_dim))(x_w)x_h = Conv2D(C,1,activation='sigmoid',name=name+"2")(x_h)x_w = Conv2D(C, 1, activation='sigmoid',name=name+"3")(x_w)x = Multiply()([inputs,x_h,x_w])x = Add()([inputs,x])return x

3. Dual attention


import keras
from keras.layers import Activation, Conv2D
import keras.backend as K
import tensorflow as tf
from keras.layers import Layer#  位置注意
class PAM(Layer):def __init__(self,# beta_initializer=tf.zeros_initializer()beta_initializer=keras.initializers.Zeros(),beta_regularizer=None,beta_constraint=None,kernal_initializer='he_normal',kernal_regularizer=None,kernal_constraint=None,**kwargs):super(PAM, self).__init__(**kwargs)self.beta_initializer = beta_initializerself.beta_regularizer = beta_regularizerself.beta_constraint = beta_constraintself.kernal_initializer = kernal_initializerself.kernal_regularizer = kernal_regularizerself.kernal_constraint = kernal_constraintdef build(self, input_shape):_, h, w, filters = input_shapeself.beta = self.add_weight(shape=(1,),initializer=self.beta_initializer,name='beta',regularizer=self.beta_regularizer,constraint=self.beta_constraint,trainable=True)# print(self.beta)self.kernel_b = self.add_weight(shape=(filters, filters // 8),initializer=self.kernal_initializer,name='kernel_b',regularizer=self.kernal_regularizer,constraint=self.kernal_constraint,trainable=True)self.kernel_c = self.add_weight(shape=(filters, filters // 8),initializer=self.kernal_initializer,name='kernel_c',regularizer=self.kernal_regularizer,constraint=self.kernal_constraint,trainable=True)self.kernel_d = self.add_weight(shape=(filters, filters),initializer=self.kernal_initializer,name='kernel_d',regularizer=self.kernal_regularizer,constraint=self.kernal_constraint,trainable=True)self.built = Truedef compute_output_shape(self, input_shape):return input_shapedef call(self, inputs):input_shape = inputs.get_shape().as_list()_, h, w, filters = input_shapeb = K.dot(inputs, self.kernel_b)c = K.dot(inputs, self.kernel_c)d = K.dot(inputs, self.kernel_d)vec_b = K.reshape(b, (-1, h * w, filters // 8))vec_cT = K.permute_dimensions(K.reshape(c, (-1, h * w, filters // 8)), (0, 2, 1))bcT = K.batch_dot(vec_b, vec_cT)softmax_bcT = Activation('softmax')(bcT)vec_d = K.reshape(d, (-1, h * w, filters))bcTd = K.batch_dot(softmax_bcT, vec_d)bcTd = K.reshape(bcTd, (-1, h, w, filters))out = self.beta * bcTd + inputs# print(self.beta)return out#  通道注意
class CAM(Layer):def __init__(self,# gamma_initializer=tf.zeros_initializer()gamma_initializer=keras.initializers.Zeros(),gamma_regularizer=None,gamma_constraint=None,**kwargs):super(CAM, self).__init__(**kwargs)self.gamma_initializer = gamma_initializerself.gamma_regularizer = gamma_regularizerself.gamma_constraint = gamma_constraintdef build(self, input_shape):self.gamma = self.add_weight(shape=(1,),initializer=self.gamma_initializer,name='gamma',regularizer=self.gamma_regularizer,constraint=self.gamma_constraint)# print(self.gamma)self.built = Truedef compute_output_shape(self, input_shape):return input_shapedef call(self, inputs):input_shape = inputs.get_shape().as_list()_, h, w, filters = input_shapevec_a = K.reshape(inputs, (-1, h * w, filters))vec_aT = K.permute_dimensions(K.reshape(vec_a, (-1, h * w, filters)), (0, 2, 1))aTa = K.batch_dot(vec_aT, vec_a)softmax_aTa = Activation('softmax')(aTa)aaTa = K.batch_dot(vec_a, softmax_aTa)aaTa = K.reshape(aaTa, (-1, h, w, filters))out = self.gamma * aaTa + inputs# print(self.gamma)return out#  使用方法
# pam = PAM()(reduce_conv5_3)
# cam = CAM()(reduce_conv5_3)
# feature_sum = add([pam, cam])

4. FrequencyChannelAttention

import math
import tensorflow as tf
import mathdef get_freq_indices(method):assert method in ['top1', 'top2', 'top4', 'top8', 'top16', 'top32','bot1', 'bot2', 'bot4', 'bot8', 'bot16', 'bot32','low1', 'low2', 'low4', 'low8', 'low16', 'low32']num_freq = int(method[3:])if 'top' in method:all_top_indices_x = [0, 0, 6, 0, 0, 1, 1, 4, 5, 1, 3, 0, 0, 0, 3, 2, 4, 6, 3, 5, 5, 2, 6, 5, 5, 3, 3, 4, 2, 2,6, 1]all_top_indices_y = [0, 1, 0, 5, 2, 0, 2, 0, 0, 6, 0, 4, 6, 3, 5, 2, 6, 3, 3, 3, 5, 1, 1, 2, 4, 2, 1, 1, 3, 0,5, 3]mapper_x = all_top_indices_x[:num_freq]mapper_y = all_top_indices_y[:num_freq]elif 'low' in method:all_low_indices_x = [0, 0, 1, 1, 0, 2, 2, 1, 2, 0, 3, 4, 0, 1, 3, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6, 1, 2,3, 4]all_low_indices_y = [0, 1, 0, 1, 2, 0, 1, 2, 2, 3, 0, 0, 4, 3, 1, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5,4, 3]mapper_x = all_low_indices_x[:num_freq]mapper_y = all_low_indices_y[:num_freq]elif 'bot' in method:all_bot_indices_x = [6, 1, 3, 3, 2, 4, 1, 2, 4, 4, 5, 1, 4, 6, 2, 5, 6, 1, 6, 2, 2, 4, 3, 3, 5, 5, 6, 2, 5, 5,3, 6]all_bot_indices_y = [6, 4, 4, 6, 6, 3, 1, 4, 4, 5, 6, 5, 2, 2, 5, 1, 4, 3, 5, 0, 3, 1, 1, 2, 4, 2, 1, 1, 5, 3,3, 3]mapper_x = all_bot_indices_x[:num_freq]mapper_y = all_bot_indices_y[:num_freq]else:raise NotImplementedErrorreturn mapper_x, mapper_y#  注意力层
def MultiSpectralAttentionLayer(x, channel, dct_h, dct_w, reduction=16, freq_sel_method='top2'):print("------MultiSpectralAttentionLayer----start")n, h, w, c = x.shapex_pooled = xmapper_x, mapper_y = get_freq_indices(freq_sel_method)num_split = len(mapper_x)mapper_x = [temp_x * (dct_h // 7) for temp_x in mapper_x]mapper_y = [temp_y * (dct_w // 7) for temp_y in mapper_y]y = MultiSpectralDCTLayer(x_pooled, dct_h, dct_w, mapper_x, mapper_y, channel)y = tf.layers.dense(y, channel // reduction, activation=tf.nn.relu)y = tf.layers.dense(y, channel)y = tf.math.sigmoid(y)y = tf.reshape(y, [n, 1, 1, c])y = tf.transpose(y, (0, 3, 1, 2))y = tf.tile(y, (1, 1, h, w))print("------MultiSpectralAttentionLayer----end")y = tf.transpose(y, (0, 2, 3, 1))return x * ydef MultiSpectralDCTLayer(x, height, width, mapper_x, mapper_y, channel):print("------MutilSpectralDCTLaer----start")# assert len(mapper_x)==(mapper_y)assert channel % len(mapper_x) == 0num_freq = len(mapper_x)weight = get_dct_filter(height, width, mapper_x, mapper_y, channel)print(height)print(width)x = x * weightresult = tf.reduce_sum(x, [1, 2])print("------MutilSpectralDCTLaer----end")return resultdef build_filter(pos, freq, POS):# print("------build_filter----statr")pi = tf.constant(math.pi)POS = tf.cast(pos, tf.float32)freq = tf.cast(freq, tf.float32)POS = tf.cast(POS, tf.float32)result = tf.math.cos(pi * freq * (pos + 0.5) / POS) / tf.math.sqrt(POS)# print("------build_filter----end")if freq == 0:return resultelse:return result * tf.math.sqrt(tf.cast(2, tf.float32))def get_dct_filter(tile_size_x, tile_size_y, mapper_x, mapper_y, channel):print("------get_dct_filter----statr")dct_filter = tf.Variable(tf.zeros([channel, tile_size_x, tile_size_y]))c_part = channel // len(mapper_x)for i, (u_x, v_y) in enumerate(zip(mapper_x, mapper_y)):for t_x in range(tile_size_x):for t_y in range(tile_size_y):dct_filter[i * c_part: (i + 1) * c_part, t_x, t_y].assign(build_filter(t_x, u_x, tile_size_x) * build_filter(t_y, v_y, tile_size_y))dct_filter = tf.transpose(dct_filter, [1, 2, 0])print("------get_dct_filter----end")return dct_filter

5. BAM

# -*- coding: utf-8 -*-import tensorflow as tf
import tensorflow.contrib.slim as slimbatch_norm_params = {# Decay for moving averages'decay': 0.995,# epsilon to prevent 0 in variance'epsilon': 0.001,# force in-place updates of mean and variances estimates'updates_collections': None,# moving averages ends up in the trainable variables collection'variables_collections': [tf.GraphKeys.TRAINABLE_VARIABLES]}def BAM(inputs, batch_norm_params, reduction_ratio=16, dilation_value=4, reuse=None, scope='BAM'):with tf.variable_scope(scope, reuse=reuse):with slim.arg_scope([slim.conv2d, slim.fully_connected],weights_initializer=slim.xavier_initializer(),weights_regularizer=slim.l2_regularizer(0.0005)):with slim.arg_scope([slim.conv2d], activation_fn=None):input_channel = inputs.get_shape().as_list()[-1]num_squeeze = input_channel // reduction_ratio# Channel attentiongap = tf.reduce_mean(inputs, axis=[1, 2], keepdims=True)channel = slim.fully_connected(gap, num_squeeze, activation_fn=None)channel = slim.fully_connected(channel, input_channel, activation_fn=None,normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params)# Spatial attentionspatial = slim.conv2d(inputs, num_squeeze, 1, padding='SAME')spatial = slim.repeat(spatial, 2, slim.conv2d, num_squeeze, 3, padding='SAME', rate=dilation_value)spatial = slim.conv2d(spatial, 1, 1, padding='SAME',normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params)# combined two attention branchcombined = tf.nn.sigmoid(channel + spatial)output = inputs + inputs * combinedreturn output

6.GlobalContext

"""
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as npdef conv(x, out_channel, kernel_size, stride=1, dilation=1):x = slim.conv2d(x, out_channel, kernel_size, stride, rate=dilation,activation_fn=None)return xdef global_avg_pool2D(x):with tf.variable_scope(None, 'global_pool2D'):n,h,w,c=x.get_shape().as_listx = slim.avg_pool2d(x, (h,w), stride=1)return xdef global_context_module(x,squeeze_depth,fuse_method='add',attention_method='att',scope=None):assert fuse_method in ['add','mul']assert attention_method in ['att','avg']with tf.variable_scope(scope,"GCModule"):if attention_method == 'avg':context = global_avg_pool2D(x)#[N,1,1,C]else:n,h,w,c=x.get_shape().as_list()context_mask = conv(x,1,1)# [N, H, W,1]context_mask = tf.reshape(context_mask,shape=tf.convert_to_tensor([tf.shape(x)[0], -1, 1]))# [N, H*W, 1]context_mask=tf.transpose(context_mask,perm=[0,2,1])# [N, 1, H*W]context_mask = tf.nn.softmax(context_mask,axis=2)# [N, 1, H*W]input_x = tf.reshape(x, shape=tf.convert_to_tensor([tf.shape(x)[0], -1,c]))# [N,H*W,C]context=tf.matmul(context_mask,input_x)# [N, 1, H*W] x [N,H*W,C] =[N,1,C]context=tf.expand_dims(context,axis=1)#[N,1,1,C]context=conv(context,squeeze_depth,1)context=slim.layer_norm(context)context=tf.nn.relu(context)context=conv(context,c,1)#[N,1,1,C]if fuse_method=='mul':context=tf.nn.sigmoid(context)out=context*xelse:out=context+xreturn out

部分参考文献

[91]Wang Q ,  Wu B ,  Zhu P , et al. ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks[C]// 2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). IEEE, 2020.
[95]Woo S, Park J, Lee J Y, et al. Cbam: Convolutional block attention module[C]//Proceedings of the European conference on computer vision (ECCV). 2018: 3-19.
[105] Hou Q, Zhou D, Feng J. Coordinate attention for efficient mobile network design[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021: 13713-13722.
[106] Cao Y, Xu J, Lin S, et al. Gcnet: Non-local networks meet squeeze-excitation networks and beyond[C]//Proceedings of the IEEE/CVF International Conference on Computer Vision Workshops. 2019: 0-0.
[107] Li X, Wang W, Hu X, et al. Selective kernel networks[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019: 510-519.

【附代码实现】Attention注意力模块的keras\tf实现(ECA、BAM、Coordinate、DualAttention、GlobalContext等)相关推荐

  1. 循环神经网络RNN 2—— attention注意力机制(附代码)

    attention方法是一种注意力机制,很明显,是为了模仿人的观察和思维方式,将注意力集中到关键信息上,虽然还没有像人一样,完全忽略到不重要的信息,但是其效果毋庸置疑,本篇我们来总结注意力机制的不同方 ...

  2. 图像处理注意力机制Attention汇总(附代码)

    原文链接: 图像处理注意力机制Attention汇总(附代码,SE.SK.ECA.CBAM.DA.CA等) 1. 介绍 注意力机制(Attention Mechanism)是机器学习中的一种数据处理方 ...

  3. CBAM——即插即用的注意力模块(附代码)

    论文:CBAM: Convolutional Block Attention Module 代码: code 目录 前言 1.什么是CBAM? (1)Channel attention module( ...

  4. 注意力机制BAM和CBAM详细解析(附代码)

    论文题目①:BAM: Bottleneck Attention Module 论文题目②:CBAM:CBAM: Convolutional Block Attention Module Bottlen ...

  5. 《最新开源 随插即用》SAM 自增强注意力深度解读与实践(附代码及分析)

    写在前面 大家好,我是cv君,前段时间忙碌工作,许久没更新,越发觉得对不起csdn的读者们,决定继续加油保持更新,保持一周2-3篇的高频率和高质量文章更新:论文分析.代码讲解.代码实操和训练.优化部署 ...

  6. ECA 注意力模块 原理分析与代码实现

    前言 本文介绍ECA注意力模块,它是在ECA-Net中提出的,ECA-Net是2020 CVPR中的论文:ECA模块可以被用于CV模型中,能提取模型精度,所以给大家介绍一下它的原理,设计思路,代码实现 ...

  7. 手把手教你用Keras进行多标签分类(附代码)_数据派THU-CSDN博客 (翻译:程思衍校对:付宇帅)

    手把手教你用Keras进行多标签分类(附代码)_数据派THU-CSDN博客 手把手教你用Keras进行多标签分类(附代码)_数据派THU-CSDN博客

  8. 教你用Keras和CNN建立模型识别神奇宝贝!(附代码)

    作者:ADRIAN ROSEBROCK 翻译:张恬钰 校对:万文菁 本文8500字,建议阅读30+分钟. 本文将讲解如何用Keras和卷积神经网络(CNN)来建立模型识别神奇宝贝! 用Keras创造一 ...

  9. SE 注意力模块 原理分析与代码实现

    前言 本文介绍SE注意力模块,它是在SENet中提出的,SENet是ImageNet 2017的冠军模型:SE模块常常被用于CV模型中,能较有效提取模型精度,所以给大家介绍一下它的原理,设计思路,代码 ...

最新文章

  1. 不允许所请求的注册表访问权
  2. 如果有一个类是 myClass , 关于下面代码正确描述的是?
  3. 动画演示 Delphi 2007 IDE 功能[3] - 修改属性
  4. C# 版本 疫情传播仿真程序
  5. 【渝粤教育】广东开放大学 建筑工程施工 形成性考核 (58)
  6. 事件(二):事件处理程序
  7. js面试题:创建一个json对象people,并追加属性:姓名、性别、年龄,追加run方法...
  8. 190606每日一句
  9. 翻出过去的一个多彩泡泡屏保特效(JS+CSS版)
  10. Excel学习笔记:P18-COUNTIFS函数与SUMIFS函数
  11. 微信公众号二维码生成
  12. Dubbo 常见的负载均衡(Load Balance)算法,一起学习一下吧~
  13. 如何在Vue项目中引入ArcGIS JavaScript API​ 创建三维可视化地图(含vue项目创建教程)
  14. 木马是如何编写的(一)
  15. 杨建允:抖快直播电商的运营逻辑是否可以复制
  16. matlab触发igbt电路设计,IGBT单相桥式无源逆变电路设计(纯电阻负载).doc
  17. python复利计算_python复利代码
  18. Maven手动导入依赖
  19. 计算机右键括号内的字母,排序快捷键 excel|关于EXCEL的。如图,括号里面的字母是不是代表快捷键的意思?怎么快捷?...
  20. 自学本科计算机课程要多久,22岁完全0基础自考计算机本科是否现实?

热门文章

  1. LWIP-TCP Server连接两次之后无法连接问题
  2. 关于SV的一些知识1
  3. matlab实现模糊控制器并仿真,用Matlab实现空调温度模糊控制器的设计与仿真.pdf...
  4. JVM分化回收机制(年轻代、年老代、永久代)
  5. MarkDown首行缩进和换行
  6. 【SpringDataJPA从入门到精通】02-JPA API
  7. DB 查询分析器 6.04 发布 ,本人为之撰写的相关技术文章达78篇
  8. 批量删除github工程仓库的办法
  9. 人工智能---图像识别
  10. Python中int32转int64