本文 GitHub https://github.com/Jack-Cherish/PythonPark 已收录,有技术干货文章,整理的学习资料,一线大厂面试经验分享等,欢迎 Star 和 完善。

一、前言

本文属于Pytorch深度学习语义分割系列教程。

该系列文章的内容有:

  • Pytorch的基本使用
  • 语义分割算法讲解

如果不了解语义分割原理以及开发环境的搭建,请看该系列教程的上一篇文章《Pytorch深度学习实战教程(一):语义分割基础与环境搭建》。

本文的开发环境采用上一篇文章搭建好的Windows环境,环境情况如下:

开发环境:Windows

开发语言:Python3.7.4

框架版本:Pytorch1.3.0

CUDA:10.2

cuDNN:7.6.0

本文主要讲解UNet网络结构,以及相应代码的代码编写

PS:文中出现的所有代码,均可在我的github上下载,欢迎Follow、Star:点击查看

二、UNet网络结构

在语义分割领域,基于深度学习的语义分割算法开山之作是FCN(Fully Convolutional Networks for Semantic Segmentation),而UNet是遵循FCN的原理,并进行了相应的改进,使其适应小样本的简单分割问题。

UNet论文地址:点击查看

研究一个深度学习算法,可以先看网络结构,看懂网络结构后,再Loss计算方法、训练方法等。本文主要针对UNet的网络结构进行讲解,其它内容会在后续章节进行说明。

1、网络结构原理

UNet最早发表在2015的MICCAI会议上,4年多的时间,论文引用量已经达到了9700多次。

UNet成为了大多做医疗影像语义分割任务的baseline,同时也启发了大量研究者对于U型网络结构的研究,发表了一批基于UNet网络结构的改进方法的论文。

UNet网络结构,最主要的两个特点是:U型网络结构和Skip Connection跳层连接。

UNet是一个对称的网络结构,左侧为下采样,右侧为上采样。

按照功能可以将左侧的一系列下采样操作称为encoder,将右侧的一系列上采样操作称为decoder。

Skip Connection中间四条灰色的平行线,Skip Connection就是在上采样的过程中,融合下采样过过程中的feature map。

Skip Connection用到的融合的操作也很简单,就是将feature map的通道进行叠加,俗称Concat。

Concat操作也很好理解,举个例子:一本大小为10cm*10cm,厚度为3cm的书A,和一本大小为10cm*10cm,厚度为4cm的书B。

将书A和书B,边缘对齐地摞在一起。这样就得到了,大小为10cm*10cm厚度为7cm的一摞书,类似这种:

这种“摞在一起”的操作,就是Concat。

同样道理,对于feature map,一个大小为256*256*64的feature map,即feature map的w(宽)为256,h(高)为256,c(通道数)为64。和一个大小为256*256*32的feature map进行Concat融合,就会得到一个大小为256*256*96的feature map。

在实际使用中,Concat融合的两个feature map的大小不一定相同,例如256*256*64的feature map和240*240*32的feature map进行Concat。

这种时候,就有两种办法:

第一种:将大256*256*64的feature map进行裁剪,裁剪为240*240*64的feature map,比如上下左右,各舍弃8 pixel,裁剪后再进行Concat,得到240*240*96的feature map。

第二种:将小240*240*32的feature map进行padding操作,padding为256*256*32的feature map,比如上下左右,各补8 pixel,padding后再进行Concat,得到256*256*96的feature map。

UNet采用的Concat方案就是第二种,将小的feature map进行padding,padding的方式是补0,一种常规的常量填充。

2、代码

有些朋友可能对Pytorch不太了解,推荐一个快速入门的官方教程。一个小时,你就可以掌握一些基本概念和Pytorch代码编写方法。

Pytorch官方基础:点击查看

我们将整个UNet网络拆分为多个模块进行讲解。

DoubleConv模块:

先看下连续两次的卷积操作。

