EGNet: Edge Guidance Network for Salient Object Detection
论文及代码解读

注:本文原创作者为Jia-Xing Zhao, Jiang-Jiang Liu, Deng-Ping Fan, Yang Cao, Ju-Feng Yang, Ming-Ming Cheng* TKLNDST, CS, Nankai University
本人仅对论文和代码进行注释讲解,如有侵权请联系删除文章即可。

论文分析:
1、论文的研究初衷,为什么要研究这个课题或者研究这个课题解决了什么问题?
因为现在流行的显著性检测大都是基于深度学习而衍生出来的,所以显著性目标的检测大多存在边缘轮廓不清晰、显著性不能精确提取的问题,因此本文作者提出了一种新颖的方法,即将在深度学习过程中,利用VGG网络的特性,即第二个池化层输出的特征图具有良好的边缘信息特征,而最后一层具有丰富的显著性特征,作者将边缘信息特征与显著性特征进行像素级的融合,得到具有清晰轮廓的显著性目标。
2、论文所基于的网络框架,特征提取方式,以及边缘信息和显著性息如何进行融合?
论文是以VGG16为base net,但并不是全部进行套用,而是截取其中池化以及卷积层进行特征的提取,作者利用了U-net的思想,对每个池化层进行了side path处理,即将每个池化层的输出都通过三层卷积输出,但并不对输出的大小进行改变。因为边缘信息以及显著目标信息的输出尺寸不一致,所以作者通过上采样的方式,将显著目标的特征信息进行上采样,采样的输出与边缘特征输出大小一致,然后进行像素级的相加,得到融合后的边缘信息。
3、模块分析:
1、Complementary information modeling
以VGG为基础,删除最后三个全连接层,在最后一个pooling 层加入一个side path,然后从vgg中得到六个side features,因为第一层输入太小,不合适所以舍去第一层。将其余五个side features放到一个集合中,其中第二层的特征c2因为有更好的边缘信息所以用作提取边缘特征。
以下是该模块代码实现的部分:

vgg16


    for m in self.modules():if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.data.normal_(0, 0.01)elif isinstance(m, nn.BatchNorm2d):m.weight.data.fill_(1)m.bias.data.zero_()

2、Progressive salient object features extraction
为了得到更丰富的显著性目标特征,采用U-Net结构,并且在每一个side path后加入三个卷积层,在每层卷积层后加入ReLU层以确保得到的结果呈非线性。在最后加入了深度监督,在每个side path后加入一个卷积层,用来将得到的显著性图转换为单通道的预测标记。
以下是该模块代码实现的部分:

