U-Net模型搭建python实现
最近在学习U-Net,读完文章之后尝试搭建模型的框架,阅读了前人的模型后,试着自己搭建了一下,适合初学者。
import torch
import torch.nn as nn
from torchsummary import summary# 2次 3*3卷积
# Remark,第一个3*3卷积承担,升降维的功能
class DoubleConv(nn.Module):def __init__(self, in_channels,mid_channels, out_channels):super(DoubleConv, self).__init__()self.conv = nn.Sequential(nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, stride=1, bias=False),# Same conv. not valid conv. in original papernn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),)def forward(self, x):return self.conv(x)# 定义主网络框架
class UNet(nn.Module):def __init__(self, original_channels=1, num_classes=1):super(UNet, self).__init__()self.original_channels = original_channels # 输入可能是RGB 3通道,也有可能是灰度图 1通道,所以定义original_channels;self.num_classes = num_classes # 输出可能是二分类(前后景),也可能是多分类,所以定义num_classes# Contracting path: 卷积2次DoubleCon,下采样Maxpool 1次,一共编码5次self.encoder1 = DoubleConv(self.original_channels,mid_channels=64, out_channels=64)self.down1 = nn.MaxPool2d(kernel_size=2, stride=2)self.encoder2 = DoubleConv(in_channels=64, mid_channels=128,out_channels=128)self.down2 = nn.MaxPool2d(kernel_size=2, stride=2)self.encoder3 = DoubleConv(in_channels=128, mid_channels=256,out_channels=256)self.down3 = nn.MaxPool2d(kernel_size=2, stride=2)self.encoder4 = DoubleConv(in_channels=256, mid_channels=512,out_channels=512)self.down4 = nn.MaxPool2d(kernel_size=2, stride=2)self.encoder5 = DoubleConv(in_channels=512,mid_channels=1024, out_channels=1024)# Expansive path: 上采样ConvTranspose 1次,卷积2次DoubleCon,一共解码5次,最后一次为1*1Conv.# Remark:通道拼接放在正向传播中做,注意编码和上采样的channels匹配的问题self.up1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)self.decoder1 = DoubleConv(in_channels=1024, mid_channels=512,out_channels=512)self.up2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)self.decoder2 = DoubleConv(in_channels=512,mid_channels=256, out_channels=256)self.up3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)self.decoder3 = DoubleConv(in_channels=256,mid_channels=128, out_channels=128)self.up4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)self.decoder4 = DoubleConv(in_channels=128, mid_channels=64,out_channels=64)self.decoder5 = nn.Conv2d(64, num_classes, kernel_size=1)# 定义正向传播的过程def forward(self, x):encoder1 = self.encoder1(x) # original_channel → 64_channels 512*512 → 512*512encoder1_pool = self.down1(encoder1) # 64_channels → 64_channels 512*512 → 256*256encoder2 = self.encoder2(encoder1_pool) # 64_channels → 128_channels 256*256 → 256*256encoder2_pool = self.down2(encoder2) # 128_channels → 128_channels 256*256 → 128*128encoder3 = self.encoder3(encoder2_pool) # 128_channels → 256_channels 128*128 → 128*128encoder3_pool = self.down3(encoder3) # 256_channels → 256_channels 128*128 → 64*64encoder4 = self.encoder4(encoder3_pool) # 256_channels → 512_channels 64*64 → 64*64encoder4_pool = self.down4(encoder4) # 512_channels → 512_channels 64*64 → 32*32encoder5 = self.encoder5(encoder4_pool) # 512_channels → 1024_channels 32*32 → 32*32decoder1_up = self.up1(encoder5) # 1024_channels → 512_channels 32*32 → 64*64decoder1 = self.decoder1(torch.cat((encoder4, decoder1_up), dim=1))# 512+512_channels → 512_channels 64*64 → 64*64decoder2_up = self.up2(decoder1) # 512_channels → 256_channels 64*64 → 128*128decoder2 = self.decoder2(torch.cat((encoder3, decoder2_up), dim=1))# 256+256_channels → 256_channels 128*128 → 128*128decoder3_up = self.up3(decoder2) # 256_channels → 128_channels 128*64 → 256*256decoder3 = self.decoder3(torch.cat((encoder2, decoder3_up), dim=1))# 128+128_channels → 128_channels 256*256 → 256*256decoder4_up = self.up4(decoder3) # 128_channels → 64_channels 256*256 → 256*256decoder4 = self.decoder4(torch.cat((encoder1, decoder4_up), dim=1))# 64+64_channels → 64_channels 256*256 → 512*512out = self.decoder5(decoder4) # 64_channels → num_classes channels 512*512 → 512*512return out"""
下面三行代码是验证模型架构用,实际不需要
"""
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# unet = UNet().to(device)
# summary(unet, (1, 512, 512))
使用pytorch的summary功能,如果没有可以pip install summary
三行测试代码开启后,可以查看模型的架构如下
----------------------------------------------------------------Layer (type) Output Shape Param #
================================================================Conv2d-1 [-1, 64, 512, 512] 576BatchNorm2d-2 [-1, 64, 512, 512] 128ReLU-3 [-1, 64, 512, 512] 0Conv2d-4 [-1, 64, 512, 512] 36,864BatchNorm2d-5 [-1, 64, 512, 512] 128ReLU-6 [-1, 64, 512, 512] 0DoubleConv-7 [-1, 64, 512, 512] 0MaxPool2d-8 [-1, 64, 256, 256] 0Conv2d-9 [-1, 128, 256, 256] 73,728BatchNorm2d-10 [-1, 128, 256, 256] 256ReLU-11 [-1, 128, 256, 256] 0Conv2d-12 [-1, 128, 256, 256] 147,456BatchNorm2d-13 [-1, 128, 256, 256] 256ReLU-14 [-1, 128, 256, 256] 0DoubleConv-15 [-1, 128, 256, 256] 0MaxPool2d-16 [-1, 128, 128, 128] 0Conv2d-17 [-1, 256, 128, 128] 294,912BatchNorm2d-18 [-1, 256, 128, 128] 512ReLU-19 [-1, 256, 128, 128] 0Conv2d-20 [-1, 256, 128, 128] 589,824BatchNorm2d-21 [-1, 256, 128, 128] 512ReLU-22 [-1, 256, 128, 128] 0DoubleConv-23 [-1, 256, 128, 128] 0MaxPool2d-24 [-1, 256, 64, 64] 0Conv2d-25 [-1, 512, 64, 64] 1,179,648BatchNorm2d-26 [-1, 512, 64, 64] 1,024ReLU-27 [-1, 512, 64, 64] 0Conv2d-28 [-1, 512, 64, 64] 2,359,296BatchNorm2d-29 [-1, 512, 64, 64] 1,024ReLU-30 [-1, 512, 64, 64] 0DoubleConv-31 [-1, 512, 64, 64] 0MaxPool2d-32 [-1, 512, 32, 32] 0Conv2d-33 [-1, 1024, 32, 32] 4,718,592BatchNorm2d-34 [-1, 1024, 32, 32] 2,048ReLU-35 [-1, 1024, 32, 32] 0Conv2d-36 [-1, 1024, 32, 32] 9,437,184BatchNorm2d-37 [-1, 1024, 32, 32] 2,048ReLU-38 [-1, 1024, 32, 32] 0DoubleConv-39 [-1, 1024, 32, 32] 0ConvTranspose2d-40 [-1, 512, 64, 64] 2,097,664Conv2d-41 [-1, 512, 64, 64] 4,718,592BatchNorm2d-42 [-1, 512, 64, 64] 1,024ReLU-43 [-1, 512, 64, 64] 0Conv2d-44 [-1, 512, 64, 64] 2,359,296BatchNorm2d-45 [-1, 512, 64, 64] 1,024ReLU-46 [-1, 512, 64, 64] 0DoubleConv-47 [-1, 512, 64, 64] 0ConvTranspose2d-48 [-1, 256, 128, 128] 524,544Conv2d-49 [-1, 256, 128, 128] 1,179,648BatchNorm2d-50 [-1, 256, 128, 128] 512ReLU-51 [-1, 256, 128, 128] 0Conv2d-52 [-1, 256, 128, 128] 589,824BatchNorm2d-53 [-1, 256, 128, 128] 512ReLU-54 [-1, 256, 128, 128] 0DoubleConv-55 [-1, 256, 128, 128] 0ConvTranspose2d-56 [-1, 128, 256, 256] 131,200Conv2d-57 [-1, 128, 256, 256] 294,912BatchNorm2d-58 [-1, 128, 256, 256] 256ReLU-59 [-1, 128, 256, 256] 0Conv2d-60 [-1, 128, 256, 256] 147,456BatchNorm2d-61 [-1, 128, 256, 256] 256ReLU-62 [-1, 128, 256, 256] 0DoubleConv-63 [-1, 128, 256, 256] 0ConvTranspose2d-64 [-1, 64, 512, 512] 32,832Conv2d-65 [-1, 64, 512, 512] 73,728BatchNorm2d-66 [-1, 64, 512, 512] 128ReLU-67 [-1, 64, 512, 512] 0Conv2d-68 [-1, 64, 512, 512] 36,864BatchNorm2d-69 [-1, 64, 512, 512] 128ReLU-70 [-1, 64, 512, 512] 0DoubleConv-71 [-1, 64, 512, 512] 0Conv2d-72 [-1, 1, 512, 512] 65
================================================================
Total params: 31,036,481
Trainable params: 31,036,481
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.00
Forward/backward pass size (MB): 3718.00
Params size (MB): 118.39
Estimated Total Size (MB): 3837.39
----------------------------------------------------------------`
U-Net模型搭建python实现相关推荐
- python利器app怎么查文献-科研人必备:一个工具搞定文献查阅、数据分析、模型搭建...
原标题:科研人必备:一个工具搞定文献查阅.数据分析.模型搭建 写论文有多难?这首诗形容得好: 进入学校先选题,踌躇满志万人敌:发现前辈都做过,满脸懵逼加惊奇. 终于找到大空白,我真是个小天才:左试右试 ...
- kaggle实战—泰坦尼克(五、模型搭建-模型评估)
kaggle实战-泰坦尼克(一.数据分析) kaggle实战-泰坦尼克(二.数据清洗及特征处理) kaggle实战-泰坦尼克(三.数据重构) kaggle实战-泰坦尼克(四.数据可视化) kaggle ...
- 目标检测——夏侯南溪模型搭建篇(legacy)
4 定义模型整体结构 4.3.1 输入样本的预处理操作 (注意:训练样本在输入网路之前,需要进行归一化,包括:均值归一化操作 具体可以看看我写的<目标检测--输入数据的归一化操作>) MT ...
- 泰坦尼克号 第三章 模型搭建和评估
第三章 模型搭建和评估–建模 经过前面的两章的知识点的学习,我可以对数数据的本身进行处理,比如数据本身的增删查补,还可以做必要的清洗工作.那么下面我们就要开始使用我们前面处理好的数据了.这一章我们要做 ...
- Flask项目实战——10—(前台板块页面搭建、文本编辑页面搭建、发布帖子信息前验证权限、帖子模型搭建、发布帖子功能、帖子信息渲染到前后台页面)
1.前台板块页面搭建 视图文件查询数据传输到前台界面:前台蓝图文件:apps/front/views.py 注意数据的收集方法和数据传输的类型. # -*- encoding: utf-8 -*- & ...
- PyTorch模型搭建和源码详解
文章目录: 一.VGG模型简单介绍 二.PyTorch源码分析 三.预训练模型的使用 本文是以VGG模型为例,深入介绍了完整的模型搭建过程,以及预训练模型使用过程,希望本篇博客可以解答一些困惑,同时欢 ...
- Kaggle猫狗大战模型搭建总结
0.前言 基于我暑假内学习的深度学习理论知识,我的学长建议我仿照他所设计的猫狗大战模型来尝试运用tensorflow与keras搭建神经网络,虽然我对python并不是很了解,但我依旧愿意尝试搭建来提 ...
- 如何使用机器学习自动修复bug: 数据处理和模型搭建
如何使用机器学习自动修复bug: 数据处理和模型搭建 上一篇<如何使用机器学习自动修复bug: 上手指南>我们介绍了使用CodeBERT自动修复bug的操作方法. 估计对于很多想了解原理的 ...
- 机器学习——决策树模型:Python实现
机器学习--决策树模型:Python实现 1 决策树模型的代码实现 1.1 分类决策树模型(DecisionTreeClassifier) 1.2 回归决策树模型(DecisionTreeRegres ...
- Paddle模型搭建-从keras转换为Paddle
瞎扯淡的部分 keras其实是高度封装的一个神经网络模块,优点就是可以很方便的进行开发,缺点就是很多情况下只能用现成的Layer去构建模型,比如我需要用神经网络去进行控制,那么在控制量和输出量两层中间 ...
最新文章
- 如何为模型选择合适的损失函数?所有ML学习者应该知道的5种回归损失函数
- WINDOWSPHONE STUDY1:创建一个 Windows Phone 7 下的简单 RSS 阅读器
- Java 多线程使用
- minifilter
- 评论设置----第二章:创建和管理内容
- SQL Server Management Studio消失了
- 机器学习的理论知识点总结
- 我为什么对TypeScript由黑转粉?
- Linux Shell——函数的使用
- python查询oracle数据库_python针对Oracle常见查询操作实例分析
- mysql安装被打断_MySQL安装未响应解决方法
- Ranger开源流水线docker化实践案例
- 【论文研读】【医学图像分割】【FCN+RNN】Recurrent Neural Networks for Aortic Image Sequence Segmentation with ...
- 高质量响应式的 HTML/CSS 网站模板
- MySQL不同版本多实例部署——5.7和8.0
- plc ge c语言编程,GE PLC编程软件是什么
- STM32(3):番外篇之STM32名字解析
- Java类属性字段校验(validation的使用)
- svc预测概率_sklearn-SVC实现与类参数
- mysql内表和外表_hive内表和外表的创建、载入数据、区别
热门文章
- 【数据库认证】OCM准备及考试经验总结
- ajax $.get怎么使用,jquery之ajax之$.get方法的使用
- 【ansys workbench】19.力学计算对比学习
- 台式计算机如何上无线网络,台式机如何无线上网
- 关于win7 环境下安装docker容器的步骤 以及过程中的问题解决
- Deconstructing laws of accessibility and facility distribution in cities
- iOS armv7, armv7s, arm64区别与应用32位、64位配置
- 解决pr调用麦克风的问题
- 2019多校 7.29
- 零基础能不能学计算机专业,零基础能学计算机专业吗?