import tensorflow as tf"""
实现用于CNN的注意力机制的模块
"""def cbam(inputs, reduction=8):"""变量重用,我们使用的是 tf.AUTO_REUSE:param inputs:  输入的tensor格式为: [N, H, W, C]:param reduction::return:"""with tf.variable_scope('cbam', reuse=tf.AUTO_REUSE):_, height, width, channels = inputs.get_shape()# 1、实现通道注意力x_mean = tf.reduce_mean(inputs, axis=[1, 2], keep_dims=True)  # [N, 1, 1, C]x_mean = tf.layers.conv2d(x_mean, channels // reduction, kernel_size=1, activation=tf.nn.relu, name='cbam1')   # [N, 1, 1, C/r]x_mean = tf.layers.conv2d(x_mean, channels, kernel_size=1, activation=None, name='cbam2')  #  [N, 1, 1, C]x_max = tf.reduce_max(inputs, axis=[1, 2], keep_dims=True)  # [N, 1, 1, C]x_max = tf.layers.conv2d(x_max, channels // reduction, kernel_size=1, activation=tf.nn.relu, name='cbam1')  # [N, 1, 1, C/r]x_max = tf.layers.conv2d(x_max, channels, kernel_size=1, activation=None, name='cbam2')  # [N, 1, 1, C]x = tf.add(x_mean, x_max)x = tf.nn.sigmoid(x)  # [N, 1, 1, C]# 获取通道注意力结果x = tf.multiply(inputs, x)  # [N, H, W, C]# 2、空间注意力y_mean = tf.reduce_mean(x, axis=[3], keepdims=True)  # [N, H, W, 1]y_max = tf.reduce_max(x, axis=[3], keep_dims=True)  # [N, H, W, 1]y = tf.concat([y_mean, y_max], axis=-1)  # [N, H, W, 2]y = tf.layers.conv2d(y, filters=1, kernel_size=7, padding='same', activation=tf.nn.sigmoid)  # [N, H, W, 1]y = tf.multiply(x, y)  # [N, H, W, C]return ydef test():with tf.Graph().as_default():data = tf.ones(shape=[64, 32, 32, 128], dtype=tf.float32)cbam_out = cbam(data, reduction=8)vars_list = tf.trainable_variables()print(cbam_out)print(len(vars_list), '\n', vars_list)if __name__ == '__main__':test()
D:\Anaconda\python.exe D:/AI20/HJZ/04-深度学习/3-CNN/20191215__AI20_CNN/09_CBAM_block.py
WARNING:tensorflow:From D:/AI20/HJZ/04-深度学习/3-CNN/20191215__AI20_CNN/09_CBAM_block.py:19: calling reduce_mean (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.
Instructions for updating:
keep_dims is deprecated, use keepdims instead
WARNING:tensorflow:From D:/AI20/HJZ/04-深度学习/3-CNN/20191215__AI20_CNN/09_CBAM_block.py:27: calling reduce_max (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.
Instructions for updating:
keep_dims is deprecated, use keepdims instead
Tensor("cbam/Mul_1:0", shape=(64, 32, 32, 128), dtype=float32)
6 [<tf.Variable 'cbam/cbam1/kernel:0' shape=(1, 1, 128, 16) dtype=float32_ref>, <tf.Variable 'cbam/cbam1/bias:0' shape=(16,) dtype=float32_ref>, <tf.Variable 'cbam/cbam2/kernel:0' shape=(1, 1, 16, 128) dtype=float32_ref>, <tf.Variable 'cbam/cbam2/bias:0' shape=(128,) dtype=float32_ref>, <tf.Variable 'cbam/conv2d/kernel:0' shape=(7, 7, 2, 1) dtype=float32_ref>, <tf.Variable 'cbam/conv2d/bias:0' shape=(1,) dtype=float32_ref>]Process finished with exit code 0

07-CBAM_block注意力机制相关推荐

  1. 各种注意力机制PyTorch实现

    给出了整个系列的PyTorch的代码实现,以及使用方法. 各种注意力机制 Pytorch implementation of "Beyond Self-attention: External ...

  2. 【Attention九层塔】注意力机制的九重理解

    本文作者:电光幻影炼金术 研究生话题Top1,上海交大计算机第一名,高中物理竞赛一等奖,段子手,上海交大计算机国奖,港中文博士在读 https://zhuanlan.zhihu.com/p/36236 ...

  3. 收藏 | PyTorch实现各种注意力机制

    点上方计算机视觉联盟获取更多干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:机器之心 AI博士笔记系列推荐 周志华<机器学习>手推笔记正式开源!可打印版本附pdf下载链接 P ...

  4. 基于注意力机制的seq2seq网络

    六月 北京 | 高性能计算之GPU CUDA培训 6月22-24日三天密集式学习  快速带你入门阅读全文> 正文共1680个字,26张图,预计阅读时间10分钟. seq2seq的用途有很多,比如 ...

  5. 空间注意力机制sam_自己挖坑自己填,谷歌大改Transformer注意力,速度、内存利用率都提上去了...

    考虑到 Transformer 对于机器学习最近一段时间的影响,这样一个研究就显得异常引人注目了. 机器之心报道,机器之心编辑部. Transformer 有着巨大的内存和算力需求,因为它构造了一个注 ...

  6. keras实现注意力机制

    分别来用keras实现通道注意力模块和空间注意力模块. #通道注意力机制 def channel_attention(input_feature, ratio=8):channel_axis = 1 ...

  7. 新年美食鉴赏——基于注意力机制CBAM的美食101分类

    新年美食鉴赏--基于注意力机制CBAM的美食101分类 一.数据预处理 1.数据集介绍 2.读取标签 3.统一命名 4.整理图片路径 5.划分训练集与验证集 6.定义美食数据集 二.注意力机制 1.简 ...

  8. 【CBAM 解读】混合注意力机制:Convolutional Block Attention Module

    摘要 本文提出了卷积块注意模块(CBAM),这是一种简单而有效的前馈卷积神经网络注意模块.在给定中间特征图的情况下,我们的模块沿着通道和空间两个不同的维度顺序地推断关注图,然后将关注图与输入特征图相乘 ...

  9. 小目标检测3_注意力机制_Self-Attention

    主要参考: (强推)李宏毅2021/2022春机器学习课程 P38.39 李沐老师:64 注意力机制[动手学深度学习v2] 手把手带你Yolov5 (v6.1)添加注意力机制(一)(并附上30多种顶会 ...

  10. NLP系列(9)_深入理解BERT Transformer ,不仅仅是注意力机制

    大数据文摘与百度NLP联合出品 作者:Damien Sileo 审校:百度NLP.龙心尘 编译:张驰.毅航 https://blog.csdn.net/longxinchen_ml/article/d ...

最新文章

  1. 基于深度学习的目标检测研究进展
  2. sql 关联使用id还是code_使用sh格式化nginx访问日志并存入mysql
  3. Activity栈管理(三):Intent的Flag与taskAffinity
  4. 倒排列表求交集算法汇总
  5. VC中char,TCHAR,WCHAR总结
  6. 如何在Windows即服务上安装Memcached Server
  7. java 认证考试题_2017年Java认证考试真题及答案
  8. Oracle中SQL*plus常用命令
  9. 微信小程序 宠物论坛1
  10. Jib快速打包Docker镜像
  11. 〔首届CSDN.南京区程序员聚会〕正式报名情况[每日更新7月19日 17:30]
  12. php博客系统答辩ppt,基于PHP实现的WEB图片共享系统-php(开题报告+源程序+论文+答辩PPT+文献综述)...
  13. oracle是dbms还是dbs,Oracle学习笔记三——DBS
  14. LeetCode每日一题 1238.循环码排列
  15. java魔法师_RxJava魔法师app
  16. stlink下载调试器使用说明(STM32采用stlink下载程序)
  17. mysql的联合索引_mysql联合索引详解
  18. 计算机行业就业的发展前景怎么样?
  19. uc浏览器安卓版 打不开php吗,javascript只允许安卓uc浏览器访问
  20. 关于Group By 单个和多个字段

热门文章

  1. 高级驾驶辅助系统(ADAS)的安全性和静态分析
  2. android 百度人脸识别,百度人脸识别模块使用分享
  3. php循环输出数组 json,php循环通过json数组(php loop through json array)
  4. [4G+5G专题-135]: 部署 - 5G带宽只有100Mhz,下载速度能达到1Gbps吗?
  5. android jni skia,Android NDK 调用Skia进行底层绘图
  6. 人脉拓展的重要性:如何通过异业合作扩大自己的人脉资源?
  7. css背景 背景颜色 颜色渐变
  8. Deepin安装Wireshark
  9. EBS发票AP常用表
  10. 我使用python的进程池技术下载企业工商数据,速度1000万条/天,超快!