从UNet网络中可以看出,不管是下采样过程还是上采样过程,每一层都会连续进行两次卷积操作,这种操作在UNet网络中重复很多次,可以单独写一个DoubleConv模块:

import torch.nn as nnclass DoubleConv(nn.Module):"""(convolution => [BN] => ReLU) * 2"""def __init__(self, in_channels, out_channels):super().__init__()self.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self, x):return self.double_conv(x)

解释下,上述的Pytorch代码:torch.nn.Sequential是一个时序容器,Modules 会以它们传入的顺序被添加到容器中。比如上述代码的操作顺序:卷积->BN->ReLU->卷积->BN->ReLU。

DoubleConv模块的in_channels和out_channels可以灵活设定,以便扩展使用。

如上图所示的网络,in_channels设为1,out_channels为64。

输入图片大小为572*572,经过步长为1,padding为0的3*3卷积,得到570*570的feature map,再经过一次卷积得到568*568的feature map。

计算公式:O=(H−F+2×P)/S+1

H为输入feature map的大小,O为输出feature map的大小,F为卷积核的大小,P为padding的大小,S为步长。

Down模块:

UNet网络一共有4次下采样过程,模块化代码如下:

class Down(nn.Module):"""Downscaling with maxpool then double conv"""def __init__(self, in_channels, out_channels):super().__init__()self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2),DoubleConv(in_channels, out_channels))def forward(self, x):return self.maxpool_conv(x)

这里的代码很简单,就是一个maxpool池化层,进行下采样,然后接一个DoubleConv模块。

至此,UNet网络的左半部分的下采样过程的代码都写好了,接下来是右半部分的上采样过程

Up模块:

上采样过程用到的最多的当然就是上采样了,除了常规的上采样操作,还有进行特征的融合。

这块的代码实现起来也稍复杂一些:

class Up(nn.Module):"""Upscaling then double conv"""def __init__(self, in_channels, out_channels, bilinear=True):super().__init__()# if bilinear, use the normal convolutions to reduce the number of channelsif bilinear:self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)else:self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels)def forward(self, x1, x2):x1 = self.up(x1)# input is CHWdiffY = torch.tensor([x2.size()[2] - x1.size()[2]])diffX = torch.tensor([x2.size()[3] - x1.size()[3]])x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])# if you have padding issues, see# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bdx = torch.cat([x2, x1], dim=1)return self.conv(x)

代码复杂一些,我们可以分开来看,首先是__init__初始化函数里定义的上采样方法以及卷积采用DoubleConv。上采样,定义了两种方法:Upsample和ConvTranspose2d,也就是双线性插值反卷积

双线性插值很好理解,示意图:

熟悉双线性插值的朋友对于这幅图应该不陌生,简单地讲:已知Q11、Q12、Q21、Q22四个点坐标,通过Q11和Q21求R1,再通过Q12和Q22求R2,最后通过R1和R2求P,这个过程就是双线性插值。

对于一个feature map而言,其实就是在像素点中间补点,补的点的值是多少,是由相邻像素点的值决定的。

反卷积,顾名思义,就是反着卷积。卷积是让featuer map越来越小,反卷积就是让feature map越来越大,示意图:

下面蓝色为原始图片,周围白色的虚线方块为padding结果,通常为0,上面绿色为卷积后的图片。

这个示意图,就是一个从2*2的feature map->4*4的feature map过程。

在forward前向传播函数中,x1接收的是上采样的数据,x2接收的是特征融合的数据。特征融合方法就是,上文提到的,先对小的feature map进行padding,再进行concat。

OutConv模块:

用上述的DoubleConv模块、Down模块、Up模块就可以拼出UNet的主体网络结构了。UNet网络的输出需要根据分割数量,整合输出通道,结果如下图所示:

操作很简单,就是channel的变换,上图展示的是分类为2的情况(通道为2)。

虽然这个操作很简单,也就调用一次,为了美观整洁,也封装一下吧。

