ExFuse: Enhancing Feature Fusion for Semantic Segmentation论文解读

代码链接:https://github.com/lxtGH/fuse_seg_pytorch
参考链接:https://zhuanlan.zhihu.com/p/74551902

摘要:

在本文中,我们首先指出,由于在语义层次和空间分辨率上的差距,低级和高级特征的简单融合可能效果较差。我们发现,将语义信息引入低级特征,将高分辨率细节引入高级特征,对以后的融合更有效。基于此观察结果,我们提出了一个新的框架,名为ExFuse,以弥补低层次和高级特征之间的差距,从而显著提高了4.0%的分割效果。

背景:

低级特征和高级特征在本质上是互补的,其中低级特征的空间细节丰富,但缺乏语义信息,反之亦然。考虑一种极端情况,即“纯”低级特性(分辨率大)只编码低级概念,如点、线或边。直观地说,高级特征与这种“纯”低级特征的融合帮助很少,因为低级特征的噪声太大,无法提供足够的高分辨率语义指导。相反,如果低级特征包含更多的语义信息,例如,编码相对更清晰的语义边界,那么融合就会变得容易——通过将高级特征映射与边界对齐起来,就可以获得精细的分割结果。类似地,空间信息少的“纯”高级特征(分辨率小)不能充分利用低级特征;然而,由于嵌入了额外的高分辨率特征,高级特征可能有机会通过对齐到最近的低级边界来完善自己。
从经验上看,低级特征和高级特征之间的语义重叠和分辨率重叠对特征融合的有效性起着重要的作用。换句话说,可以通过在低级特征(大分辨率)中引入更多的语义概念或在高级特征(小分辨率)中嵌入更多的空间信息来增强特征融合。

提出了一个名为ExFuse的框架,它解决了这个差距从以下两个方面来看:

  1. 为了在低级特征中引入更多的语义信息,我们提出了层重新排列(Layer Rearrangement)、语义监督(Semantic Supervision)和语义嵌入分支(Semantic Embedding Branch)
  2. 将更多的空间信息嵌入更多的高级特征,我们提出了两种新的方法:显式信道分辨率嵌入(Explicit Channel Resolution Embedding)和密集相邻预测(Densely Adjacent Prediction)

网络整体结构:其中,使用GCN(Global Convolution Network)作为backbone。

训练顺序:

  1. 将图片通过一层卷积层
  2. 使用ResNet或者ResNeXt作为下采样网络,将特征映射下采样为四个不同分辨率的特征映射,即四个等级水平的映射。在预训练阶段,会加入语义监督SS在四个特征映射中,即加入辅助损失,一起进行训练,低级特征被迫编码更多的语义概念,预训练后,除去辅助损失,进行微调。
  3. 四个特征映射进入语义嵌入分支SEB,将低级特征与高级特征进行融合。res-5不进入。
  4. 接下来进入GCN。
  5. 最底下的最高级特征单独进入ECRE,使通道中含有分辨率的信息。上三层进入反卷积层。
  6. 最后进入密集相邻预测DAP,使模型可以预测邻近位置的结果。

GCN:


代码:

class _GlobalConvModule(nn.Module):def __init__(self, in_dim, out_dim, kernel_size):super(_GlobalConvModule, self).__init__()pad0 = int((kernel_size[0] - 1) / 2)pad1 = int((kernel_size[1] - 1) / 2)# kernel size had better be odd number so as to avoid alignment errorsuper(_GlobalConvModule, self).__init__()self.conv_l1 = nn.Conv2d(in_channels=in_dim, out_channels=out_dim, kernel_size=(kernel_size[0], 1),padding=(pad0, 0),bias = False)self.conv_l2 = nn.Conv2d(out_dim, out_dim, kernel_size=(1, kernel_size[1]),padding=(0, pad1),bias = False)self.conv_r1 = nn.Conv2d(in_dim, out_dim, kernel_size=(1, kernel_size[1]),padding=(0, pad1))self.conv_r2 = nn.Conv2d(out_dim, out_dim, kernel_size=(kernel_size[0], 1),padding=(pad0, 0))def forward(self, x):x_l = self.conv_l1(x)x_l = self.conv_l2(x_l)x_r = self.conv_r1(x)x_r = self.conv_r2(x_r)x = x_l + x_rreturn x

Layer Rearrangement(层重新排列)

