fishnet:论文阅读与代码理解

  • 一、论文概述
  • 二、整体框架
  • 三、代码理解
  • 四、总结

fishnet论文地址:http://papers.nips.cc/paper/7356-fishnet-a-versatile-backbone-for-image-region-and-pixel-level-prediction.pdf
fishnet源码地址(pytorch版本):https://github.com/kevin-ssy/FishNet

一、论文概述

  我们知道,对应不同的计算机视觉任务(图像分类、目标检测、语义分割、实例分割等),所需要卷积神经网络提取的特征是不一样的。以图像分类任务与语义分割任务为例。图像分类对应对图片级别的对象进行预测,比如预测一张图片属于猫还是狗。那么它所需要的特征需要更加抽象化的高层次语义特征。而语义分割任务所对应的是像素级别的预测,即预测每一个像素点属于哪一类。这种任务不仅需要语义特征,而且在此基础上还需要注重低层次的细节特征。所以说针对图片级别、区域级别和像素级别的预测任务,卷积神经网络的注重点是不一样的。
 目前而言,用于图像分类的网络,如:ResNet、DenseNet等可以直接将其作为Backbone(主干网)用于区域级和像素级的预测任务(如语义分割中,常用ResNet101作为特征提取的Backbone)。但是为语义分割、目标检测等设计的网络通常是在图像分类任务中发挥不了作用的。
 于是,作者为了设计一种通用于这些任务的卷积神经网络,设计了一种名为fishnet(因为网络的形状像一条鱼,所以命名为fishnet。写论文还是得搞点花里胡哨的东西才能中啊)的网络,它既可以用于图像分类,又可以用于目标检测、语义分割等任务。换言之,也就是说,这个网络所提取的特征是语义特征与细节特征都十分丰富的。那么下面我们来看一下,fishnet是怎么做到语义特征与细节特征并重的。

二、整体框架

 如下图,是fishnet的网络结构图。可以看到,确实,还真挺像一条鱼。

 整个网络分为三个部分从左至右分别命名为:鱼尾(fish tail)、鱼身(fish body)、鱼头(fish head)。鱼尾实际上就是一个resnet的结构,它负责提取语义特征。到了鱼身之后,开始使用上采样提升特征图的分辨率,并进行了跳层连接。这两个操作都是为了让网络拥有更多的细节特征。至此,如果你是要进行语义分割、目标检测等任务的话,就可以不用管鱼头部分了。你可以将鱼身的输出直接上采样到原图大小(到这里,实际上就是一个类似于FCN结构的网络,只是内部实现的细节有所不同)。然后,如果想要进行图像分类任务的话,就用最后的鱼头,下采样得到最后的score vector。下面详细的讲一下这三个部分:

  1. 鱼尾:一个resnet结构。具体结构如下图。值得注意的是,这里的结构据采用maxpooling进行下采样而不采用步长为2的卷积。
    -
  2. 鱼身与鱼头:详细结构如下图:  鱼尾的输出特征图经过SE block的处理后得到鱼身的输入(对应图C3)。然后将其上采样一倍后与鱼尾中对应分辨率经过Transferring Block的特征图相连。这里的Transferring Block实际上就是一个Bottleneck block。串联后送入Up-sampling & Refinement block (UR-block) 中。UR blcok顾名思义就是用来讲特征图上采样与精细化特征的。上采样我们是知道的,它对应这幅图右上角的up(.)。论文中用最近淋插值法上采样。那么怎么进行特征精细化呢?它对用M(.)与r(.)操作。其中M(.)是bottleneck block 。它将特征图的通道变为输入通道图的1/k。这里的K是个超参数,人为通过实验设定。而r(.)则是把输入特征图中的相邻k个通道求和变为一个通道。这样也得到一个通道变为输入通道图的1/k的特征图。然后对二者求和得到特征细化的结果。读到这里,你可能就理解了,所谓的特征细化其实就是一个减少通道数的过程。后续在重复上采样、串联、UR block两次后便完成了鱼身的过程。得到了一个分辨率为原图1/4大小的富含语义信息与细节信息的特征图。
      随后的鱼头的才做与鱼身中类似。只不过上采样换为下采样、UR block换为DR block。而DR block与UR block的不同之处在于:
       1)使用2x2最大池化来下采样。
       2)不使用通道缩减函数,以使得当前阶段的梯度可以直接被传送到先前的阶段。

三、代码理解

  模型主要分为三个文件:

  • **fishnet.py:**构建fishnet模型的文件。主要分为两个类和一个函数:
1、class Fish(nn.Module):封装了fishnet的主要结构
2、class FishNet(nn.Module):调用Fish类进行更高一层的封装
3、def fish(**kwargs):fishnet.py文件的对外接口,调用该函数会返回一个Fishnet类对象
  • fish_block.py:包含一个与原始resnet中经过稍微调整的bottleneck block 类。是fishnet.py文件Fish类中构建fishnet模型的重要组件。
  • net_factory.py: 整个这三个文件的对外接口。其中包含三个函数:
1、def fishnet99(**kwargs):
2、def fishnet150(**kwargs):
3、def fishnet201(**kwargs):

调用不同的函数可以返回不同模型大小的fishnet模型。
1) fishnet.py:

from __future__ import division
import torch
import math
from .fish_block import *__all__ = ['fish']class Fish(nn.Module):def __init__(self, block, num_cls=1000, num_down_sample=5, num_up_sample=3, trans_map=(2, 1, 0, 6, 5, 4),network_planes=None, num_res_blks=None, num_trans_blks=None):super(Fish, self).__init__()self.block = blockself.trans_map = trans_mapself.upsample = nn.Upsample(scale_factor=2)self.down_sample = nn.MaxPool2d(2, stride=2)self.num_cls = num_clsself.num_down = num_down_sampleself.num_up = num_up_sampleself.network_planes = network_planes[1:]self.depth = len(self.network_planes)self.num_trans_blks = num_trans_blksself.num_res_blks = num_res_blksself.fish = self._make_fish(network_planes[0])def _make_score(self, in_ch, out_ch=1000, has_pool=False):bn = nn.BatchNorm2d(in_ch)relu = nn.ReLU(inplace=True)conv_trans = nn.Conv2d(in_ch, in_ch // 2, kernel_size=1, bias=False)bn_out = nn.BatchNorm2d(in_ch // 2)conv = nn.Sequential(bn, relu, conv_trans, bn_out, relu)if has_pool:fc = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(in_ch // 2, out_ch, kernel_size=1, bias=True))else:fc = nn.Conv2d(in_ch // 2, out_ch, kernel_size=1, bias=True)return [conv, fc]def _make_se_block(self, in_ch, out_ch):bn = nn.BatchNorm2d(in_ch)sq_conv = nn.Conv2d(in_ch, out_ch // 16, kernel_size=1)ex_conv = nn.Conv2d(out_ch // 16, out_ch, kernel_size=1)return nn.Sequential(bn,nn.ReLU(inplace=True),nn.AdaptiveAvgPool2d(1),sq_conv,nn.ReLU(inplace=True),ex_conv,nn.Sigmoid())def _make_residual_block(self, inplanes, outplanes, nstage, is_up=False, k=1, dilation=1):layers = []if is_up:layers.append(self.block(inplanes, outplanes, mode='UP', dilation=dilation, k=k))else:layers.append(self.block(inplanes, outplanes, stride=1))for i in range(1, nstage):layers.append(self.block(outplanes, outplanes, stride=1, dilation=dilation))return nn.Sequential(*layers)def _make_stage(self, is_down_sample, inplanes, outplanes, n_blk, has_trans=True,has_score=False, trans_planes=0, no_sampling=False, num_trans=2, **kwargs):sample_block = []if has_score:sample_block.extend(self._make_score(outplanes, outplanes * 2, has_pool=False))if no_sampling or is_down_sample:res_block = self._make_residual_block(inplanes, outplanes, n_blk, **kwargs)else:res_block = self._make_residual_block(inplanes, outplanes, n_blk, is_up=True, **kwargs)sample_block.append(res_block)if has_trans:trans_in_planes = self.in_planes if trans_planes == 0 else trans_planessample_block.append(self._make_residual_block(trans_in_planes, trans_in_planes, num_trans))if not no_sampling and is_down_sample:sample_block.append(self.down_sample)elif not no_sampling:  # Up-Samplesample_block.append(self.upsample)return nn.ModuleList(sample_block)def _make_fish(self, in_planes):def get_trans_planes(index):map_id = self.trans_map[index-self.num_down-1] - 1p = in_planes if map_id == -1 else cated_planes[map_id]return pdef get_trans_blk(index):return self.num_trans_blks[index-self.num_down-1]def get_cur_planes(index):return self.network_planes[index]def get_blk_num(index):return self.num_res_blks[index]cated_planes, fish = [in_planes] * self.depth, []for i in range(self.depth):# even num for down-sample, odd for up-sampleis_down, has_trans, no_sampling = i not in range(self.num_down, self.num_down+self.num_up+1),\i > self.num_down, i == self.num_down# is_down, has_trans, no_sampling:True False False; True False False; True False False; False False True# False True False; False True False; False True False; True True False;True True False; True True Falsecur_planes, trans_planes, cur_blocks, num_trans =\get_cur_planes(i), get_trans_planes(i), get_blk_num(i), get_trans_blk(i)# cur_planes, trans_planes, cur_blocks, num_trans:128 64 2 1;256 64 2 1; 512 64 6 1; 512 64 2 4# 512 256 1 1; 384 128 1 1; 256 64 1 1; 320 512 1 1;832 768 2 1; 1600 512 2 4stg_args = [is_down, cated_planes[i - 1], cur_planes, cur_blocks]# inplanes:64,128,256,512,1024,512,768,512,320,832,1600if is_down or no_sampling:k, dilation = 1, 1else:k, dilation = cated_planes[i - 1] // cur_planes, 2 ** (i-self.num_down-1)sample_block = self._make_stage(*stg_args, has_trans=has_trans, trans_planes=trans_planes,has_score=(i==self.num_down), num_trans=num_trans, k=k, dilation=dilation,no_sampling=no_sampling)if i == self.depth - 1:sample_block.extend(self._make_score(cur_planes + trans_planes, out_ch=self.num_cls, has_pool=True))elif i == self.num_down:sample_block.append(nn.Sequential(self._make_se_block(cur_planes*2, cur_planes)))if i == self.num_down-1:cated_planes[i] = cur_planes * 2elif has_trans:cated_planes[i] = cur_planes + trans_planeselse:cated_planes[i] = cur_planesfish.append(sample_block)return nn.ModuleList(fish)def _fish_forward(self, all_feat):def _concat(a, b):return torch.cat([a, b], dim=1)def stage_factory(*blks):def stage_forward(*inputs):if stg_id < self.num_down:  # tailtail_blk = nn.Sequential(*blks[:2])# print(stg_id)# print(tail_blk)return tail_blk(*inputs)elif stg_id == self.num_down:score_blks = nn.Sequential(*blks[:2])score_feat = score_blks(inputs[0])att_feat = blks[3](score_feat)return blks[2](score_feat) * att_feat + att_featelse:  # refinefeat_trunk = blks[2](blks[0](inputs[0]))feat_branch = blks[1](inputs[1])return _concat(feat_trunk, feat_branch)return stage_forwardstg_id = 0# tail:while stg_id < self.depth:stg_blk = stage_factory(*self.fish[stg_id])if stg_id <= self.num_down:in_feat = [all_feat[stg_id]]else:trans_id = self.trans_map[stg_id-self.num_down-1]in_feat = [all_feat[stg_id], all_feat[trans_id]]all_feat[stg_id + 1] = stg_blk(*in_feat)stg_id += 1# loop exitif stg_id == self.depth:score_feat = self.fish[self.depth-1][-2](all_feat[-1])score = self.fish[self.depth-1][-1](score_feat)for fea in all_feat:print(fea.shape)return scoredef forward(self, x):all_feat = [None] * (self.depth + 1)all_feat[0] = xreturn self._fish_forward(all_feat)class FishNet(nn.Module):def __init__(self, block, **kwargs):super(FishNet, self).__init__()inplanes = kwargs['network_planes'][0]# resolution: 224x224self.conv1 = self._conv_bn_relu(3, inplanes // 2, stride=2)self.conv2 = self._conv_bn_relu(inplanes // 2, inplanes // 2)self.conv3 = self._conv_bn_relu(inplanes // 2, inplanes)self.pool1 = nn.MaxPool2d(3, padding=1, stride=2)# construct fish, resolution 56x56self.fish = Fish(block, **kwargs)self._init_weights()def _conv_bn_relu(self, in_ch, out_ch, stride=1):return nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, stride=stride, bias=False),nn.BatchNorm2d(out_ch),nn.ReLU(inplace=True))def _init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.data.normal_(0, math.sqrt(2. / n))elif isinstance(m, nn.BatchNorm2d):m.weight.data.fill_(1)m.bias.data.zero_()def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = self.pool1(x)# x.Size([1, 64, 56, 56])score = self.fish(x)# 1*1 outputout = score.view(x.size(0), -1)return outdef fish(**kwargs):return FishNet(Bottleneck, **kwargs)

2) fish_block.py:

import torch.nn as nnclass Bottleneck(nn.Module):def __init__(self, inplanes, planes, stride=1, mode='NORM', k=1, dilation=1):"""Pre-act residual block, the middle transformations are bottle-necked:param inplanes::param planes::param stride::param downsample::param mode: NORM | UP:param k: times of additive"""super(Bottleneck, self).__init__()self.mode = modeself.relu = nn.ReLU(inplace=True)self.k = kbtnk_ch = planes // 4self.bn1 = nn.BatchNorm2d(inplanes)self.conv1 = nn.Conv2d(inplanes, btnk_ch, kernel_size=1, bias=False)self.bn2 = nn.BatchNorm2d(btnk_ch)self.conv2 = nn.Conv2d(btnk_ch, btnk_ch, kernel_size=3, stride=stride, padding=dilation,dilation=dilation, bias=False)self.bn3 = nn.BatchNorm2d(btnk_ch)self.conv3 = nn.Conv2d(btnk_ch, planes, kernel_size=1, bias=False)if mode == 'UP':self.shortcut = Noneelif inplanes != planes or stride > 1:self.shortcut = nn.Sequential(nn.BatchNorm2d(inplanes),self.relu,nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False))else:self.shortcut = Nonedef _pre_act_forward(self, x):residual = xout = self.bn1(x)out = self.relu(out)out = self.conv1(out)out = self.bn2(out)out = self.relu(out)out = self.conv2(out)out = self.bn3(out)out = self.relu(out)out = self.conv3(out)if self.mode == 'UP':residual = self.squeeze_idt(x)elif self.shortcut is not None:residual = self.shortcut(residual)out += residualreturn outdef squeeze_idt(self, idt):n, c, h, w = idt.size()return idt.view(n, c // self.k, self.k, h, w).sum(2)def forward(self, x):out = self._pre_act_forward(x)return out

