一、前言

经过慎重考虑,决定新开一个系列,该系列文章主要的目的就是利用PyTorch、Python实现深度学习中的一些经典模型,接下来一段时间的安排如下:

  • UNet
  • ResNet
  • VggNet
  • AlexNet

本文首先实现UNet,关于UNet的详细介绍请移步深度学习模型解析系列文章–白话详解UNet

二、网络结构详解


UNet总体上分为编码器和解码器,其中编码器负责提取特征信息,解码器负责还原特征信息;编码器主要由4个块组成,每个块分别由2个卷积层、1个最大池化层组成。解码器也是由4个块组成,每个块都是由1个上采样层、2个卷积层组成,详细信息请见下图。

三、网络组成部分实现

  • 第1步:导入需要的包
import torch
import torch.nn as nn
import torch.nn.functional as F
  • 第2步:我们需要自定义一个卷积的基础块,该基础块由2个卷积层组成。
class DoubleConv(nn.Module):"""(convolution => [BN] => ReLU) * 2"""def __init__(self, in_channels, out_channels, mid_channels=None):super().__init__()if not mid_channels:mid_channels = out_channelsself.double_conv = nn.Sequential(nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),nn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self, x):return self.double_conv(x)
  • 第3步:我们需要自定义一个编码器的基础块,该块由1个最大池化层和第2步的卷积基础块组成。
class Down(nn.Module):"""Downscaling with maxpool then double conv"""def __init__(self, in_channels, out_channels):super().__init__()self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2),DoubleConv(in_channels, out_channels))def forward(self, x):return self.maxpool_conv(x)
  • 第4步:我们需要自定义一个解码器的基础块,该基础块由1个上采样层和2个卷积层组成。
class Up(nn.Module):"""Upscaling then double conv"""def __init__(self, in_channels, out_channels, bilinear=True):super().__init__()# if bilinear, use the normal convolutions to reduce the number of channelsif bilinear:self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  # 双线性插值self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)else:self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)  # 转置卷积self.conv = DoubleConv(in_channels, out_channels)def forward(self, x1, x2):x1 = self.up(x1)# input is CHWdiffY = x2.size()[2] - x1.size()[2]diffX = x2.size()[3] - x1.size()[3]x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])x = torch.cat([x2, x1], dim=1)return self.conv(x)
  • 第5步:定义一个最后的输出层
class OutConv(nn.Module):def __init__(self, in_channels, out_channels):super(OutConv, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):return self.conv(x)

四、网络结构实现

  • 第1步:我们需要把上述定义的类一股脑的导入到你要定义的网络文件中,因为每个人的文件夹不同,这里就不详细讲述。
  • 第2步:初始化你的网络模型参数
  • 第3步:编写前向传播方法
class UNet(nn.Module):def __init__(self, args, n_channels, n_classes, bilinear=True):super(UNet, self).__init__()  # 简单点讲:就是子类使用父类的初始化方法进行初始化,这会使得代码非常的整洁self.n_channels = n_channelsself.n_classes = n_classesself.bilinear = bilinear"""DoubleConv <-> (convolution => [BN] => ReLU) * 2"""self.inc = DoubleConv(n_channels, 64)self.down1 = Down(64, 128)self.down2 = Down(128, 256)self.down3 = Down(256, 512)factor = 2 if bilinear else 1self.down4 = Down(512, 1024 // factor)self.up1 = Up(1024, 512 // factor, bilinear)self.up2 = Up(512, 256 // factor, bilinear)self.up3 = Up(256, 128 // factor, bilinear)self.up4 = Up(128, 64, bilinear)self.outc = OutConv(64, n_classes)def forward(self, x):x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)logits = self.outc(x)return logits

至此,UNet经典网络结构就编写好了,是不是非常的简单呢?如果您觉得写的还不错,欢迎一键三连,这对我真的帮助很大,非常感谢!我也会继续努力,提升文章的质量与数量!