为了使低级特征(上图res-2或res-3)“更接近”监督,一种直接的方法是在早期阶段安排更多的层,而不是后期。例如,ResNeXt101模型分别有{3、2、4、3、23、3}的构建块;我们重新安排分配到{8、8、9、8},并调整通道的数量,以确保相同的整体计算复杂度。实验表明,尽管新设计的模型的ImageNet分类评分几乎没有变化,但其分割性能提高了0.8%。

Semantic Supervision(语义监督SS)

提出了另一种改进低级特征的方法,即语义监督(SS),即将辅助监督直接分配到编码器网络的早期阶段(见上图)。为了在辅助分支中生成语义输出,低级特征被迫编码更多的语义概念,这将有助于以后的特征融合。我们的语义监督方法主要关注于提高低级特征的质量,而不是提升主干模型本身
需要注意的是,这个模块并不是分割网络的一部分,而是在预训练的时候加在分类网络上的,总体分类损失等于所有辅助分支的加权求和。然后在预训练后,我们删除这些分支,并使用剩余的部分进行微调。
结构如下:

Semantic Embedding Branch(语义嵌入分支SEB)

这个模块就是在把特征融合到decoder之前先将高低级别的特征进行融合。具体来说就是高级别的先经过卷积,再上采样,然后和低级别的进行逐像素相乘

Explicit Channel Resolution Embedding(显式通道分辨率嵌入ECRE)将分辨率信息嵌入到channels中

高级特征含有极少的分辨率,为了获得更多的细节,往往需要使用扩充策略,然而,扩充会带来极高的计算量,因此,我们将把分辨率信息嵌入到通道中。即通道中含有分辨率的信息

按照文中的说法,这个模块是加在第一个上采样环节的,对上采样的feature map加一个辅助loss,因为反卷积含有权重,参数是可学习的,辅助loss没办法嵌入模型中,因此使用Sub-pixel Upsample代替反卷积。但是这里是和谁做loss?这里只能理解为,该部分的上采样是直接变为原图大小,和label做loss,也就意味着这里的上采样并不是分割网络的主路的一部分,只是通过该模块对feature map施加影响。代码中使用nn.PixelShuffle完成。nn.PixelShuffle将(B,C, H, W)转换为(B,C/r^2, Hxr, Wxr),将通道信息转换为分辨率信息。

Densely Adjacent Prediction(密度相邻预测DAP):

  • 传统的方法在最后预测分割结果是,每个像素点预测一个概率值,各个像素独立预测。而作者提出了一种方法,在预测某一位置像素的概率值时,参考周围3*3邻域的值,求平均得到。
  • 空间定位 (i,j) 上的特征点主要负责相同位置的语义信息。为尽可能多地把空间信息编码进通道,本文提出一种全新的机制——密集邻域预测,可以预测邻近位置的结果,比如 (i-1,j+1) 。
  • 原始GCN最后阶段为21通道(对应21个分类),因为是33邻域,一个像素需要参考9个像素的值,所以把21通道扩展为189(219)。
  • 实现过程:先将通道信息转换为分辨率信息,即使用nn.PixelShuffle,再采用平均池化层获取。

代码:

import torch
from torch import nnfrom model.deeplab_resnet import ModelBuilderclass _GlobalConvModule(nn.Module):def __init__(self, in_dim, out_dim, kernel_size):super(_GlobalConvModule, self).__init__()pad0 = int((kernel_size[0] - 1) / 2)pad1 = int((kernel_size[1] - 1) / 2)# kernel size had better be odd number so as to avoid alignment errorsuper(_GlobalConvModule, self).__init__()self.conv_l1 = nn.Conv2d(in_channels=in_dim, out_channels=out_dim, kernel_size=(kernel_size[0], 1),padding=(pad0, 0),bias = False)self.conv_l2 = nn.Conv2d(out_dim, out_dim, kernel_size=(1, kernel_size[1]),padding=(0, pad1),bias = False)self.conv_r1 = nn.Conv2d(in_dim, out_dim, kernel_size=(1, kernel_size[1]),padding=(0, pad1))self.conv_r2 = nn.Conv2d(out_dim, out_dim, kernel_size=(kernel_size[0], 1),padding=(pad0, 0))def forward(self, x):x_l = self.conv_l1(x)x_l = self.conv_l2(x_l)x_r = self.conv_r1(x)x_r = self.conv_r2(x_r)x = x_l + x_rreturn xclass SEB(nn.Module):def __init__(self, in_channels, out_channels):super(SEB, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1,padding=1)self.upsample = nn.Upsample(scale_factor=2, mode="bilinear")def forward(self, x):x1, x2 = xreturn x1 * self.upsample(self.conv(x2))class GCNFuse(nn.Module):def __init__(self, configer=None,kernel_size=7, dap_k=3):super(GCNFuse, self).__init__()self.num_classes =20num_classes = self.num_classesself.resnet_features = ModelBuilder().build_encoder("resnet101")self.layer0 = nn.Sequential(self.resnet_features.conv1, self.resnet_features.bn1,self.resnet_features.relu1, self.resnet_features.conv3,self.resnet_features.bn3, self.resnet_features.relu3)self.layer1 = nn.Sequential(self.resnet_features.maxpool, self.resnet_features.layer1)self.layer2 = self.resnet_features.layer2self.layer3 = self.resnet_features.layer3self.layer4 = self.resnet_features.layer4self.gcm1 = _GlobalConvModule(2048, num_classes * 4, (kernel_size, kernel_size))self.gcm2 = _GlobalConvModule(1024, num_classes, (kernel_size, kernel_size))self.gcm3 = _GlobalConvModule(512, num_classes * dap_k**2, (kernel_size, kernel_size))self.gcm4 = _GlobalConvModule(256, num_classes * dap_k**2, (kernel_size, kernel_size))self.deconv1 = nn.ConvTranspose2d(num_classes, num_classes * dap_k**2, kernel_size=4, stride=2, padding=1, bias=False)self.deconv2 = nn.ConvTranspose2d(num_classes, num_classes * dap_k**2, kernel_size=4, stride=2, padding=1, bias=False)self.deconv3 = nn.ConvTranspose2d(num_classes * dap_k**2, num_classes * dap_k**2, kernel_size=4, stride=2, padding=1, bias=False)self.deconv4 = nn.ConvTranspose2d(num_classes * dap_k**2, num_classes * dap_k**2, kernel_size=4, stride=2, padding=1, bias=False)self.deconv5 = nn.ConvTranspose2d(num_classes * dap_k**2, num_classes * dap_k**2, kernel_size=4, stride=2, padding=1, bias=False)self.ecre = nn.PixelShuffle(2)self.seb1 = SEB(2048, 1024)self.seb2 = SEB(3072, 512)self.seb3 = SEB(3584, 256)self.upsample2 = nn.Upsample(scale_factor=2, mode="bilinear")self.upsample4 = nn.Upsample(scale_factor=4, mode="bilinear")self.DAP = nn.Sequential(nn.PixelShuffle(dap_k),nn.AvgPool2d((dap_k,dap_k)))def forward(self, x):# suppose input = x , if x 512f0 = self.layer0(x)  # 256f1 = self.layer1(f0)  # 128print (f1.size())f2 = self.layer2(f1)  # 64print (f2.size())f3 = self.layer3(f2)  # 32print (f3.size())f4 = self.layer4(f3)  # 16print (f4.size())x = self.gcm1(f4)out1 = self.ecre(x)seb1 = self.seb1([f3, f4])gcn1 = self.gcm2(seb1)seb2 = self.seb2([f2, torch.cat([f3, self.upsample2(f4)], dim=1)])gcn2 = self.gcm3(seb2)seb3 = self.seb3([f1, torch.cat([f2, self.upsample2(f3), self.upsample4(f4)], dim=1)])gcn3 = self.gcm4(seb3)y = self.deconv2(gcn1 + out1)y = self.deconv3(gcn2 + y)y = self.deconv4(gcn3 + y)y = self.deconv5(y)y = self.DAP(y)return ydef freeze_bn(self):for m in self.modules():if isinstance(m, nn.BatchNorm2d):m.eval()if __name__ == '__main__':model = GCNFuse(20)model.freeze_bn()model.eval()image = torch.autograd.Variable(torch.randn(1, 3, 512, 512), volatile=True)res1= model(image)print (res1.size())

