专栏目录:pytorch(图像分割UNet)快速入门与实战——零、前言
pytorch快速入门与实战——一、知识准备(要素简介)
pytorch快速入门与实战——二、深度学习经典网络发展
pytorch快速入门与实战——三、Unet实现
pytorch快速入门与实战——四、网络训练与测试

续上文pytorch快速入门与实战——二、深度学习经典网络发展的8.4章节

Unet实现

  • 1 前期准备
    • 1.1 torch安装
    • 1.2 数据集准备
    • 1.3 网络结构骨架
    • 1.4 数据分析、完善网络
  • 2.网络实现
    • 2.1 相关知识
    • 2.2 代码实现:
      • 2.2.1 初始化方法__init__():
      • 2.2.2 参数回传方法forward():
      • 2.2.3 语义分割实现流程
      • 2.2.4 整合!(网络完整代码)

1 前期准备

1.1 torch安装

pytorch安装自行解决

1.2 数据集准备

我的是自己模拟的数据,所有数据一共是1600对(inputs,labels),训练集与测试集是9:1抽取的。
我的输入大小为120*240,label大小为256*256.

1.3 网络结构骨架

backbone是Unet,根据自己需求再变。不是改自己网络,而是自己加个卷积适应自己的输入输出。
先down一个基础的Unet图。基于此来修改。

1.4 数据分析、完善网络

【具体的size和channel没所谓的,都是可以直接设置的,怎么设置在实现里面说,这里只说流程】
图中input的size是572x572x1,而我的size是120x240x1,我选择在Unet之前加一个卷积层以让我的输入成为方形120x120x1,为了后续计算方便,通过padding(直接padding或者通过卷积都可以)变成128x128x1。接下来就是常规Unet操作,所以我的网络结构图为:

可以看到整个变化过程:具体如何变在实现中说明(一张破图一下午,骂骂咧咧ing)

120x240x1--卷积-->120x120x1--卷积-->128x128x1--卷积-->128x128x32
--池化-->64x64x32--卷积-->64x64x64
--池化-->32x32x64--卷积-->32x32x128
--池化-->16x16x128--卷积-->16x16x256
--池化-->8x8x256--卷积-->8x8x512--上采样-->16x16x256
--通道拼接-->16x16x512--反卷积-->16x16x256--上采样-->32x32x128
--通道拼接-->32x32x256--反卷积-->32x32x128--上采样-->64x64x64
--通道拼接-->64x64x128--反卷积-->64x64x64--上采样-->128x128x32
--通道拼接-->128x128x64--反卷积-->128x128x32--上采样-->256x256x16
(注意我左边是128开始的,所以没法拼接了,网络结构并不是严格对称的)
--1x1卷积核代替全连接-->256x256x1

2.网络实现

2.1 相关知识

  1. 首先我们要知道卷积的计算公式:

O = (I − K + 2P )/S+1
O(output)是输出图像、I(input)为原始图像、K(kernel)为卷积核尺寸、P为padding、S(stride)是步长

  1. 以及反卷积的计算公式:

output = (input-1)stride+output_padding -2*padding+kernel_size
O = (I-1)*S + OP - 2P + K
O(output)是输出图像、I(input)为原始图像、K(kernel)为卷积核尺寸、P为padding、S(stride)是步长,OP为output_padding

  1. 通道channel

说一下我的理解:

现实意义上是特征(我用分类来举例子:比如西瓜的根蒂,颜色,花纹等)
在图片中,色彩是一种特征,但特征不只是色彩。
比如我的灰度图,那channel就是1,如果是其他彩色图(RGB,BGR,CMY)的channel都是3
那可能就要问了,那图中channel为64,难道是64种色彩?
参照上面那句话“特征不只是色彩”,其他特征,我也不懂,猜测是分布什么的吧。

2.2 代码实现:

emmmm还是由浅入深地讲解吧:网络的整体代码放在文章最后。
首先导入torch包:

import torch
import torch.nn as nn

然后设计我的网络AdUNet,编写成类,该类继承nn.module。
主要重写两个方法:初始化__init__和参数回传forward

在此之前,为了提高代码复用性,将重复出现的双层卷积设计成一个函数,方便代码复用:

def double_conv(in_channels, out_channels):  # 双层卷积模型,神经网络最基本的框架return nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1),nn.BatchNorm2d(out_channels),  # 加入Bn层提高网络泛化能力(防止过拟合),加收敛速度nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, 3, padding=1),  # 3指kernel_size,即卷积核3*3nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))

OK,开始。

2.2.1 初始化方法__init__():

  1. 输入适配层
    首先自主设计让输入适应网络的卷积层adnet放入网络AdUNet的类里,将输入1x120x240卷积为方形1x120x120,利用pytorch自带的卷积核方法Conv2d来实现:

