最近在学习U-Net,读完文章之后尝试搭建模型的框架,阅读了前人的模型后,试着自己搭建了一下,适合初学者。

import torch
import torch.nn as nn
from torchsummary import summary#   2次 3*3卷积
#   Remark,第一个3*3卷积承担,升降维的功能
class DoubleConv(nn.Module):def __init__(self, in_channels,mid_channels, out_channels):super(DoubleConv, self).__init__()self.conv = nn.Sequential(nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, stride=1, bias=False),# Same conv. not valid conv. in original papernn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),)def forward(self, x):return self.conv(x)# 定义主网络框架
class UNet(nn.Module):def __init__(self, original_channels=1, num_classes=1):super(UNet, self).__init__()self.original_channels = original_channels  # 输入可能是RGB 3通道,也有可能是灰度图 1通道,所以定义original_channels;self.num_classes = num_classes  # 输出可能是二分类(前后景),也可能是多分类,所以定义num_classes# Contracting path: 卷积2次DoubleCon,下采样Maxpool 1次,一共编码5次self.encoder1 = DoubleConv(self.original_channels,mid_channels=64, out_channels=64)self.down1 = nn.MaxPool2d(kernel_size=2, stride=2)self.encoder2 = DoubleConv(in_channels=64, mid_channels=128,out_channels=128)self.down2 = nn.MaxPool2d(kernel_size=2, stride=2)self.encoder3 = DoubleConv(in_channels=128, mid_channels=256,out_channels=256)self.down3 = nn.MaxPool2d(kernel_size=2, stride=2)self.encoder4 = DoubleConv(in_channels=256, mid_channels=512,out_channels=512)self.down4 = nn.MaxPool2d(kernel_size=2, stride=2)self.encoder5 = DoubleConv(in_channels=512,mid_channels=1024, out_channels=1024)# Expansive path: 上采样ConvTranspose 1次,卷积2次DoubleCon,一共解码5次,最后一次为1*1Conv.# Remark:通道拼接放在正向传播中做,注意编码和上采样的channels匹配的问题self.up1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)self.decoder1 = DoubleConv(in_channels=1024, mid_channels=512,out_channels=512)self.up2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)self.decoder2 = DoubleConv(in_channels=512,mid_channels=256, out_channels=256)self.up3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)self.decoder3 = DoubleConv(in_channels=256,mid_channels=128, out_channels=128)self.up4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)self.decoder4 = DoubleConv(in_channels=128, mid_channels=64,out_channels=64)self.decoder5 = nn.Conv2d(64, num_classes, kernel_size=1)# 定义正向传播的过程def forward(self, x):encoder1 = self.encoder1(x)              # original_channel → 64_channels  512*512 → 512*512encoder1_pool = self.down1(encoder1)     # 64_channels → 64_channels       512*512 → 256*256encoder2 = self.encoder2(encoder1_pool)  # 64_channels → 128_channels      256*256 → 256*256encoder2_pool = self.down2(encoder2)     # 128_channels → 128_channels     256*256 → 128*128encoder3 = self.encoder3(encoder2_pool)  # 128_channels → 256_channels     128*128 → 128*128encoder3_pool = self.down3(encoder3)     # 256_channels → 256_channels     128*128 → 64*64encoder4 = self.encoder4(encoder3_pool)  # 256_channels → 512_channels       64*64 → 64*64encoder4_pool = self.down4(encoder4)     # 512_channels → 512_channels       64*64 → 32*32encoder5 = self.encoder5(encoder4_pool)  # 512_channels → 1024_channels      32*32 → 32*32decoder1_up = self.up1(encoder5)         # 1024_channels → 512_channels      32*32 → 64*64decoder1 = self.decoder1(torch.cat((encoder4, decoder1_up), dim=1))# 512+512_channels → 512_channels   64*64 → 64*64decoder2_up = self.up2(decoder1)         # 512_channels → 256_channels       64*64 → 128*128decoder2 = self.decoder2(torch.cat((encoder3, decoder2_up), dim=1))# 256+256_channels → 256_channels   128*128 → 128*128decoder3_up = self.up3(decoder2)         # 256_channels → 128_channels       128*64 → 256*256decoder3 = self.decoder3(torch.cat((encoder2, decoder3_up), dim=1))# 128+128_channels → 128_channels   256*256 → 256*256decoder4_up = self.up4(decoder3)         # 128_channels → 64_channels        256*256 → 256*256decoder4 = self.decoder4(torch.cat((encoder1, decoder4_up), dim=1))# 64+64_channels → 64_channels      256*256 → 512*512out = self.decoder5(decoder4)            # 64_channels → num_classes channels 512*512 → 512*512return out"""
下面三行代码是验证模型架构用,实际不需要
"""
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# unet = UNet().to(device)
# summary(unet, (1, 512, 512))

使用pytorch的summary功能,如果没有可以pip install summary
三行测试代码开启后,可以查看模型的架构如下

----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [-1, 64, 512, 512]             576BatchNorm2d-2         [-1, 64, 512, 512]             128ReLU-3         [-1, 64, 512, 512]               0Conv2d-4         [-1, 64, 512, 512]          36,864BatchNorm2d-5         [-1, 64, 512, 512]             128ReLU-6         [-1, 64, 512, 512]               0DoubleConv-7         [-1, 64, 512, 512]               0MaxPool2d-8         [-1, 64, 256, 256]               0Conv2d-9        [-1, 128, 256, 256]          73,728BatchNorm2d-10        [-1, 128, 256, 256]             256ReLU-11        [-1, 128, 256, 256]               0Conv2d-12        [-1, 128, 256, 256]         147,456BatchNorm2d-13        [-1, 128, 256, 256]             256ReLU-14        [-1, 128, 256, 256]               0DoubleConv-15        [-1, 128, 256, 256]               0MaxPool2d-16        [-1, 128, 128, 128]               0Conv2d-17        [-1, 256, 128, 128]         294,912BatchNorm2d-18        [-1, 256, 128, 128]             512ReLU-19        [-1, 256, 128, 128]               0Conv2d-20        [-1, 256, 128, 128]         589,824BatchNorm2d-21        [-1, 256, 128, 128]             512ReLU-22        [-1, 256, 128, 128]               0DoubleConv-23        [-1, 256, 128, 128]               0MaxPool2d-24          [-1, 256, 64, 64]               0Conv2d-25          [-1, 512, 64, 64]       1,179,648BatchNorm2d-26          [-1, 512, 64, 64]           1,024ReLU-27          [-1, 512, 64, 64]               0Conv2d-28          [-1, 512, 64, 64]       2,359,296BatchNorm2d-29          [-1, 512, 64, 64]           1,024ReLU-30          [-1, 512, 64, 64]               0DoubleConv-31          [-1, 512, 64, 64]               0MaxPool2d-32          [-1, 512, 32, 32]               0Conv2d-33         [-1, 1024, 32, 32]       4,718,592BatchNorm2d-34         [-1, 1024, 32, 32]           2,048ReLU-35         [-1, 1024, 32, 32]               0Conv2d-36         [-1, 1024, 32, 32]       9,437,184BatchNorm2d-37         [-1, 1024, 32, 32]           2,048ReLU-38         [-1, 1024, 32, 32]               0DoubleConv-39         [-1, 1024, 32, 32]               0ConvTranspose2d-40          [-1, 512, 64, 64]       2,097,664Conv2d-41          [-1, 512, 64, 64]       4,718,592BatchNorm2d-42          [-1, 512, 64, 64]           1,024ReLU-43          [-1, 512, 64, 64]               0Conv2d-44          [-1, 512, 64, 64]       2,359,296BatchNorm2d-45          [-1, 512, 64, 64]           1,024ReLU-46          [-1, 512, 64, 64]               0DoubleConv-47          [-1, 512, 64, 64]               0ConvTranspose2d-48        [-1, 256, 128, 128]         524,544Conv2d-49        [-1, 256, 128, 128]       1,179,648BatchNorm2d-50        [-1, 256, 128, 128]             512ReLU-51        [-1, 256, 128, 128]               0Conv2d-52        [-1, 256, 128, 128]         589,824BatchNorm2d-53        [-1, 256, 128, 128]             512ReLU-54        [-1, 256, 128, 128]               0DoubleConv-55        [-1, 256, 128, 128]               0ConvTranspose2d-56        [-1, 128, 256, 256]         131,200Conv2d-57        [-1, 128, 256, 256]         294,912BatchNorm2d-58        [-1, 128, 256, 256]             256ReLU-59        [-1, 128, 256, 256]               0Conv2d-60        [-1, 128, 256, 256]         147,456BatchNorm2d-61        [-1, 128, 256, 256]             256ReLU-62        [-1, 128, 256, 256]               0DoubleConv-63        [-1, 128, 256, 256]               0ConvTranspose2d-64         [-1, 64, 512, 512]          32,832Conv2d-65         [-1, 64, 512, 512]          73,728BatchNorm2d-66         [-1, 64, 512, 512]             128ReLU-67         [-1, 64, 512, 512]               0Conv2d-68         [-1, 64, 512, 512]          36,864BatchNorm2d-69         [-1, 64, 512, 512]             128ReLU-70         [-1, 64, 512, 512]               0DoubleConv-71         [-1, 64, 512, 512]               0Conv2d-72          [-1, 1, 512, 512]              65
================================================================
Total params: 31,036,481
Trainable params: 31,036,481
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.00
Forward/backward pass size (MB): 3718.00
Params size (MB): 118.39
Estimated Total Size (MB): 3837.39
----------------------------------------------------------------`

U-Net模型搭建python实现相关推荐

  1. python利器app怎么查文献-科研人必备:一个工具搞定文献查阅、数据分析、模型搭建...

    原标题:科研人必备:一个工具搞定文献查阅.数据分析.模型搭建 写论文有多难?这首诗形容得好: 进入学校先选题,踌躇满志万人敌:发现前辈都做过,满脸懵逼加惊奇. 终于找到大空白,我真是个小天才:左试右试 ...

  2. kaggle实战—泰坦尼克(五、模型搭建-模型评估)

    kaggle实战-泰坦尼克(一.数据分析) kaggle实战-泰坦尼克(二.数据清洗及特征处理) kaggle实战-泰坦尼克(三.数据重构) kaggle实战-泰坦尼克(四.数据可视化) kaggle ...

  3. 目标检测——夏侯南溪模型搭建篇(legacy)

    4 定义模型整体结构 4.3.1 输入样本的预处理操作 (注意:训练样本在输入网路之前,需要进行归一化,包括:均值归一化操作 具体可以看看我写的<目标检测--输入数据的归一化操作>) MT ...

  4. 泰坦尼克号 第三章 模型搭建和评估

    第三章 模型搭建和评估–建模 经过前面的两章的知识点的学习,我可以对数数据的本身进行处理,比如数据本身的增删查补,还可以做必要的清洗工作.那么下面我们就要开始使用我们前面处理好的数据了.这一章我们要做 ...

  5. Flask项目实战——10—(前台板块页面搭建、文本编辑页面搭建、发布帖子信息前验证权限、帖子模型搭建、发布帖子功能、帖子信息渲染到前后台页面)

    1.前台板块页面搭建 视图文件查询数据传输到前台界面:前台蓝图文件:apps/front/views.py 注意数据的收集方法和数据传输的类型. # -*- encoding: utf-8 -*- & ...

  6. PyTorch模型搭建和源码详解

    文章目录: 一.VGG模型简单介绍 二.PyTorch源码分析 三.预训练模型的使用 本文是以VGG模型为例,深入介绍了完整的模型搭建过程,以及预训练模型使用过程,希望本篇博客可以解答一些困惑,同时欢 ...

  7. Kaggle猫狗大战模型搭建总结

    0.前言 基于我暑假内学习的深度学习理论知识,我的学长建议我仿照他所设计的猫狗大战模型来尝试运用tensorflow与keras搭建神经网络,虽然我对python并不是很了解,但我依旧愿意尝试搭建来提 ...

  8. 如何使用机器学习自动修复bug: 数据处理和模型搭建

    如何使用机器学习自动修复bug: 数据处理和模型搭建 上一篇<如何使用机器学习自动修复bug: 上手指南>我们介绍了使用CodeBERT自动修复bug的操作方法. 估计对于很多想了解原理的 ...

  9. 机器学习——决策树模型:Python实现

    机器学习--决策树模型:Python实现 1 决策树模型的代码实现 1.1 分类决策树模型(DecisionTreeClassifier) 1.2 回归决策树模型(DecisionTreeRegres ...

  10. Paddle模型搭建-从keras转换为Paddle

    瞎扯淡的部分 keras其实是高度封装的一个神经网络模块,优点就是可以很方便的进行开发,缺点就是很多情况下只能用现成的Layer去构建模型,比如我需要用神经网络去进行控制,那么在控制量和输出量两层中间 ...

最新文章

  1. 如何为模型选择合适的损失函数?所有ML学习者应该知道的5种回归损失函数
  2. WINDOWSPHONE STUDY1:创建一个 Windows Phone 7 下的简单 RSS 阅读器
  3. Java 多线程使用
  4. minifilter
  5. 评论设置----第二章:创建和管理内容
  6. SQL Server Management Studio消失了
  7. 机器学习的理论知识点总结
  8. 我为什么对TypeScript由黑转粉?
  9. Linux Shell——函数的使用
  10. python查询oracle数据库_python针对Oracle常见查询操作实例分析
  11. mysql安装被打断_MySQL安装未响应解决方法
  12. Ranger开源流水线docker化实践案例
  13. 【论文研读】【医学图像分割】【FCN+RNN】Recurrent Neural Networks for Aortic Image Sequence Segmentation with ...
  14. 高质量响应式的 HTML/CSS 网站模板
  15. MySQL不同版本多实例部署——5.7和8.0
  16. plc ge c语言编程,GE PLC编程软件是什么
  17. STM32(3):番外篇之STM32名字解析
  18. Java类属性字段校验(validation的使用)
  19. svc预测概率_sklearn-SVC实现与类参数
  20. mysql内表和外表_hive内表和外表的创建、载入数据、区别

热门文章

  1. 【数据库认证】OCM准备及考试经验总结
  2. ajax $.get怎么使用,jquery之ajax之$.get方法的使用
  3. 【ansys workbench】19.力学计算对比学习
  4. 台式计算机如何上无线网络,台式机如何无线上网
  5. 关于win7 环境下安装docker容器的步骤 以及过程中的问题解决
  6. Deconstructing laws of accessibility and facility distribution in cities
  7. iOS armv7, armv7s, arm64区别与应用32位、64位配置
  8. 解决pr调用麦克风的问题
  9. 2019多校 7.29
  10. 零基础能不能学计算机专业,零基础能学计算机专业吗?