1

前言

本文属于Pytorch深度学习语义分割系列教程。

该系列文章的内容有:

  • Pytorch的基本使用

  • 语义分割算法讲解

本文的开发环境如下:

  • 开发环境:Windows

  • 开发语言:Python3.7.4

  • 框架版本:Pytorch1.3.0

  • CUDA:10.2

  • cuDNN:7.6.0

本文主要讲解UNet网络结构,以及相应代码的代码编写

2

UNet网络结构

在语义分割领域,基于深度学习的语义分割算法开山之作是FCN(Fully Convolutional Networks for Semantic Segmentation),而UNet是遵循FCN的原理,并进行了相应的改进,使其适应小样本的简单分割问题。

UNet论文地址:https://arxiv.org/pdf/1505.04597.pdf

研究一个深度学习算法,可以先看网络结构,看懂网络结构后,再Loss计算方法、训练方法等。本文主要针对UNet的网络结构进行讲解,其它内容会在后续章节进行说明。

1、网络结构原理

UNet最早发表在2015的MICCAI会议上,4年多的时间,论文引用量已经达到了9700多次。

UNet成为了大多做医疗影像语义分割任务的baseline,同时也启发了大量研究者对于U型网络结构的研究,发表了一批基于UNet网络结构的改进方法的论文。

UNet网络结构,最主要的两个特点是:U型网络结构和Skip Connection跳层连接。

UNet是一个对称的网络结构,左侧为下采样,右侧为上采样。

按照功能可以将左侧的一系列下采样操作称为encoder,将右侧的一系列上采样操作称为decoder。

Skip Connection中间四条灰色的平行线,Skip Connection就是在上采样的过程中,融合下采样过过程中的feature map。

Skip Connection用到的融合的操作也很简单,就是将feature map的通道进行叠加,俗称Concat。

Concat操作也很好理解,举个例子:一本大小为10cm*10cm,厚度为3cm的书A,和一本大小为10cm*10cm,厚度为4cm的书B。

将书A和书B,边缘对齐地摞在一起。这样就得到了,大小为10cm*10cm厚度为7cm的一摞书,类似这种:

这种“摞在一起”的操作,就是Concat。

同样道理,对于feature map,一个大小为256*256*64的feature map,即feature map的w(宽)为256,h(高)为256,c(通道数)为64。和一个大小为256*256*32的feature map进行Concat融合,就会得到一个大小为256*256*96的feature map。

在实际使用中,Concat融合的两个feature map的大小不一定相同,例如256*256*64的feature map和240*240*32的feature map进行Concat。

这种时候,就有两种办法:

第一种:将大256*256*64的feature map进行裁剪,裁剪为240*240*64的feature map,比如上下左右,各舍弃8 pixel,裁剪后再进行Concat,得到240*240*96的feature map。

第二种:将小240*240*32的feature map进行padding操作,padding为256*256*32的feature map,比如上下左右,各补8 pixel,padding后再进行Concat,得到256*256*96的feature map。

UNet采用的Concat方案就是第二种,将小的feature map进行padding,padding的方式是补0,一种常规的常量填充。

2、代码

有些朋友可能对Pytorch不太了解,推荐一个快速入门的官方教程。一个小时,你就可以掌握一些基本概念和Pytorch代码编写方法。

Pytorch官方基础:https://github.com/yunjey/pytorch-tutorial

我们将整个UNet网络拆分为多个模块进行讲解。

DoubleConv模块:

先看下连续两次的卷积操作。

从UNet网络中可以看出,不管是下采样过程还是上采样过程,每一层都会连续进行两次卷积操作,这种操作在UNet网络中重复很多次,可以单独写一个DoubleConv模块:

import torch.nn as nnclass DoubleConv(nn.Module):"""(convolution => [BN] => ReLU) * 2"""def __init__(self, in_channels, out_channels):super().__init__()self.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self, x):return self.double_conv(x)

