前言

近期想用各种方法提高一下mAP,看了一下ASPP的方法,都说效果不错。感觉可以当一个即插即用的模块。搜了一下代码,都是pytorch,于是就跟着用TensorFlow改写了一下。

主要难点可能就在TensorFlow和Pytorch的部分的函数转换和数据结构的不同。TensorFlow是BHWC,即批数量、高、宽、通道数。Pytorch是BCHW,即批数量、通道数、高、宽。

参考

论文:Learning Spatial Fusion for Single-Shot Object Detection
代码:ASFF

原作代码 level_2 部分有点问题,建议参考以下代码
YOLOX改进之添加ASFF

一点就分享系列(实践篇3—上篇)—修改YOLOV5 魔刀小试+ Trick心得分享(非常推荐研读一下)

ASFF结构

ASFF本质上就还是一种金字塔特征融合,不严谨的说有点像全连接,权重还能自适应学习。

ASFF代码

import tensorflow as tf
from tensorflow.keras.layers import Conv2D,BatchNormalization,Activation,LeakyReLU,ReLU
# https://github.com/GOATmessi7/ASFF/blob/4df6f7288b7882a45b8c2dcc3e6e7b499d6cc883/models/network_blocks.py
# https://blog.csdn.net/weixin_45679938/article/details/122354725
# class add_conv(tf.keras.layers.Layer):"""Add a conv2d / batchnorm / leaky ReLU block.Args:out_ch (int): number of output channels of the convolution layer.ksize (int): kernel size of the convolution layer.stride (int): stride of the convolution layer.Returns:out: Sequential layers composing a convolution block."""   def __init__(self, out_ch,ksize,stride,leaky=True):super(add_conv,self).__init__()self.conv = Conv2D(filters=out_ch,kernel_size=ksize,strides=stride,padding='same',use_bias=False)self.bn = BatchNormalization()self.act = LeakyReLU(0.1) if leaky==True else ReLU(6.0)def call(self,x):return self.act(self.bn(self.conv(x)))class ASFF(tf.keras.layers.Layer):def __init__(self, level,rfb=False,# vis=False):super(ASFF,self).__init__()self.level = levelself.dim = [512,256,128]self.inter_dim = self.dim[self.level]rfbif level == 0:self.stride_level_1 = add_conv(self.inter_dim,3,2) # in_channel = 512self.stride_level_2 = add_conv(self.inter_dim,3,2) # in_channel = 512self.expand = add_conv(512,3,1) # 输出是要给head的特征elif level==1:self.compress_level_0 = add_conv(self.inter_dim, 1, 1) # in_channel = 256self.stride_level_2 = add_conv(self.inter_dim, 3, 2) # in_channel = 256self.expand = add_conv(256, 3, 1) # 输出是要给head的特征elif level==2:self.compress_level_0 = add_conv(self.inter_dim, 1, 1) # in_channel = 128self.compress_level_1 = add_conv(self.inter_dim, 1, 1) # in_channel = 128self.expand = add_conv(128, 3, 1) # 输出是要给head的特征compress_c = 8 if rfb else 16  #when adding rfb, we use half number of channels to save memoryself.weight_level_0 = add_conv(compress_c, 1, 1)self.weight_level_1 = add_conv(compress_c, 1, 1)self.weight_level_2 = add_conv(compress_c, 1, 1)self.weight_levels = Conv2D(3,kernel_size=1,strides=(1,1),padding='valid')# self.vis = visdef call(self, x_level_0, x_level_1, x_level_2):if self.level == 0:level_0_resized = x_level_0level_1_resized = self.stride_level_1(x_level_1)level_2_downsampled_inter =tf.nn.max_pool2d(x_level_2, 3, strides=2, padding="SAME")   # 源代码padding = 1怀疑就是same,毕竟ksize=3,3//2=1       level_2_resized = self.stride_level_2(level_2_downsampled_inter)elif self.level==1:level_0_compressed = self.compress_level_0(x_level_0)# level_0_resized =F.interpolate(level_0_compressed, scale_factor=2, mode='nearest')# https://blog.csdn.net/weixin_40128276/article/details/104958708_,H,W,_ = level_0_compressed.shapelevel_0_resized = tf.image.resize(level_0_compressed,[H*2,W*2],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)level_1_resized = x_level_1level_2_resized = self.stride_level_2(x_level_2)elif self.level==2:level_0_compressed = self.compress_level_0(x_level_0)_,H,W,_ = level_0_compressed.shapelevel_0_resized = tf.image.resize(level_0_compressed,[H*4,W*4],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)x_level_1_compressed = self.compress_level_1(x_level_1)_,H,W,_ = x_level_1.shapelevel_1_resized = tf.image.resize(x_level_1_compressed,[H*2,W*2],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)level_2_resized = x_level_2level_0_weight_v = self.weight_level_0(level_0_resized)level_1_weight_v = self.weight_level_1(level_1_resized)level_2_weight_v = self.weight_level_2(level_2_resized)# levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v),1) # torch BCHW concat in channellevels_weight_v = tf.concat((level_0_weight_v, level_1_weight_v, level_2_weight_v),3) # tensorflow BHWC# levels_weight = F.softmax(levels_weight, dim=1)levels_weight = self.weight_levels(levels_weight_v)levels_weight = tf.nn.softmax(levels_weight,axis=-1)# pytorch# fused_out_reduced = level_0_resized * levels_weight[:,0:1,:,:]+\#                     level_1_resized * levels_weight[:,1:2,:,:]+\#                     level_2_resized * levels_weight[:,2:,:,:]fused_out_reduced = level_0_resized * levels_weight[:,:,:,0:1]+\level_1_resized * levels_weight[:,:,:,1:2]+\level_2_resized * levels_weight[:,:,:,2:]out = self.expand(fused_out_reduced)# if self.vis:#     return out,levels_weight,fused_out_reduced.sum()# else:#     return outreturn outif __name__ == "__main__":P5 = tf.keras.Input(shape=(13,13,512),batch_size=1)P4 = tf.keras.Input(shape=(26,26,256),batch_size=1)P3 = tf.keras.Input(shape=(52,52,128),batch_size=1)P5_ASFF =ASFF(level=0)(P5,P4,P3)P4_ASFF =ASFF(level=1)(P5,P4,P3)P3_ASFF =ASFF(level=2)(P5,P4,P3)print(P5_ASFF,P4_ASFF,P3_ASFF)