设置输入的通道in_channels和输出的通道out_channels,选择2x1的卷积核,padding设为0,步长设置为(2,1)即行方向上步长为2,列方向上步长为1。这样设置步长才能让行方向的size缩小一倍。size调整为120x120
然后再绑定一个BN层和ReLu层,作用与原因参照上一篇文章。
然后再用一个padding=5的3x3卷积核将size从120x120调整为128x128

self.adnet = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(2, 1), padding=0, stride=(2, 1)),nn.BatchNorm2d(1),  # 加入Bn层提高网络泛化能力(防止过拟合),加收敛速度nn.ReLU(inplace=True),nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, padding=5, stride=1),nn.BatchNorm2d(1),  # 加入Bn层提高网络泛化能力(防止过拟合),加收敛速度nn.ReLU(inplace=True))
  1. 4个下采样时的卷积层+一个底层的卷积层
        self.dconv_down0 = double_conv(1, 32)self.dconv_down1 = double_conv(32, 64)self.dconv_down2 = double_conv(64, 128)self.dconv_down3 = double_conv(128, 256)self.dconv_down4 = double_conv(256, 512)
  1. 最大池化层
self.maxpool = nn.MaxPool2d(2)
  1. 4个上采样时的卷积层



        self.dconv_up3 = double_conv(256 + 256, 256)self.dconv_up2 = double_conv(128 + 128, 128)self.dconv_up1 = double_conv(64 + 64, 64)self.dconv_up0 = double_conv(64, 32)
  1. 5个上采样
        self.upsample4 = nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1)self.upsample3 = nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1)self.upsample2 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)self.upsample1 = nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)self.upsample0 = nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1)
  1. 代替全连接层的1x1卷积层
        self.conv_last = nn.Conv2d(16, 1, 1)

2.2.2 参数回传方法forward():

按照上图的网络结构将他们拼接起来!就OK了!
哦对,别忘了concat。
为什么不把下采样和上采样的那个重复模块写在一起呢?就是因为我不想传参,因为前面下采样的时候要在pool池之前保留值留给上采样的时候concat,所以就单独写了。concat操作也简单,看看代码就懂了,没什么难点。

    def forward(self, x):# reshapex = self.adnet(x)  # 1x128x128# encodeconv0 = self.dconv_down0(x)  # 32x128x128x = self.maxpool(conv0)  # 32x64x64conv1 = self.dconv_down1(x)  # 64x64x64x = self.maxpool(conv1)  # 64x32x32conv2 = self.dconv_down2(x)  # 128x32x32x = self.maxpool(conv2)  # 128x16x16conv3 = self.dconv_down3(x)  # 256x16x16x = self.maxpool(conv3)  # 256x8x8x = self.dconv_down4(x)  # 512x8x8# decodex = self.upsample4(x)  # 256x16x16# 因为使用了3*3卷积核和 padding=1 的组合,所以卷积过程图像尺寸不发生改变,所以省去了crop操作!x = torch.cat([x, conv3], dim=1)  # 512x16x16x = self.dconv_up3(x)  # 256x16x16x = self.upsample3(x)  # 128x32x32x = torch.cat([x, conv2], dim=1)  # 256x32x32x = self.dconv_up2(x)  # 128x32x32x = self.upsample2(x)  # 64x64x64x = torch.cat([x, conv1], dim=1)  # 128x64x64x = self.dconv_up1(x)  # 64x64x64x = self.upsample1(x)  # 32x128x128x = torch.cat([x, conv0], dim=1)  # 64x128x128x = self.dconv_up0(x)  # 32x128x128x = self.upsample0(x)   # 16x256x256out = self.conv_last(x)  # 1x256x256return out

2.2.3 语义分割实现流程

很遗憾地说,网络的结构虽然实现了,但是距离我们的目标还有一些路,但是还好,这个网络是确确实实可以用的,只要加载数据训练就可以得出结果,甚至可以随机生成一些矩阵当做图像来进行训练。
这里简单说一下流程,预感细节不少,详细实现下篇再说:pytorch快速入门与实战——四、网络训练与测试
训练:

根据batch size大小,将数据集中的训练样本和标签读入卷积神经网络。根据实际需要,应先对训练图片及标签进行预处理,如裁剪、数据增强等。这有利于深层网络的的训练,加速收敛过程,同时也避免过拟合问题并增强了模型的泛化能力。

验证:

训练一个epoch结束后,将数据集中的验证样本和标签读入卷积神经网络,并载入训练权重。根据编写好的语义分割指标进行验证,得到当前训练过程中的指标分数,保存对应权重。常用一次训练一次验证的方法更好的监督模型表现。