解释下,上述的Pytorch代码:torch.nn.Sequential是一个时序容器,Modules 会以它们传入的顺序被添加到容器中。比如上述代码的操作顺序:卷积->BN->ReLU->卷积->BN->ReLU。

DoubleConv模块的in_channels和out_channels可以灵活设定,以便扩展使用。

如上图所示的网络,in_channels设为1,out_channels为64。

输入图片大小为572*572,经过步长为1,padding为0的3*3卷积,得到570*570的feature map,再经过一次卷积得到568*568的feature map。

计算公式:O=(H−F+2×P)/S+1

H为输入feature map的大小,O为输出feature map的大小,F为卷积核的大小,P为padding的大小,S为步长。

Down模块:

UNet网络一共有4次下采样过程,模块化代码如下:

class Down(nn.Module):"""Downscaling with maxpool then double conv"""def __init__(self, in_channels, out_channels):super().__init__()self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2),DoubleConv(in_channels, out_channels))def forward(self, x):return self.maxpool_conv(x)

这里的代码很简单,就是一个maxpool池化层,进行下采样,然后接一个DoubleConv模块。

至此,UNet网络的左半部分的下采样过程的代码都写好了,接下来是右半部分的上采样过程

Up模块:

上采样过程用到的最多的当然就是上采样了,除了常规的上采样操作,还有进行特征的融合。

这块的代码实现起来也稍复杂一些:

class Up(nn.Module):"""Upscaling then double conv"""def __init__(self, in_channels, out_channels, bilinear=True):super().__init__()# if bilinear, use the normal convolutions to reduce the number of channelsif bilinear:self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)else:self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels)def forward(self, x1, x2):x1 = self.up(x1)# input is CHWdiffY = x2.size()[2] - x1.size()[2]diffX = x2.size()[3] - x1.size()[3]x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])# if you have padding issues, see# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bdx = torch.cat([x2, x1], dim=1)return self.conv(x)

代码复杂一些,我们可以分开来看,首先是__init__初始化函数里定义的上采样方法以及卷积采用DoubleConv。上采样,定义了两种方法:Upsample和ConvTranspose2d,也就是双线性插值反卷积

双线性插值很好理解,示意图:

熟悉双线性插值的朋友对于这幅图应该不陌生,简单地讲:已知Q11、Q12、Q21、Q22四个点坐标,通过Q11和Q21求R1,再通过Q12和Q22求R2,最后通过R1和R2求P,这个过程就是双线性插值。

对于一个feature map而言,其实就是在像素点中间补点,补的点的值是多少,是由相邻像素点的值决定的。

反卷积,顾名思义,就是反着卷积。卷积是让featuer map越来越小,反卷积就是让feature map越来越大,示意图:

下面蓝色为原始图片,周围白色的虚线方块为padding结果,通常为0,上面绿色为卷积后的图片。

这个示意图,就是一个从2*2的feature map->4*4的feature map过程。

在forward前向传播函数中,x1接收的是上采样的数据,x2接收的是特征融合的数据。特征融合方法就是,上文提到的,先对小的feature map进行padding,再进行concat。

OutConv模块:

用上述的DoubleConv模块、Down模块、Up模块就可以拼出UNet的主体网络结构了。UNet网络的输出需要根据分割数量,整合输出通道,结果如下图所示:

操作很简单,就是channel的变换,上图展示的是分类为2的情况(通道为2)。

虽然这个操作很简单,也就调用一次,为了美观整洁,也封装一下吧。

class OutConv(nn.Module):def __init__(self, in_channels, out_channels):super(OutConv, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):return self.conv(x)

至此,UNet网络用到的模块都已经写好,我们可以将上述的模块代码都放到一个unet_parts.py文件里,然后再创建unet_model.py,根据UNet网络结构,设置每个模块的输入输出通道个数以及调用顺序,编写如下代码:

""" Full assembly of the parts to form the complete network """
"""Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py"""import torch.nn.functional as Ffrom unet_parts import *class UNet(nn.Module):def __init__(self, n_channels, n_classes, bilinear=False):super(UNet, self).__init__()self.n_channels = n_channelsself.n_classes = n_classesself.bilinear = bilinearself.inc = DoubleConv(n_channels, 64)self.down1 = Down(64, 128)self.down2 = Down(128, 256)self.down3 = Down(256, 512)self.down4 = Down(512, 1024)self.up1 = Up(1024, 512, bilinear)self.up2 = Up(512, 256, bilinear)self.up3 = Up(256, 128, bilinear)self.up4 = Up(128, 64, bilinear)self.outc = OutConv(64, n_classes)def forward(self, x):x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)logits = self.outc(x)return logitsif __name__ == '__main__':net = UNet(n_channels=3, n_classes=1)print(net)

使用命令python unet_model.py,如果没有错误,你会得到如下结果:

UNet((inc): DoubleConv((double_conv): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True)))(down1): Down((maxpool_conv): Sequential((0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(1): DoubleConv((double_conv): Sequential((0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True)))))(down2): Down((maxpool_conv): Sequential((0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(1): DoubleConv((double_conv): Sequential((0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True)))))(down3): Down((maxpool_conv): Sequential((0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(1): DoubleConv((double_conv): Sequential((0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True)))))(down4): Down((maxpool_conv): Sequential((0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(1): DoubleConv((double_conv): Sequential((0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1))(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))(4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True)))))(up1): Up((up): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))(conv): DoubleConv((double_conv): Sequential((0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1))(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))))(up2): Up((up): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))(conv): DoubleConv((double_conv): Sequential((0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1))(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))))(up3): Up((up): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))(conv): DoubleConv((double_conv): Sequential((0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))))(up4): Up((up): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))(conv): DoubleConv((double_conv): Sequential((0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))))(outc): OutConv((conv): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1)))
)

网络搭建完成,下一步就是使用网络进行训练了,具体实现会在该系列教程的下一篇文章进行讲解。

3

小结

  • 本文主要讲解了UNet网络结构,并对UNet网络进行了模块化梳理。

  • 下篇文章讲解如何使用UNet网络,编写训练代码。

【推荐原创干货阅读】
2019年的个人总结和2020年的一些展望
【资源分享】对于时间序列,你所能做的一切.聊聊近状, 唠十块钱的
【Deep Learning】详细解读LSTM与GRU单元的各个公式和区别
【手把手AI项目】一、安装win10+linux-Ubuntu16.04的双系统(全网最详细)
【Deep Learning】为什么卷积神经网络中的“卷积”不是卷积运算?【TOOLS】Pandas如何进行内存优化和数据加速读取(附代码详解)【TOOLS】python3利用SMTP进行邮件Email自主发送【手把手AI项目】七、MobileNetSSD通过Ncnn前向推理框架在PC端的使用【时空序列预测第一篇】什么是时空序列问题?这类问题主要应用了哪些模型?主要应用在哪些领域?
公众号:AI蜗牛车保持谦逊、保持自律、保持进步个人微信
备注:昵称+学校/公司+方向
如果没有备注不拉群!
拉你进AI蜗牛车交流群点个在看,么么哒!