class MergeLayer1(nn.Module): # list_k: [[64, 512, 64], [128, 512, 128], [256, 0, 256] ... ]def __init__(self, list_k):super(MergeLayer1, self).__init__()self.list_k = list_ktrans, up, score = [], [], []for ik in list_k:if ik[1] > 0:         # 如果ik中的第二项大于0trans.append(nn.Sequential(nn.Conv2d(ik[1], ik[0], 1, 1, bias=False), nn.ReLU(inplace=True)))   #  测路径提取
        up.append(nn.Sequential(nn.Conv2d(ik[0], ik[2], ik[3], 1, ik[4]), nn.ReLU(inplace=True),nn.Conv2d(ik[2], ik[2], ik[3], 1, ik[4]), nn.ReLU(inplace=True),nn.Conv2d(ik[2], ik[2], ik[3], 1, ik[4]), nn.ReLU(inplace=True)))           # 提取后的信息经过三层卷积处理score.append(nn.Conv2d(ik[2], 1, 3, 1, 1))                                                          # 将数据转化为单通道进行输出trans.append(nn.Sequential(nn.Conv2d(512, 128, 1, 1, bias=False), nn.ReLU(inplace=True)))self.trans, self.up, self.score = nn.ModuleList(trans), nn.ModuleList(up), nn.ModuleList(score)         # 测路径提取,同上一行一致self.relu =nn.ReLU()                                                                                    # 增加relu层,使得结果呈非线性def forward(self, list_x, x_size):up_edge, up_sal, edge_feature, sal_feature = [], [], [], []num_f = len(list_x)tmp = self.up[num_f - 1](list_x[num_f-1])sal_feature.append(tmp)U_tmp = tmpup_sal.append(F.interpolate(self.score[num_f - 1](tmp), x_size, mode='bilinear', align_corners=True))for j in range(2, num_f ):i = num_f - jif list_x[i].size()[1] < U_tmp.size()[1]:U_tmp = list_x[i] + F.interpolate((self.trans[i](U_tmp)), list_x[i].size()[2:], mode='bilinear', align_corners=True)else:U_tmp = list_x[i] + F.interpolate((U_tmp), list_x[i].size()[2:], mode='bilinear', align_corners=True)tmp = self.up[i](U_tmp)U_tmp = tmpsal_feature.append(tmp)up_sal.append(F.interpolate(self.score[i](tmp), x_size, mode='bilinear', align_corners=True))U_tmp = list_x[0] + F.interpolate((self.trans[-1](sal_feature[0])), list_x[0].size()[2:], mode='bilinear', align_corners=True)tmp = self.up[0](U_tmp)edge_feature.append(tmp)up_edge.append(F.interpolate(self.score[0](tmp), x_size, mode='bilinear', align_corners=True)) return up_edge, edge_feature, up_sal, sal_feature

3、Non-local salient edge features extraction
为了得到更精确的位置信息,作者直接将最高级的特征信息融入到边缘信息中,即将本文中Conv6-3的位置特征信息融入到Conv2-2中的边缘特征中,采取的融合方式是:
其中 表示卷积层的参数 , 表示ReLU激活函数, 是双线性插值运算,旨在将*向上采样到与C(2) 相同的大小。
表示侧路径(side path) 的特征,

表示 的参数, 表示卷积和非线性参数 。
最后边缘信息特征表示为 。
以下是该模块代码实现的部分:

class MergeLayer2(nn.Module): def __init__(self, list_k):super(MergeLayer2, self).__init__()self.list_k = list_ktrans, up, score = [], [], []for i in list_k[0]:tmp = []tmp_up = []tmp_score = []feature_k = [[3,1],[5,2], [5,2], [7,3]]for idx, j in enumerate(list_k[1]):tmp.append(nn.Sequential(nn.Conv2d(j, i, 1, 1, bias=False), nn.ReLU(inplace=True)))    #  上采样到与边缘特征同样输出
            tmp_up.append(nn.Sequential(nn.Conv2d(i , i, feature_k[idx][0], 1, feature_k[idx][1]), nn.ReLU(inplace=True),nn.Conv2d(i, i,  feature_k[idx][0],1 , feature_k[idx][1]), nn.ReLU(inplace=True),nn.Conv2d(i, i, feature_k[idx][0], 1, feature_k[idx][1]), nn.ReLU(inplace=True)))   # 三层卷积tmp_score.append(nn.Conv2d(i, 1, 3, 1, 1))           # 单通道输出trans.append(nn.ModuleList(tmp))up.append(nn.ModuleList(tmp_up))score.append(nn.ModuleList(tmp_score))self.trans, self.up, self.score = nn.ModuleList(trans), nn.ModuleList(up), nn.ModuleList(score)       self.final_score = nn.Sequential(nn.Conv2d(list_k[0][0], list_k[0][0], 5, 1, 2), nn.ReLU(inplace=True), nn.Conv2d(list_k[0][0], 1, 3, 1, 1))  # 最终合成一维输出self.relu =nn.ReLU()

4、One-to-one guidance module
在这个module中,作者又提出了一个sub-side path,得到的显著性特征表示为:

注:该式表达的意思为:将side path中的特征信息上采样到与FE相同输出维度,然后进行像素级的相加。
以下是该模块代码实现的部分:

def forward(self, list_x, list_y, x_size):up_score, tmp_feature = [], []list_y = list_y[::-1]
for i, i_x in enumerate(list_x):for j, j_x in enumerate(list_y):                              tmp = F.interpolate(self.trans[i][j](j_x), i_x.size()[2:], mode='bilinear', align_corners=True) + i_x                tmp_f = self.up[i][j](tmp)             up_score.append(F.interpolate(self.score[i][j](tmp_f), x_size, mode='bilinear', align_corners=True))                  tmp_feature.append(tmp_f)tmp_fea = tmp_feature[0]
for i_fea in range(len(tmp_feature) - 1):tmp_fea = self.relu(torch.add(tmp_fea, F.interpolate((tmp_feature[i_fea+1]), tmp_feature[0].size()[2:], mode='bilinear', align_corners=True)))
up_score.append(F.interpolate(self.final_score(tmp_fea), x_size, mode='bilinear', align_corners=True))

代码分析:
一、训练时代码所需修改的地方:
1:

vgg_path = 'D:/date set/vgg16_20M.pth'
resnet_path = 'D:/date set/resnet50_caffe.pth'

修改路径,原文作者在上传代码时也有上传这两个文件,点击链接下载后,将代码中的路径修改为该文件所在路径即可

2:
parser.add_argument('--epoch', type=int, default=1)

原文代码此处为30个epoch,为方便测试可以在此处将默认值改为1,后续调通代码进行测试可以改为自己所需的值

3:

   parser.add_argument('--mode', type=str, default='test', choices=['train', 'test'])

该行代码可以用来选择模式,train或者test,train表示训练,test表示测试
4

base_model_cfg = 'vgg'

在solver.py文件中找到上述代码,可以进行选择base net为VGG16或者rest net。
5

        self.sal_root = 'D:/date set/DUTS-TR'                             #训练数据集的位置self.sal_source = 'D:/date set/DUTS-TR/train_pair_edge.lst'       #训练数据集的清单

在dateset.py文件中找到上述代码,根据自己数据集所在位置修改路径即可。
6:

  if i % 200 == 0:vutils.save_image(torch.sigmoid(up_sal_f[-1].data), tmp_path+'/iter%d-sal-0.jpg' % i, normalize=True, padding = 0)#vutils.save_image(up_sal_f[-1].data, tmp_path + '/iter%d-sal-0.jpg' % i,padding = 0)vutils.save_image(sal_image.data, tmp_path+'/iter%d-sal-data.jpg' % i, padding = 0)#print(os.path.abspath(sal_image))vutils.save_image(sal_label.data, tmp_path+'/iter%d-sal-target.jpg' % i, padding = 0)

在dateset.py文件中找到上述代码,可以根据自己所需修改每n次进行打印一次图片。
二、测试时需要修改的地方
1:

    parser.add_argument('--model', type=str, default='D:/date set/epoch_resnet.pth')parser.add_argument('--test_fold', type=str, default='D:/date set/results/test')parser.add_argument('--test_mode', type=int, default=1)parser.add_argument('--sal_mode', type=str, default='p')# Miscparser.add_argument('--mode', type=str, default='test', choices=['train', 'test'])

根据自己的文件位置找到对应路径,将第一二行代码的路径进行修改即可。
第三行和第四行是选择测试模式可以根据自己所需进行修改默认值,也可以将此处的默认值去掉,在dateset.py文件中修改以下默认值,(注意,如果在前面未修改默认值而直接在下述代码修改默认值是无效的,运行代码时默认为前述默认值)

class ImageDataTest(data.Dataset):def __init__(self, test_mode=1, sal_mode='p'):

2:

class ImageDataTest(data.Dataset):def __init__(self, test_mode=1, sal_mode='p'):if test_mode == 0:# self.image_root = '/home/liuj/dataset/saliency_test/ECSSD/Imgs/'# self.image_source = '/home/liuj/dataset/saliency_test/ECSSD/test.lst'self.image_root = 'D:/date set/LFSD/test_images/'self.image_source = 'D:/date set/LFSD/test.lst'elif test_mode == 1:if sal_mode == 'e':self.image_root = 'D:/date set/saliency_test/ECSSD/Imgs/'self.image_source = 'D:/date set/saliency_test/ECSSD/test.lst'self.test_fold = '/media/ubuntu/disk/Result/saliency/ECSSD/'elif sal_mode == 'p':self.image_root = 'D:/date set/LFSD/test_images/'self.image_source = 'D:/date set/LFSD/test.lst'self.test_fold = 'D:/date set/LFSD/'elif sal_mode == 'd':self.image_root = 'D:/date set/saliency_test/DUTOMRON/Imgs/'self.image_source = 'D:/date set/saliency_test/DUTOMRON/test.lst'self.test_fold = '/media/ubuntu/disk/Result/saliency/DUTOMRON/'elif sal_mode == 'h':self.image_root = 'D:/date set/saliency_test/HKU-IS/Imgs/'self.image_source = 'D:/date set/saliency_test/HKU-IS/test.lst'self.test_fold = '/media/ubuntu/disk/Result/saliency/HKU-IS/'elif sal_mode == 's':self.image_root = 'D:/date set/saliency_test/SOD/Imgs/'self.image_source = 'D:/date set/saliency_test/SOD/test.lst'self.test_fold = '/media/ubuntu/disk/Result/saliency/SOD/'elif sal_mode == 'm':self.image_root = 'D:/date set/saliency_test/MSRA/Imgs/'self.image_source = 'D:/date set/saliency_test/MSRA/test.lst'elif sal_mode == 'o':self.image_root = 'D:/date set/saliency_test/SOC/TestSet/Imgs/'self.image_source = 'D:/date set/saliency_test/SOC/TestSet/test.lst'self.test_fold = '/media/ubuntu/disk/Result/saliency/SOC/'elif sal_mode == 't':self.image_root = 'D:/date set/DUTS/DUTS-TE/DUTS-TE-Image/'self.image_source = 'D:/date set/DUTS/DUTS-TE/test.lst'self.test_fold = '/media/ubuntu/disk/Result/saliency/DUTS/'elif test_mode == 2:self.image_root = '/home/liuj/dataset/SK-LARGE/images/test/'self.image_source = '/home/liuj/dataset/SK-LARGE/test.lst'

修改此处代码的路径,将路径改为自己电脑中数据集的路径即可。
此处需要一个self.image_source 文件,百度了许多没有有效的生成代码,因此自己进行编写了一个代码,可以生成列表,但是在运行此代码之前需要将数据集中图片名称修改成统一的命名格式,自己测试过可以使用。

import osRoot = 'D:\\date set\ECSSD\\images\\'                 #数据集路径
Dest = open('D:\\date set\ECSSD\\ECSSD.lst','w+')     #创建的lst文件路径,w+表示可以写的文件形式for (root, dirs, files) in os.walk(Root):             #遍历数据集目录for i in files:                                   #便利数据集图片的名称print(i)Dest.write(i+'\n')                            #将图片名称写入lst文件

后续还会跟进,如有相同一起学习本文的,可以留言进行交流