测试:

所有训练结束后,将数据集中的测试样本和标签读入卷积神经网络,并将保存的最好权重值载入模型,进行测试。 测试结果分为两种,一种是根据常用指标分数衡量网络性能,另一种是将网络的预测结果以 图片的形式保存下来,直观感受分割的精确程度。

2.2.4 整合!(网络完整代码)

import torch
import torch.nn as nndef double_conv(in_channels, out_channels):  # 双层卷积模型,神经网络最基本的框架return nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1),nn.BatchNorm2d(out_channels),  # 加入Bn层提高网络泛化能力(防止过拟合),加收敛速度nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, 3, padding=1),  # 3指kernel_size,即卷积核3*3nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))# class UpSample(nn.Module):
#     def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding):
#         super(UpSample, self).__init__()
#         self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=2, padding=1)
#         self.conv_relu = nn.Sequential(
#             nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
#             nn.BatchNorm2d(num_features=out_channels),
#             nn.ReLU(),
#             nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
#             nn.BatchNorm2d(num_features=out_channels),
#             nn.ReLU(),
#         )
#
#     def forward(self, x, y):
#         x = self.up(x)
#         x1 = torch.cat((x, y), dim=0)
#         x1 = self.conv_relu(x1)
#         return x1 + xclass AdUNet(nn.Module):def __init__(self):super().__init__()self.adnet = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(2, 1), padding=0, stride=(2, 1)),nn.BatchNorm2d(1),  # 加入Bn层提高网络泛化能力(防止过拟合),加收敛速度nn.ReLU(inplace=True),nn.Conv2d(1, 1, kernel_size=3, padding=5, stride=1),nn.BatchNorm2d(1),  # 加入Bn层提高网络泛化能力(防止过拟合),加收敛速度nn.ReLU(inplace=True))self.dconv_down0 = double_conv(1, 32)self.dconv_down1 = double_conv(32, 64)self.dconv_down2 = double_conv(64, 128)self.dconv_down3 = double_conv(128, 256)self.dconv_down4 = double_conv(256, 512)self.maxpool = nn.MaxPool2d(2)# self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.upsample4 = nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1)self.upsample3 = nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1)self.upsample2 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)self.upsample1 = nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)self.upsample0 = nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1)self.dconv_up3 = double_conv(256 + 256, 256)self.dconv_up2 = double_conv(128 + 128, 128)self.dconv_up1 = double_conv(64 + 64, 64)self.dconv_up0 = double_conv(64, 32)self.conv_last = nn.Conv2d(16, 1, 1)def forward(self, x):# reshapex = self.adnet(x)  # 1x128x128# encodeconv0 = self.dconv_down0(x)  # 32x128x128x = self.maxpool(conv0)  # 32x64x64conv1 = self.dconv_down1(x)  # 64x64x64x = self.maxpool(conv1)  # 64x32x32conv2 = self.dconv_down2(x)  # 128x32x32x = self.maxpool(conv2)  # 128x16x16conv3 = self.dconv_down3(x)  # 256x16x16x = self.maxpool(conv3)  # 256x8x8x = self.dconv_down4(x)  # 512x8x8# decodex = self.upsample4(x)  # 256x16x16# 因为使用了3*3卷积核和 padding=1 的组合,所以卷积过程图像尺寸不发生改变,所以省去了crop操作!x = torch.cat([x, conv3], dim=1)  # 512x16x16x = self.dconv_up3(x)  # 256x16x16x = self.upsample3(x)  # 128x32x32x = torch.cat([x, conv2], dim=1)  # 256x32x32x = self.dconv_up2(x)  # 128x32x32x = self.upsample2(x)  # 64x64x64x = torch.cat([x, conv1], dim=1)  # 128x64x64x = self.dconv_up1(x)  # 64x64x64x = self.upsample1(x)  # 32x128x128x = torch.cat([x, conv0], dim=1)  # 64x128x128x = self.dconv_up0(x)  # 32x128x128x = self.upsample0(x)   # 16x256x256out = self.conv_last(x)  # 1x256x256return out