class OutConv(nn.Module):def __init__(self, in_channels, out_channels):super(OutConv, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):return self.conv(x)

至此,UNet网络用到的模块都已经写好,我们可以将上述的模块代码都放到一个unet_parts.py文件里,然后再创建unet_model.py,根据UNet网络结构,设置每个模块的输入输出通道个数以及调用顺序,编写如下代码:

""" Full assembly of the parts to form the complete network """
"""Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py"""import torch.nn.functional as Ffrom unet_parts import *class UNet(nn.Module):def __init__(self, n_channels, n_classes, bilinear=False):super(UNet, self).__init__()self.n_channels = n_channelsself.n_classes = n_classesself.bilinear = bilinearself.inc = DoubleConv(n_channels, 64)self.down1 = Down(64, 128)self.down2 = Down(128, 256)self.down3 = Down(256, 512)self.down4 = Down(512, 1024)self.up1 = Up(1024, 512, bilinear)self.up2 = Up(512, 256, bilinear)self.up3 = Up(256, 128, bilinear)self.up4 = Up(128, 64, bilinear)self.outc = OutConv(64, n_classes)def forward(self, x):x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)logits = self.outc(x)return logitsif __name__ == '__main__':net = UNet(n_channels=3, n_classes=1)print(net)

使用命令python unet_model.py,如果没有错误,你会得到如下结果:

UNet((inc): DoubleConv((double_conv): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True)))(down1): Down((maxpool_conv): Sequential((0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(1): DoubleConv((double_conv): Sequential((0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True)))))(down2): Down((maxpool_conv): Sequential((0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(1): DoubleConv((double_conv): Sequential((0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True)))))(down3): Down((maxpool_conv): Sequential((0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(1): DoubleConv((double_conv): Sequential((0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True)))))(down4): Down((maxpool_conv): Sequential((0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(1): DoubleConv((double_conv): Sequential((0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1))(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))(4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True)))))(up1): Up((up): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))(conv): DoubleConv((double_conv): Sequential((0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1))(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))))(up2): Up((up): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))(conv): DoubleConv((double_conv): Sequential((0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1))(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))))(up3): Up((up): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))(conv): DoubleConv((double_conv): Sequential((0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))))(up4): Up((up): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))(conv): DoubleConv((double_conv): Sequential((0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))))(outc): OutConv((conv): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1)))
)

网络搭建完成,下一步就是使用网络进行训练了,具体实现会在该系列教程的下一篇文章进行讲解。

三、小结

  • 本文主要讲解了UNet网络结构,并对UNet网络进行了模块化梳理。
  • 下篇文章讲解如何使用UNet网络,编写训练代码。

点赞再看,养成习惯,微信公众号搜索【JackCui-AI】关注一个在互联网摸爬滚打的潜行者

