SegNet 的应用

SegNet常用于图像的语义分割。什么是语义分割了?,我们知道图像分割大致可以划分为三类,一类是语义分割、一类是实例分割,一类是全景分割,另外还有一些可以归为超像素分割。打个比方,如果是有一群人的沙滩排球这样的一个场景,图像中有一群人,有蓝天,大海,沙滩,还有一些椰子树。语义分割就是将人从这张图中分出来,其他的全部认为是背景;实例分割,就是不仅仅要把人分出来,还要区别不同的人;全景分割就是不仅仅要将人区分出来,而且不同椰子树区分出来,蓝天区分出来,总之不同类别的都区分出来。下图就是SegNet进行分割效果的图

SegNet原理

SegNet网络是最开始明确定义ecoder端和decoder端,它的ecoder端是使用Vgg16,总计使用了Vgg16的13个卷积层,相对于采用比较典型的ecoder用来提取图像特征,SegNet的改进侧重点在于设计优良的decoder端,decoder端将pooling indices技术应用在max pooling过程中来连接encoder的输出做一个非线性的上采样,根据SegNet这篇文章的说明,使用SegNet在计算机消耗资源以及预测分类的准确性上取得较好的平衡。
SegNet相对于全卷积分割网络的重要改进是在pooling indices上,具体体现在encoder和decoder过程中,示例如下。

SegNet 网络结构


注意看一下,这里相对于Unet是有一个差别的,encoder和decoder之间skip 连接的不是tensor而是文章中所定义的pooling indices

这个图中其他部分比较好理解,根据文章的介绍,pooling过程需要较好的审视,以及upsampling过程。

在encoder端,文章记载使用的是简单的max pooling 2x2的窗宽,步长为2,毫无疑问的是采用max pooling过后的图像会变小,而且按照这样的参数配比,会变为原图的一半,同时,这会增加不变性的东西,就比如大尺度的背景,与之伴随的就是空间分辨率的下降。毫无疑问这种空间分辨率的下降对于边界的勾画是不利的,在文章的论述中,作者认为需要对这种边界信息进行存储保留,比如作者将每一个feature map中每一个pooling 窗口中最大值的localtion记录了下来;简而言之在encoder端口除了常规的conv、bn、maxpooling以外,作者将pooling窗口中的最大值location记录了下来

在decoder端,采用的也是正常的卷积、bn、relu的操作,不同的地方在于最后加上了一个softmax用来整理输出的k个channel的概率图,k是指定的分割类别数目。但是与众不同的地方在于SegNet在上采样过程中使用了在encoder端所获得的pooling indices,用这个来指导上采样过程,而不是同一般的卷积网络那样直接采用了一个卷积过程进行上采样,这样的上采样后的每一个pixel 位置都是上采样前输入图像的一个加权平均。具体情况如下所示

从这里就可以看出,使用indics的上采样,是一个不变input tensor pixel 相对位置的填充,这个也就是所谓的pixel-wise过程,按照论文的记载,这样的做法能够有一定程度的边界信息保留。

我个人的看法

SegNet首先相对于Unet或者其他需要传递feature map的网络来说,一般而言还是能够减少很多decoder模块的权重的,这也就是符合文章作者提到的在计算机消耗和accuracy中取得平衡,虽然我本人并没有在大数据集上进行过SegNet的准确性测试,但是从一些文章来看,大家普遍认为SegNet的预测精准性比较高[1]。但是我并不推荐在医学图像这样精度要求非常严格的领域内使用,相对于pooling indices,直接传送feature map的信息量还是高很多。

代码