Pytorch深度学习实战教程:UNet语义分割网络相关推荐

  1. Pytorch深度学习实战教程:语义分割基础与环境搭建

    一.前言 许久没有更新技术博文了,给自己挖一个新坑:语义分割系列文章. 该系列文章的内容有: Pytorch的基本使用 语义分割算法讲解 先从最简单的语义分割基础与开发环境搭建开始讲解. 二.语义分割 ...

  2. Pytorch 深度学习实战教程(二):UNet语义分割网络

    本文 GitHub https://github.com/Jack-Cherish/PythonPark 已收录,有技术干货文章,整理的学习资料,一线大厂面试经验分享等,欢迎 Star 和 完善. 一 ...

  3. Pytorch深度学习实战教程(二):UNet语义分割网络

    1 前言 本文属于Pytorch深度学习语义分割系列教程. 该系列文章的内容有: Pytorch的基本使用 语义分割算法讲解 如果不了解语义分割原理以及开发环境的搭建,请看该系列教程的上一篇文章< ...

  4. Pytorch深度学习实战教程(一):语义分割基础与环境搭建

    Pytorch的基本使用&&语义分割算法讲解 先从最简单的语义分割基础与开发环境搭建开始讲解. 二.语义分割 语义分割是什么? 语义分割(semantic segmentation) ...

  5. Pytorch 深度学习实战教程:今天,你垃圾分类了吗?

    1 垃圾分类 还记得去年,上海如火如荼进行的垃圾分类政策吗? 2020年5月1日起,北京也开始实行「垃圾分类」了! 北京的垃圾分类标准与上海略有差别,垃圾分为厨余垃圾.可回收物.有害垃圾和其他垃圾四大 ...

  6. Pytorch 深度学习实战教程(六):仝卓自爆,快本打码。

    本文 GitHub https://github.com/Jack-Cherish/PythonPark 已收录,有技术干货文章,整理的学习资料,一线大厂面试经验分享等,欢迎 Star 和 完善. 一 ...

  7. 【Pytorch】Pytorch深度学习实战教程:超分辨率重建AI与环境搭建

    一.基础开发环境搭建 1)cuda安装 需要根据自己的显卡的型号选择支持的CUDA版本 显卡驱动查看: 鼠标右键 CUDA安装版本查看:https://docs.nvidia.com/cuda/cud ...

  8. 深度学习应用篇-计算机视觉-语义分割综述[5]:FCN、SegNet、Deeplab等分割算法、常用二维三维半立体数据集汇总、前景展望等

    [深度学习入门到进阶]必看系列,含激活函数.优化策略.损失函数.模型调优.归一化算法.卷积模型.序列模型.预训练模型.对抗神经网络等 专栏详细介绍:[深度学习入门到进阶]必看系列,含激活函数.优化策略 ...

  9. 深度学习高遥感影像语义分割

    深度学习遥感影像语义分割 深度学习大家都知道,在计算机视觉领域取得了很大的成功,在遥感影像自动解译方面,同样带来了快速的发展,我在遥感影像自动解译领域,也做了一些微薄的工作,发表几篇论文,我一直关注遥 ...

最新文章

  1. Asp.net中GridView使用详解(引)【转】
  2. LeetCode Minimum Path Sum(动态规划)
  3. 使用 icon 字体图标出现小方块问题
  4. 11.2运行异常和编译异常
  5. 【原创】RabbitMQ 之 TTL 详解(翻译)
  6. 百道Python面试题实现,搞定Python编程就靠它
  7. 网上商城—管理员修改商品
  8. Kafka解惑之Old Producer(1)—— Beginning
  9. mqtt java_MQTT和Java入门
  10. 技术真的就不是那么重要了
  11. n平方的求和公式_极限求解--数列前n项和公式推导(补充知识)
  12. 20. 自定义配置文件
  13. SQL Server 2005安装图解
  14. [BScroll warn]: Can not resolve the wrapper DOM.
  15. 如何使用MSGEQ7音频频谱分析仪芯片
  16. 苹果计算机格式化磁盘,MAC格式化移动硬盘
  17. 第六章-循环控制结构
  18. 各行业容灾备份架构#容灾#,
  19. Python GUI教程 | Lynda教程 中文字幕
  20. 2019上海网络赛icpc

热门文章

  1. Windows下Python无法正常卸载:There is a problem with this Windows Installer package.
  2. Win10右键文件无响应崩溃
  3. 微信小程序,一个有局限的类似 React Native 轮子
  4. 完美玩机 三星I9000解锁工具实测教程
  5. 35岁以上的那些测试员何去何从?
  6. java导入xmind的坑及解决方案
  7. Coding and Paper Letter(七十六)
  8. 贝拉博客,一个屌丝网站
  9. 贪官产生的本质是什么——谈谈人性与制度的博弈未来
  10. 数据库选课系统mysql_数据库设计(学生选课系统).doc