Pytorch 深度学习实战教程(二):UNet语义分割网络相关推荐

  1. Pytorch深度学习实战教程:UNet语义分割网络

    1 前言 本文属于Pytorch深度学习语义分割系列教程. 该系列文章的内容有: Pytorch的基本使用 语义分割算法讲解 本文的开发环境如下: 开发环境:Windows 开发语言:Python3. ...

  2. Pytorch深度学习实战教程(二):UNet语义分割网络

    1 前言 本文属于Pytorch深度学习语义分割系列教程. 该系列文章的内容有: Pytorch的基本使用 语义分割算法讲解 如果不了解语义分割原理以及开发环境的搭建,请看该系列教程的上一篇文章< ...

  3. 深度学习实战23(进阶版)-语义分割实战,实现人物图像抠图的效果(计算机视觉)

    大家好,我是微学AI,今天给大家带来深度学习实战23(进阶版)-语义分割实战,实现人物图像抠图的效果.语义分割是计算机视觉中的一项重要任务,其目标是将图像中的每个像素都分配一个语义类别标签.与传统的目 ...

  4. Pytorch深度学习实战教程(一):语义分割基础与环境搭建

    Pytorch的基本使用&&语义分割算法讲解 先从最简单的语义分割基础与开发环境搭建开始讲解. 二.语义分割 语义分割是什么? 语义分割(semantic segmentation) ...

  5. Pytorch深度学习实战教程:语义分割基础与环境搭建

    一.前言 许久没有更新技术博文了,给自己挖一个新坑:语义分割系列文章. 该系列文章的内容有: Pytorch的基本使用 语义分割算法讲解 先从最简单的语义分割基础与开发环境搭建开始讲解. 二.语义分割 ...

  6. Pytorch 深度学习实战教程(六):仝卓自爆,快本打码。

    本文 GitHub https://github.com/Jack-Cherish/PythonPark 已收录,有技术干货文章,整理的学习资料,一线大厂面试经验分享等,欢迎 Star 和 完善. 一 ...

  7. Pytorch 深度学习实战教程:今天,你垃圾分类了吗?

    1 垃圾分类 还记得去年,上海如火如荼进行的垃圾分类政策吗? 2020年5月1日起,北京也开始实行「垃圾分类」了! 北京的垃圾分类标准与上海略有差别,垃圾分为厨余垃圾.可回收物.有害垃圾和其他垃圾四大 ...

  8. 【Pytorch】Pytorch深度学习实战教程:超分辨率重建AI与环境搭建

    一.基础开发环境搭建 1)cuda安装 需要根据自己的显卡的型号选择支持的CUDA版本 显卡驱动查看: 鼠标右键 CUDA安装版本查看:https://docs.nvidia.com/cuda/cud ...

  9. PyTorch深度学习实战(5)——计算机视觉基础

    PyTorch深度学习实战(5)--计算机视觉基础 0. 前言 1. 图像表示 2. 将图像转换为结构化数组 2.1 灰度图像表示 2.2 彩色图像表示 3 利用神经网络进行图像分析的优势 小结 系列 ...

最新文章

  1. ​别再用方括号在Python中获取字典的值,试试这个方法
  2. UVA-10954 Add All
  3. VTK修炼之道75:交互部件_测量类Widget的应用
  4. 【高斯消元】球形空间产生器(luogu 4035/金牌导航 高斯消元-1)
  5. Mybatis源码学习笔记
  6. httpclient案例一(调用识别接口)
  7. 踩坑日记(二):记一次线上业务—Redis 的缓存雪崩
  8. excel服务器数据同步修改,勤哲Excel服务器同步解决海量数据快速上传问题
  9. Dubbo源码解析之SPI(一):扩展类的加载过程
  10. 回弹强度记录表填写_混凝土抗压强度回弹法测试原始记录表.doc-_装配图网
  11. 携程的产品与收入模式分析
  12. python 复制word内容_Python读取word文本操作详解
  13. sql 累计占比_sql 面试题(难题汇总)
  14. 【MobileViT】
  15. MySQL auto_increment介绍及自增键断层的原因分析
  16. Tushare介绍和入门级实践(1)——使用tushare接口获取沪深300成分股交易日涨跌数据
  17. fatal error C1004: 发现意外的文件尾
  18. 21 天入门机器学习-第05期
  19. es6 class 跟普通function的区别
  20. 查看表所有列名SQL

热门文章

  1. css:linear-gradient实现水平条纹背景,垂直条纹背景,斜向条纹背景
  2. Android 音乐盒完整版
  3. 远程连接工具——WindTerm
  4. 新浪ID.邮箱.新浪注册器::最新免费可用,自己测试
  5. 在ffmpeg中添加编解码器
  6. 操作:EFS加密、密钥备份、恢复代理
  7. VR AR MR到底是什么
  8. python猴子选大王_sicily 猴子选大王
  9. 看了下面的关于大亚湾核电站的一些报道 我终于明白为什么老弟在惠阳买的房子会这么便宜了
  10. Oracle学习笔记(七)——分组统计查询