class SegNet(nn.Module):def __init__(self,input_nbr,label_nbr):super(SegNet, self).__init__()batchNorm_momentum = 0.1self.conv11 = nn.Conv2d(input_nbr, 64, kernel_size=3, padding=1)self.bn11 = nn.BatchNorm2d(64, momentum= batchNorm_momentum)self.conv12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)self.bn12 = nn.BatchNorm2d(64, momentum= batchNorm_momentum)self.conv21 = nn.Conv2d(64, 128, kernel_size=3, padding=1)self.bn21 = nn.BatchNorm2d(128, momentum= batchNorm_momentum)self.conv22 = nn.Conv2d(128, 128, kernel_size=3, padding=1)self.bn22 = nn.BatchNorm2d(128, momentum= batchNorm_momentum)self.conv31 = nn.Conv2d(128, 256, kernel_size=3, padding=1)self.bn31 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)self.conv32 = nn.Conv2d(256, 256, kernel_size=3, padding=1)self.bn32 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)self.conv33 = nn.Conv2d(256, 256, kernel_size=3, padding=1)self.bn33 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)self.conv41 = nn.Conv2d(256, 512, kernel_size=3, padding=1)self.bn41 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)self.conv42 = nn.Conv2d(512, 512, kernel_size=3, padding=1)self.bn42 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)self.conv43 = nn.Conv2d(512, 512, kernel_size=3, padding=1)self.bn43 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)self.conv51 = nn.Conv2d(512, 512, kernel_size=3, padding=1)self.bn51 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)self.conv52 = nn.Conv2d(512, 512, kernel_size=3, padding=1)self.bn52 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)self.conv53 = nn.Conv2d(512, 512, kernel_size=3, padding=1)self.bn53 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)self.conv53d = nn.Conv2d(512, 512, kernel_size=3, padding=1)self.bn53d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)self.conv52d = nn.Conv2d(512, 512, kernel_size=3, padding=1)self.bn52d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)self.conv51d = nn.Conv2d(512, 512, kernel_size=3, padding=1)self.bn51d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)self.conv43d = nn.Conv2d(512, 512, kernel_size=3, padding=1)self.bn43d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)self.conv42d = nn.Conv2d(512, 512, kernel_size=3, padding=1)self.bn42d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)self.conv41d = nn.Conv2d(512, 256, kernel_size=3, padding=1)self.bn41d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)self.conv33d = nn.Conv2d(256, 256, kernel_size=3, padding=1)self.bn33d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)self.conv32d = nn.Conv2d(256, 256, kernel_size=3, padding=1)self.bn32d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)self.conv31d = nn.Conv2d(256,  128, kernel_size=3, padding=1)self.bn31d = nn.BatchNorm2d(128, momentum= batchNorm_momentum)self.conv22d = nn.Conv2d(128, 128, kernel_size=3, padding=1)self.bn22d = nn.BatchNorm2d(128, momentum= batchNorm_momentum)self.conv21d = nn.Conv2d(128, 64, kernel_size=3, padding=1)self.bn21d = nn.BatchNorm2d(64, momentum= batchNorm_momentum)self.conv12d = nn.Conv2d(64, 64, kernel_size=3, padding=1)self.bn12d = nn.BatchNorm2d(64, momentum= batchNorm_momentum)self.conv11d = nn.Conv2d(64, label_nbr, kernel_size=3, padding=1)def forward(self, x):# Stage 1x11 = F.relu(self.bn11(self.conv11(x)))x12 = F.relu(self.bn12(self.conv12(x11)))x1p, id1 = F.max_pool2d(x12,kernel_size=2, stride=2,return_indices=True)# Stage 2x21 = F.relu(self.bn21(self.conv21(x1p)))x22 = F.relu(self.bn22(self.conv22(x21)))x2p, id2 = F.max_pool2d(x22,kernel_size=2, stride=2,return_indices=True)# Stage 3x31 = F.relu(self.bn31(self.conv31(x2p)))x32 = F.relu(self.bn32(self.conv32(x31)))x33 = F.relu(self.bn33(self.conv33(x32)))x3p, id3 = F.max_pool2d(x33,kernel_size=2, stride=2,return_indices=True)# Stage 4x41 = F.relu(self.bn41(self.conv41(x3p)))x42 = F.relu(self.bn42(self.conv42(x41)))x43 = F.relu(self.bn43(self.conv43(x42)))x4p, id4 = F.max_pool2d(x43,kernel_size=2, stride=2,return_indices=True)# Stage 5x51 = F.relu(self.bn51(self.conv51(x4p)))x52 = F.relu(self.bn52(self.conv52(x51)))x53 = F.relu(self.bn53(self.conv53(x52)))x5p, id5 = F.max_pool2d(x53,kernel_size=2, stride=2,return_indices=True)# Stage 5dx5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=2)x53d = F.relu(self.bn53d(self.conv53d(x5d)))x52d = F.relu(self.bn52d(self.conv52d(x53d)))x51d = F.relu(self.bn51d(self.conv51d(x52d)))# Stage 4dx4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2)x43d = F.relu(self.bn43d(self.conv43d(x4d)))x42d = F.relu(self.bn42d(self.conv42d(x43d)))x41d = F.relu(self.bn41d(self.conv41d(x42d)))# Stage 3dx3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2)x33d = F.relu(self.bn33d(self.conv33d(x3d)))x32d = F.relu(self.bn32d(self.conv32d(x33d)))x31d = F.relu(self.bn31d(self.conv31d(x32d)))# Stage 2dx2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2)x22d = F.relu(self.bn22d(self.conv22d(x2d)))x21d = F.relu(self.bn21d(self.conv21d(x22d)))# Stage 1dx1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2)x12d = F.relu(self.bn12d(self.conv12d(x1d)))x11d = self.conv11d(x12d)return x11ddef load_from_segnet(self, model_path):s_dict = self.state_dict()# create a copy of the state dictth = torch.load(model_path).state_dict() # load the weigths# for name in th:# s_dict[corresp_name[name]] = th[name]self.load_state_dict(th)

参考文献

[1]. A Review on Deep learining Techniques Applied to Semantic Segmentation
SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation

SegNet学习笔记(附Pytorch 代码)相关推荐

  1. mapbox 修改初始位置_一行代码教你如何随心所欲初始化Bert参数(附Pytorch代码详细解读)...

    微信公众号:NLP从入门到放弃 微信文章在这里(排版更漂亮,但是内置链接不太行,看大家喜欢哪个点哪个看吧): 一行代码带你随心所欲重新初始化bert的参数(附Pytorch代码详细解读)​mp.wei ...

  2. JUC.Condition学习笔记[附详细源码解析]

    JUC.Condition学习笔记[附详细源码解析] 目录 Condition的概念 大体实现流程 I.初始化状态 II.await()操作 III.signal()操作 3个主要方法 Conditi ...

  3. 【学习笔记】低代码平台(LCAP:Low-Code Application Platform)

    学习笔记:低代码平台(LCAP:Low-Code Application Platform) [概念] 开发者写很少的代码,通过低代码平台提供的界面.逻辑.对象.流程等可视化编排工具来完成大量的开发工 ...

  4. 吴恩达《机器学习》学习笔记十一——神经网络代码

    吴恩达<机器学习>学习笔记十一--神经网络代码 数据准备 神经网络结构与代价函数· 初始化设置 反向传播算法 训练网络与验证 课程链接:https://www.bilibili.com/v ...

  5. PyTorch学习笔记:PyTorch初体验

    PyTorch学习笔记:PyTorch初体验 一.在Anaconda里安装PyTorch 1.进入虚拟环境mlcc 2.安装PyTorch 二.在PyTorch创建张量 1.启动mlcc环境下的Spy ...

  6. 【学习笔记】Pytorch深度学习—Batch Normalization

    [学习笔记]Pytorch深度学习-Batch Normalization Batch Normalization概念 `Batch Normalization ` `Batch Normalizat ...

  7. 卷起来了,写了一套Tensorflow和Pytorch的学习笔记(20G/代码/PPT/视频)

    作为一名AI工程师,掌握一门深度学习框架是必备的生存技能之一. 谷歌的 Tensorflow 与 Facebook 的 PyTorch 一直是颇受社区欢迎的两种深度学习框架. 我们通过调研发现,80% ...

  8. ResNet论文笔记及Pytorch代码解析

    注:个人学习记录 感谢B站up主"同济子豪兄"的精彩讲解,参考视频的记录 [精读AI论文]ResNet深度残差网络_哔哩哔哩_bilibili 算法的意义(大概介绍) CV史上的技 ...

  9. pytorch adagrad_【学习笔记】Pytorch深度学习—优化器(二)

    点击文末 阅读原文,体验感更好哦! 前面学习过了Pytorch中优化器optimizer的基本属性和方法,优化器optimizer的主要功能是 "管理模型中的可学习参数,并利用参数的梯度gr ...

最新文章

  1. Camera HDR Algorithms
  2. 你们这行我懂,不给点好处都不接!
  3. Android前沿技术
  4. Apache PDFBox 存在高危 XXE 漏洞,建议升级至 2.0.15
  5. 按键改变元素背景颜色 链式编程的原理 评分案例 each方法的使用
  6. [POJ2151]Check the difficulty of problems(概率DP)
  7. MySQL和Mariadb都启动不了了_linux centos7mariadb安装成功启动不了 解决思路
  8. editview只输入英文_搜狗输入法Mac版更新:适配苹果M1处理器
  9. 网络协议:TCP拥塞控制
  10. jQuery选择器经典案例
  11. mui实现分享功能_继MIUI之后,华为EMUI更新,深度实现万物互联
  12. Spark1.0.0 应用程序部署工具spark-submit
  13. mysql图书管理系统设计答辩_基于微信的图书管理系统毕业论文+任务书+开题报告+答辩PPT+前后台(Java+Mysql)源码及数据库文件...
  14. DevCon 5 2019 活动照片
  15. (简易版)c语言人机对战五子棋
  16. mysql 报ERROR 1840 (HY000) at line 24: @@GLOBAL.GTID_PURGED can only be set when @@GLOBAL.GTID_EXECUT
  17. SpringBoot 发送电子邮件
  18. 2021-BUPT计组课设硬布线控制器
  19. 机器学习笔记 - 特征向量和特征值
  20. 手机备份到底备份什么

热门文章

  1. CSP CCF: 201903-2 二十四点 (C++)
  2. JeMalloc 内存分配器 简介
  3. 一款带ai基因的向导般生成ppt的神奇网站
  4. javacpp 人脸_javacv人脸识别项目源码
  5. 与机房收费系统重相见
  6. Apache Maven 环境变量的配置
  7. python emf转gif_将EMF/WMF文件转换为PNG/JPG
  8. 红外循迹TCRT5000 舵机SG90
  9. MySQL字符串是怎么截取substring函数的?
  10. Windows 系统双网卡冲突