EGNet: Edge Guidance Network for Salient Object Detection 论文及代码解读相关推荐

  1. EGNet: Edge Guidance Network for Salient Object Detection

    论文主要解决的问题: 全卷积神经网络(FCNs)在突出的目标检测任务中显示出了其优势.然而,大多数现有的基于fcns的方法仍然存在粗糙的对象边界.与基于区域的方法相比,像素级显著目标检测方法具有优势. ...

  2. 【论文笔记】Multi-Content Complementation Network for Salient Object Detection in Optical RSI

    论文 论文:Multi-Content Complementation Network for Salient Object Detection in Optical Remote Sensing I ...

  3. ECCV 2020预会议 直播笔记| Suppress and Balance: A Simple Gated Network for Salient Object Detection

    目标跟踪基础与智能前沿 寻找 目标跟踪方向的小伙伴,如果你苦于没有地方可以和同方向的小伙伴交流,我们创建了一个交流群,点上方链接可以进入,每周的交流活动通过该号宣传,群里随时随地可以展开讨论,无论是学 ...

  4. 分析显著性目标检测--Global Context-Aware Progressive Aggregation Network for Salient Object Detection

    分析显著性目标检测--Global Context-Aware Progressive Aggregation Network for Salient Object Detection 引入 方法 网 ...

  5. 显著性目标检测之Global Context-Aware Progressive Aggregation Network for Salient Object Detection

    Global Context-Aware Progressive Aggregation Network for Salient Object Detection 文章目录 Global Contex ...

  6. Lightweight Adversarial Network for Salient Object Detection

    Abstract 作者提出了一种用于显着目标检测(salient object detection)的轻量级对抗网络,该网络通过进行对抗性训练来实现更高阶的空间一致性,并分别通过轻量级bottlene ...

  7. Multi-scale Interactive Network for Salient Object Detection(用于显著性目标检测的多尺度交互网络)

    Abstract 基于深度学习的显著性目标检测方法取得了很大的进步,然而,物体的尺度变化和类别的未知一直是显著性目标检测任务的挑战,这些与多层次和多尺度特征的利用紧密相关.在本文中,提出了聚合交互模块 ...

  8. Motion Guided Attention for Video Salient Object Detection论文详读

    abstract 视频显著目标检测的主要目的是检测出视频中视觉上最突出.最独特的目标,现有的方法没有获取和使用视频中的运动线索,或忽略了光流图像中的空间上下文. 本文的方法使用两个子网络分别实现两个子 ...

  9. OTA: Optimal Transport Assignment for Object Detection 原理与代码解读

    paper:OTA: Optimal Transport Assignment for Object Detection code:https://github.com/Megvii-BaseDete ...

最新文章

  1. jvm的那些设置参数你都知道吗
  2. 电脑主板线路连接图解_电工速学手册:306页现场电工全能图解,实用技术精选大合集!...
  3. moveTaskToback退后台
  4. python dendrogram_【聚类分析】《数学建模算法与应用》第十章 多元分析 第一节 聚类分析 python实现...
  5. linux2.4内核下载,升级到Linux 2.4内核
  6. excel2010设置列宽为像素_职场新手都能学会的Excel技巧:快速调整行高、列宽
  7. [leetcode] 题型整理之排列组合
  8. java中的垃圾收集器_Java中的垃圾收集
  9. Densenet论文解读 深度学习领域论文分析博主
  10. java-jdk环境下载
  11. vue富文本编辑器组件
  12. 电脑安装有道后打开word文档很慢
  13. 微信小程序开发之组件official-account(配置公众号关注组件)
  14. Kotlin Sealed 是什么?为什么 Google 都用
  15. python获取arduino数据可视化_Arduino数据可视化在实验教学中的应用
  16. cadence 画电路图时出现绿色的倒三角
  17. javamail发送邮件到qq邮箱图片不能显示问题
  18. 函数指针的强制类型转换与void指针
  19. 腾讯cos做文件服务器,将腾讯云COS对象存储挂载至腾讯云服务器实现大硬盘存储...
  20. http协议的状态码(statue) / readyState状态码

热门文章

  1. 《操作系统真象还原》第七章
  2. 嵌入式实训 智能家居项目
  3. Qt如何读取.txt文件(将内容读到文本编辑框)
  4. 一个简洁的斐波那契求法和它的简单应用
  5. Joggler的MeeGo系统移植
  6. 飞桨分布式训练又推新品,4D混合并行可训千亿级AI模型
  7. wolframalpha最新版_WolframAlpha安卓版中文最新版
  8. Activiti的配置文件
  9. CTF 每日一题 Day28 异性相吸
  10. 【每日一读】Deep Variational Network Embedding in Wasserstein Space