3) fish_block.py:

from models.fishnet import fish
import torchdef fishnet99(**kwargs):""":return:"""net_cfg = {#  input size:   [224, 56, 28,  14  |  7,   7,  14,  28 | 56,   28,  14]# output size:   [56,  28, 14,   7  |  7,  14,  28,  56 | 28,   14,   7]#                  |    |    |   |     |    |    |    |    |     |    |'network_planes': [64, 128, 256, 512, 512, 512, 384, 256, 320, 832, 1600],'num_res_blks': [2, 2, 6, 2, 1, 1, 1, 1, 2, 2],'num_trans_blks': [1, 1, 1, 1, 1, 4],'num_cls': 1000,'num_down_sample': 3,'num_up_sample': 3,}cfg = {**net_cfg, **kwargs}return fish(**cfg)def fishnet150(**kwargs):""":return:"""net_cfg = {#  input size:   [224, 56, 28,  14  |  7,   7,  14,  28 | 56,   28,  14]# output size:   [56,  28, 14,   7  |  7,  14,  28,  56 | 28,   14,   7]#                  |    |    |   |     |    |    |    |    |     |    |'network_planes': [64, 128, 256, 512, 512, 512, 384, 256, 320, 832, 1600],'num_res_blks': [2, 4, 8, 4, 2, 2, 2, 2, 2, 4],'num_trans_blks': [2, 2, 2, 2, 2, 4],'num_cls': 1000,'num_down_sample': 3,'num_up_sample': 3,}cfg = {**net_cfg, **kwargs}return fish(**cfg)def fishnet201(**kwargs):""":return:"""net_cfg = {#  input size:   [224, 56, 28,  14  |  7,   7,  14,  28 | 56,   28,  14]# output size:   [56,  28, 14,   7  |  7,  14,  28,  56 | 28,   14,   7]#                  |    |    |   |     |    |    |    |    |     |    |'network_planes': [64, 128, 256, 512, 512, 512, 384, 256, 320, 832, 1600],'num_res_blks': [3, 4, 12, 4, 2, 2, 2, 2, 3, 10],'num_trans_blks': [2, 2, 2, 2, 2, 9],'num_cls': 1000,'num_down_sample': 3,'num_up_sample': 3,}cfg = {**net_cfg, **kwargs}return fish(**cfg)

四、总结

  • 创造性的在类FCN的网络后再次添加了卷积神经网络。这样的处理使得用于目标检测、语义分割等任务的卷积神经网络可以用于图像分类。并且充分利用到了卷积神经网络所提取到的细节信息。
  • 在网络中不再使用孤立卷积,使得深层的梯度可以直接传递到浅层。

fishnet:论文阅读与代码理解相关推荐

  1. Self-Tuning Spectral Clustering论文阅读和代码理解

    一.代码问题 运行test_segimage.m时,存在如下错误: Building affinity matrix took 0.092672 second Error using dist2aff ...

  2. 《Scale Invariant Feature Transform on the Sphere: Theory and Applications》论文阅读和源码理解(一)

    <Scale Invariant Feature Transform on the Sphere: Theory and Applications>论文阅读和源码理解(一) 摘要 主要贡献 ...

  3. OpenCV图像处理算法——7(《Contrast image correction method》 论文阅读及代码实现)

    <Contrast image correction method> 论文阅读及代码实现 以下内容大部分引自:https://cloud.tencent.com/developer/art ...

  4. 【异构图笔记,篇章3】GATNE论文阅读笔记与理解:General Attributed Multiplex HeTerogeneous Network Embedding

    [异构图笔记,篇章3]GATNE论文阅读笔记与理解:General Attributed Multiplex HeTerogeneous Network Embedding 上期回顾 论文信息概览 论 ...

  5. [软件工程程序修复论文阅读]基于代码感知机器翻译的程序修复

    本文约2871字,预计阅读时长6分钟. 原文标题为CURE: Code-Aware Neural Machine Translation for Automatic Program Repair 论文 ...

  6. paperswithcode 论文阅读与代码复现

    Machine Learning论文阅读与复现 神奇宝贝 1.丰富的论文合集 2.丰富的数据集 3.方法合集 4.论文解析 要是有一个cs科研er不知道这个宝藏网站,我都会伤心的,OK?https:/ ...

  7. 九月学习笔记 (FM、一些论文阅读、代码)

    目录 2020.09.16 FM 因子分解机 2021.09.18 论文阅读 Interactive Recommender System via Knowledge Graph-enhanced R ...

  8. GLMP:任务型对话中全局到局部的记忆指针网络 论文阅读及代码解析

    UPDATE 11.6.2020 复习代码,修正部分内容,清晰化部分表述.如发现问题,欢迎留言讨论! 文章目录 UPDATE GLMP ABSTRACT 1.INTRODUCTION 2.GLMP M ...

  9. 风格化渲染之油画渲染:Customizing Painterly Rendering Styles Using Stroke Processes——论文阅读和个人理解

    摘要 文中提出了一种基于sketch(单个笔划)Non-photorealistic rendering的方法.效果如图 1 引言 本章主要引入了几个用于控制式样渲染的参数 Density--密度:这 ...

最新文章

  1. 快速排序 python菜鸟教程-快速排序
  2. 关于面向过程编程的一些思考
  3. [云炬ThinkPython阅读笔记]2.9 术语表
  4. Android开发之非常好用的日志工具类(公司项目挖出来的)
  5. 日志 中文乱码、nacos 中文乱码、saltstack 中文乱码、docker中文乱码
  6. BZOJ2563: 阿狸和桃子的游戏 贪心
  7. 邮件合并保存为一个个单独的文档_你还在为考计算机二级烦恼吗? 基本操作步骤分享...
  8. ORACLE 常用操作命令
  9. PLSQL导入SQL文件
  10. JanusGraph学习手册
  11. 数学期望方差 expectationvariance
  12. 高级英语(张汉熙版)第一册学习笔记(原文及全文翻译)——2 - Hiroshima-The “Liveliest“ City in Japan (excerpts)(广岛——日本“最有活力”的城市)
  13. 机器学习笔记之概率图模型(八)信念传播(Belief Propagation,BP)(基于树结构)
  14. 批量替换 Word 文档某几页
  15. oracle11监视器,Oracle 11g 表空间监控(一) datafile autoextend
  16. 报错:启动apache服务时出现报错
  17. Workbench导入xls文件
  18. SPSS Modeler 项目实战之超市商品购买关联分析
  19. 奢侈品电商,压死趣店的最后一根稻草?
  20. 国外酷站设计:15个带给你灵感的作品集网站

热门文章

  1. EDT技术 ug - 第四章节Creation of the EDT Logic (持续更新)
  2. 110kV级电力变压器系列技术参数:
  3. 销售火爆,APS自动排产提升咖啡机家电企业生产管理效益
  4. 毕业设计-剪叉式物流液压升降台的设计【论文+CAD图纸(整机图A0+液压系统图A1+液压缸A1)+开题报告+外文翻译+文献综述】
  5. [AcWing算法刷题]之DFS+BFS迷宫模板(简单)
  6. 数据仓库模型报表设计
  7. UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation
  8. 老人地摊淘到旧书 发现刊有老伴年轻照片(图)
  9. .net core平台socket调用失败 This protocol version is not supported.
  10. 米兔积木机器人与履带机甲零件差别_这只兔子有点酷—米兔积木机器人履带机甲测评...