手把手带你撸深度学习经典模型(一)----- UNet相关推荐

  1. 手把手带你入门深度学习(一):保姆级Anaconda和PyTorch环境配置指南

    手把手带你入门深度学习(一):保姆级Anaconda和PyTorch环境配置指南 一. 前言和准备工作 1.1 python.anaconda和pytorch的关系 二. Anconda安装 2.1 ...

  2. 【杂谈】手把手带你配置深度学习环境

    要想AI学的好,那就得linux系统用的好.放弃windows系统,只用linux系统是你学习AI,或者说做一个合格程序猿的基础.今天就手把手教大家如何在linux系统上配置OpenCV和Caffe, ...

  3. 【手把手带你入门深度学习之150行代码的汉字识别系统】学习笔记 ·003 用训练模型进行预测

    立即学习:https://edu.csdn.net/course/play/24719/279510?utm_source=blogtoedu 目录 一.用训练模型进行预测代码 二.思路总结 1.模型 ...

  4. 【手把手带你入门深度学习之150行代码的汉字识别系统】学习笔记 ·002 训练神经网络

    立即学习:https://edu.csdn.net/course/play/24719/279509?utm_source=blogtoedu 目录 一.神经网络训练代码 二.思路总结 1.数据集图片 ...

  5. 【手把手带你入门深度学习之150行代码的汉字识别系统】学习笔记 ·001 用OpenCV制作数据集

    立即学习:https://edu.csdn.net/course/play/24719/279505?utm_source=blogtoedu 目录 一.制作数据集代码 二.思路总结 1.数据集目录的 ...

  6. 手把手教你用深度学习做物体检测(四):模型使用

    上一篇<手把手教你用深度学习做物体检测(三):模型训练>中介绍了如何使用yolov3训练我们自己的物体检测模型,本篇文章将重点介绍如何使用我们训练好的模型来检测图片或视频中的物体.   如 ...

  7. 手把手教你用深度学习做物体检测(三):模型训练

    本篇文章旨在快速试验使用yolov3算法训练出自己的物体检测模型,所以会重过程而轻原理,当然,原理是非常重要的,只是原理会安排在后续文章中专门进行介绍.所以如果本文中有些地方你有原理方面的疑惑,也没关 ...

  8. 谷歌、阿里们的杀手锏:三大领域,十大深度学习CTR模型演化图谱

    作者 | 王喆 来源 | 转载自知乎专栏王喆的机器学习笔记 今天我们一起回顾一下近3年来的所有主流深度学习CTR模型,也是我工作之余的知识总结,希望能帮大家梳理推荐系统.计算广告领域在深度学习方面的前 ...

  9. 谷歌、阿里们的杀手锏:3大领域,10大深度学习CTR模型演化图谱(附论文)

    来源:知乎 作者:王喆 本文约4000字,建议阅读8分钟. 本文为你介绍近3年来的所有主流深度学习CTR模型. 今天我们一起回顾一下近3年来的所有主流深度学习CTR模型,也是我工作之余的知识总结,希望 ...

最新文章

  1. 试用最新版本的live writer发一篇日志看看
  2. python的优缺点-Python语言的优点和缺点 - 深度剖析
  3. 2017年CISCN初赛
  4. 网络安装LINUX系统原理,PXE网络引导系统自动化安装CentOS7
  5. python中的基本数据结构
  6. php跨域传sessionid,php中http与https跨域共享session的解决方法
  7. C++_类和对象_C++运算符重载_关系运算符重载_对== !=重载实现对象的对比_---C++语言工作笔记059
  8. 接到有用数据的5个做法,让你不再头疼
  9. 论文列表——text classification
  10. DELL R340 14G服务器的RAID划分
  11. 线程池创建线程数量讨论
  12. ARKit入门到精通 1.0 - 实战案例 AR打地鼠-史小川-专题视频课程
  13. JAVA开发交互式CAD系统_用VB.NET和VC#.NET开发交互式CAD系统(源代码)
  14. 帧间差分法函数python_【目标追踪】python帧差法原理及其实现
  15. C语言分数加减乘除化简操作集(含测试源码)
  16. 开源流媒体SRS结合硬件视频实时转码服务器的部署
  17. 云服务器是什么?云服务器有什么作用?
  18. vue按钮字体大小设置_用Vue模仿antd的样式造UI组件之button
  19. 【腾讯TMQ】走进标准化测试
  20. CAKEPHP 常见错误

热门文章

  1. C# 获取utc时间,以及utc datetime 互相转化
  2. SpringMVC+ZTree实现树形菜单权限配置
  3. noip2016 组合数问题
  4. java全面的知识体系结构总结
  5. IIS 发布网站到外网
  6. 基于Angularjs+jasmine+karma的测试驱动开发(TDD)实例
  7. Delphi编程之系统OEM DIY
  8. Google Analytics 跟踪代码安装后状态总是显示'未安装跟踪代码'
  9. MATLAB中的wavedec、wrcoef函数简析
  10. Golang库学习笔记 Gin(三)