附代码 ExFuse相关推荐

  1. Get了!用Python制作数据预测集成工具 | 附代码

    作者 | 李秋键 责编 | 晋兆雨 大数据预测是大数据最核心的应用,是它将传统意义的预测拓展到"现测".大数据预测的优势体现在,它把一个非常困难的预测问题,转化为一个相对简单的描述 ...

  2. java中自造类是什么意思_Java建造者模式是什么?如何实现?(附代码)

    本篇文章给大家带来的内容是关于Java建造者模式是什么?如何实现?(附代码),有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助. 建造者模式 一.什么是建筑者模式? 建造者模式(Build ...

  3. 一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述

    <繁凡的深度学习笔记>第 15 章 元学习详解 (上)万字中文综述(DL笔记整理系列) 3043331995@qq.com https://fanfansann.blog.csdn.net ...

  4. html5自定义属性作用,html5自定义属性:如何获取自定义属性值(附代码)

    这篇文章给大家介绍的内容是关于html5自定义属性:如何获取自定义属性值(附代码),有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助. 自定义属性: 在HTML5中我们可以自定义属性,其格 ...

  5. 手把手教你用Keras进行多标签分类(附代码)_数据派THU-CSDN博客 (翻译:程思衍校对:付宇帅)

    手把手教你用Keras进行多标签分类(附代码)_数据派THU-CSDN博客 手把手教你用Keras进行多标签分类(附代码)_数据派THU-CSDN博客

  6. 独家 | 手把手教TensorFlow(附代码)

    上一期我们发布了"一文读懂TensorFlow(附代码.学习资料)",带领大家对TensorFlow进行了全面了解,并分享了入门所需的网站.图书.视频等资料,本期文章就来带你一步步 ...

  7. MobileViT: 一种更小,更快,高精度的轻量级Transformer端侧网络架构(附代码实现)...

    点击上方,选择星标或置顶,不定期资源大放送! 阅读大概需要5分钟 Follow小博主,每天更新前沿干货 [导读]之前详细介绍了轻量级网络架构的开源项目,详情请看深度学习中的轻量级网络架构总结与代码实现 ...

  8. 【卷积神经网络结构专题】一文详解AlexNet(附代码实现)

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! [导读]本文是卷积神经网络结构系列专题第二篇文章,前面我们已经介绍了第一个真正意义 ...

  9. 数据表格搜索php代码_手把手教学:提取PDF各种表格文本数据(附代码)

    标星★公众号     爱你们♥ 量化投资与机器学习编辑部报道 近期原创文章: ♥ 5种机器学习算法在预测股价的应用(代码+数据) ♥ Two Sigma用新闻来预测股价走势,带你吊打Kaggle ♥  ...

最新文章

  1. seci-log 1.11 发布 增加了ftpserver,远程ftp,sftp采集简化配置等功能
  2. 如何重装Citrix XenServer不丢失SR数据
  3. 32位地址的寻址方式
  4. 数据中心机房设计及各专业技术平衡
  5. CIPAddressCtrl的用法
  6. 「前端早读君007」css进阶之彻底理解视觉格式化模型
  7. vstar为什么登录不了_一手的闲鱼号,为什么现在闲鱼号一号难求
  8. nodejs在Liunx上的部署生产方式-PM2
  9. mysql 合并相加_mysql 多条记要判断相加减合并一条
  10. UG12.0基础绘图3D建模造型 工程图视频教程
  11. matlab多久可以入门,5分钟入门matlab
  12. VMware ESXi 安装部署过程
  13. 爬虫_app 2.7 packet capture抓包工具介绍
  14. ofd电子文档内容分析工具(分析文档、签章和证书)
  15. 理解v8的Isolate调度
  16. 【机器学习】可决系数R^2和MSE,MAE,SMSE
  17. 续:Windows Vista操作系统最新安全特性分析:改进和局限 (下)
  18. WSTMart 分销说明,三级分销与返利
  19. 10005 内联函数
  20. bzoj3573米特运输

热门文章

  1. 微信企业号开发之图文消息
  2. mybatis多表查询(两表)例子
  3. PostgreSQL 常用命令速查表
  4. 文献阅读_Joint Prostate Cancer Detection and Gleason Score Prediction in mp-MRI via FocalNet
  5. 【PC工具】白领办公必备,电脑定时提醒休息护眼软件:眼睛护士
  6. lora 与 485 双线备份式通讯
  7. 可靠性设计:容错设计
  8. 非Xposed版 修改微信摇塞子
  9. 视频质量和大小、分辨率200*200、码率kb/s、帧率FPS、带宽、码流、人数
  10. 从一代名将到一代名流 法切蒂的蓝黑人生