pytorch快速入门与实战——三、Unet实现相关推荐

  1. appinventor2 MySQL,写给大家看的安卓应用开发书 App Inventor 2快速入门与实战pdf

    没错,你有能力创建自己的安卓应用,而且一点都不难.AppInventor2,让你分分钟成为应用开发者! 本书由浅入深地介绍了强大的可视化编程工具AppInventor2,任何人都可以用它来开发自己的应 ...

  2. 深度学习框架PyTorch快速开发与实战

    深度学习框架PyTorch快速开发与实战 邢梦来,王硕,孙洋洋 著 ISBN:9787121345647 包装:平装 开本:16开 用纸:胶版纸 正文语种:中文 出版社:电子工业出版社 出版时间:20 ...

  3. (上)小程序从0快速入门到实战项目打造个性简历,让你轻松脱颖而出吸引面试官眼球(附源码)

    前言 分享之前我们先来认识一下小程序,官方定义的微信小程序是一种新的开放能力,开发者可以快速地开发一个小程序.更是一种全新的连接用户与服务的方式,它可以在微信内被便捷地获取和传播,同时具有出色的使用体 ...

  4. 带你少走弯路:强烈推荐的Pytorch快速入门资料和翻译(可下载)

    上次写了TensorFlow的快速入门资料,受到很多好评,读者强烈建议我出一个pytorch的快速入门路线,经过翻译和搜索网上资源,我推荐3份入门资料,希望对大家有所帮助. 备注:TensorFlow ...

  5. 新手必备pr 2021快速入门教程「三」素材的导入与管理

    PR2021快速入门教程,学完之后,制作抖音视频,vlog,电影混剪,日常记录等不在话下!零基础,欢迎入坑! 本节内容 上节内容我们学习了新建项目以及软件首选项的一些基本设置,接下来我们就可以导入素材 ...

  6. Linux 快速入门到实战【二】

    一.Linux用户与权限 1. 用户和权限的基本概念 1.1.基本概念 用户 是Linux系统工作中重要的一环, 用户管理包括 用户 与 组 管理 在Linux系统中, 不论是由本级或是远程登录系统, ...

  7. 【视频课】永久免费!5小时快速掌握Pytorch框架入门及实战

    前言 PyTorch是深度学习的主流框架之一,新手入门相对容易.为了帮助初学者解决PyTorch入门及实践的问题,有三AI推出<深度学习之PyTorch-入门及实战>课程,课程将算法.模型 ...

  8. Python快速入门到实战(三)逻辑控制语句,函数与类

    目录 一.逻辑控制语句 条件控制语句 if-else for 循环语句 while 循环 break 语句 continue 语句 Pass 语句 二.函数 函数的定义与调用 参数传递 函数的参数类型 ...

  9. Pytorch快速入门笔记

    Pytorch 入门笔记 1. Pytorch下载与安装 2. Pytorch的使用教程 2.1 Pytorch设计理念及其基本操作 2.2 使用torch.nn搭建神经网络 2.3 创建属于自己的D ...

最新文章

  1. Google 是如何定制 Material 主题的?
  2. QTP自动化测试-笔记 注释、大小写
  3. Go变量地址值和指针的关系
  4. Mysql快照读和当前读
  5. 数据结构与算法——递归、回溯与分治
  6. 使用jupyterthemes插件定制jupyter notebook界面
  7. 动态调用Webservice 支持Soapheader身份验证(转)
  8. 126.单词接龙II
  9. viper4android 机顶盒,利用VIPer53封装上系统实现经济型机顶盒供电
  10. USB转串口驱动安装失败解决方法
  11. 王牌英雄怎么服务器维护了,王牌英雄steam版无法运行问题解决方法
  12. 计算机管理禁用usb,电脑如何禁用U盘、怎样禁用USB存储工具,防止USB端口泄密?...
  13. 新浪微博相册图片外链限制,图床不显示解决方法总结!
  14. c语言中百分号后面跟的数字_C语言中的各种百分号都代表什么意思? c语言中百分号后的数字是...
  15. python找不到指定模块sklearn怎么办_python中sklearn找不到指定模块怎么办
  16. 笔记本电脑安装固态硬盘并重装win10系统
  17. 年度大促将至,企业如何进行性能压测
  18. python如何撤销上一步_python代码运行到某一步能返回到前面某一步吗?
  19. bzoj 1646 bfs
  20. 读取手机或SD卡的音频

热门文章

  1. 最新UE下载地址和可使用注册码(公布)
  2. 个人博客处理——页面处理
  3. DES加密解密kotlin版
  4. 八、STM32串口通信
  5. SVN客户端 创建分支/合并分支/切换分支
  6. 记录大坑:用Xamarin引入UHF读写器dll,报错: 所生成项目的处理器架构“MSIiL”与引用的Reader.dll处理器架构“x86”不匹配
  7. Linux 系统查询处理器架构
  8. 【图文详解】一文全面彻底搞懂HBase、LevelDB、RocksDB等NoSQL背后的存储原理:LSM-tree 日志结构合并树...
  9. Apple M1 Sourcetree 卡 卡顿 卡死
  10. 数字签名与数字证书技术简介(二)