ASFF的TensorFlow2实现相关推荐

  1. tensorflow2 目标检测_一文了解YOLO-v4目标检测

    一.YOLO-v4主要做了什么? 通俗的讲,就是说这个YOLO-v4算法是在原有YOLO目标检测架构的基础上,采用了近些年CNN领域中最优秀的优化策略,从数据处理.主干网络.网络训练.激活函数.损失函 ...

  2. Anaconda3+python3.7.10+TensorFlow2.3.0+PyQt5环境搭建

    Anaconda3+python3.7.10+TensorFlow2.3.0+PyQt5环境搭建 一.Anaconda 创建 python3.7环境 1.进入 C:\Users\用户名 目录下,找到 ...

  3. 【深度学习】(6) tensorflow2.0使用keras高层API

    各位同学好,今天和大家分享一下TensorFlow2.0深度学习中借助keras的接口减少神经网络代码量.主要内容有: 1. metrics指标:2. compile 模型配置:3. fit 模型训练 ...

  4. 【TensorFlow2.0】(7) 张量排序、填充、复制、限幅、坐标选择

    各位同学好,今天和大家分享一下TensorFlow2.0中的一些操作.内容有: (1)排序 tf.sort().tf.argsort().top_k():(2)填充 tf.pad():(3)复制 tf ...

  5. 【TensorFlow2.0】(6) 数据统计,范数、最值、求和、均值、最值位置、唯一值、张量比较

    各位同学好,今天和大家分享一下TensorFlow2.0中的数据分析操作.内容有: (1)范数 tf.norm():(2)最值 tf.reduce_min(), tf.reduce_max()(3)求 ...

  6. 【TensorFlow2.0】(5) 数学计算、合并、分割

    各位同学好,今天和大家分享一下TensorFlow2.0中的数学运算方法.合并与分割.内容有: (1)基本运算:(2)矩阵相乘:(3)合并 tf.concat().tf.stack():(4)分割 t ...

  7. 【TensorFlow2.0】(4) 维度变换、广播

    各位同学好,今天我和大家分享一下TensorFlow2.0中有关数学计算的相关操作,主要内容有: (1) 改变维度:reshape():(2) 维度转置:transpose():(3) 增加维度:ex ...

  8. 【TensorFlow2.0】(3) 索引与切片操作

    各位同学好,今天我和大家分享一下TensorFlow2.0中索引与切片.内容有: (1) 给定每一维度的索引来获取数据:(2) 切片索引:(3) 省略号应用:(4) tf.gather() 方法:(5 ...

  9. 【TensorFlow2.0】(2) 创建tensor的方法

    各位同学好,今天和大家分享一下TensorFlow2.0中的tensor变量的创建方法.内容有: (1) 通过numpy和list创建tensor:(2) 创建全部为某个值的tensor:(3) 随机 ...

最新文章

  1. ES单字段支持的最大字符数
  2. 什么是m叉树_重型货架是什么?重型仓储货架介绍
  3. SAP MM Error message - Customizing incorrectly maintained – in transaction code ML81N
  4. Java纸牌拖拉机简单模拟
  5. 将所有单个json标注文件合并成一个总的json标注文件(COCO数据集格式)
  6. linux deepin手动升级内核命令
  7. 非刚性人脸跟踪 —— 人脸跟踪
  8. kubernetes1.8.4 安装指南 -- 8. 安装Kube DNS
  9. php pdo使用事务,PHP内PDO事务使用步骤详解
  10. python--从入门到实践--chapter 15 16 17 生成数据/下载数据/web API
  11. STM32利用光敏二极管实现光度测量
  12. 在自己的电脑上搭建服务器(可供对外访问)
  13. Ubuntu系统中各种文件颜色的含义
  14. R中安装LightGBM(Windows 64位)
  15. pandas 作图 统计_Pandas数据可视化工具——Seaborn用法整理(下)
  16. 计算机设备硬件设备,计算机硬件设备有哪些
  17. CSDN下载频道于2014年7月17日改版,23日-24日系统维护
  18. Android 开发工具一键下载
  19. docker部署kafka踩坑
  20. 忘记teamviewer密码怎么办?

热门文章

  1. 聚类分析在用户行为中的实例_聚类分析案例之市场细分
  2. 智能车浅谈——硬件篇
  3. Web server failed to start. Port 9080 was already in use报错解决
  4. 中国联通沃云----弹性云主机使用说明
  5. 一个关于随机矩阵谱范数的不等式
  6. OkGo上传文件、图片的用法
  7. H5推流解决方案测试环境搭建指南
  8. 使用 youtube api封装播放器的坑
  9. 搭建可以通过外网访问本地服务器CentOS7,这一篇就够了
  10. Oozie-4.1.0-cdh5.5.2 安装部署使用