pytorch版本的Unet网络可以去github上面下载,网址为https://github.com/milesial/Pytorch-UNet,话不多说,还是以代码为例吧。
有小伙伴问我pytorch的型号,发图给大家参考一下,文章写得有点久了…好多东西我自己都记不太清楚了,体谅一下~

1、dataset.py

  这个数据集采用的是汽车的数据集,数据集当中返回的是一个字典:

        return {'image': torch.from_numpy(img).type(torch.FloatTensor),'mask': torch.from_numpy(mask).type(torch.FloatTensor)}

  image返回的则是汽车的图片,如下图:

  mask则返回的是图层蒙版,如下图:

2、Unet模型

  代码分为Unet_model.py以及Unet_part.py
  Unet网络图如下所示:

  再看一下网络大体的代码结构:

class UNet(nn.Module):def __init__(self, n_channels, n_classes, bilinear=True):super(UNet, self).__init__()self.n_channels = n_channelsself.n_classes = n_classesself.bilinear = bilinear

  n_classes:希望获得的每个像素的概率数,对于一个类和背景,使用n_classes=1,这里输出的就是黑白对照,所以使用1;n_channels=3是因为输入的图片是RGB 图像,因此是三维;bilinear则用于上采样。

        self.inc = DoubleConv(n_channels, 64)

  首先输入一张图片,通过DoubleConv将通道数变为64,图片的大小改变就对应的公式[(n1-n2)/s+1],(其中n1是图片大小,n2是卷积核大小,s是滑动步长,默认为1,因此图片大小由572->570->568),DoubleConv对应的就是下采样中的每一行的卷积与Relu,可以看到每一行的通道数是没有发生改变,找到这部分的代码:

class DoubleConv(nn.Module):"""(convolution => [BN] => ReLU) * 2"""def __init__(self, in_channels, out_channels, mid_channels=None):super().__init__()if not mid_channels:mid_channels = out_channelsself.double_conv = nn.Sequential(nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),nn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self, x):return self.double_conv(x)

  DoubleConv主要是用于两次卷积,如果有中间通道的话可以作为桥梁,先卷积到中间通道数,然后再卷积到输出通道数。

        self.down1 = Down(64, 128)self.down2 = Down(128, 256)self.down3 = Down(256, 512)factor = 2 if bilinear else 1self.down4 = Down(512, 1024 // factor)

  再看一下下采样的后续过程,找到Down代码:

class Down(nn.Module):def __init__(self, in_channels, out_channels):super(Down,self).__init__()self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2),DoubleConv(in_channels, out_channels))def forward(self, x):return self.maxpool_conv(x)

  主要是先进行最大池化,将图片大小变为原来的一半,然后再采用DoubleConv增加通道数。这样经过3次下采样,可以看到图片通道数为512,大小为64*64,此时还需要进行第4次下采样,由于后续要进行上采样,需要将每一层上采样对应的特征图与下采样对应的特征图进行融合,能够充分获得有用信息,融合时需要通道数进行对应,因此输出通道数为512,且图片大小为28*28,对应的forward代码为:

    def forward(self, x):x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)

  接着看一下后续上采样初始化的代码:

        self.up1 = Up(1024, 512 // factor, bilinear)self.up2 = Up(512, 256 // factor, bilinear)self.up3 = Up(256, 128 // factor, bilinear)self.up4 = Up(128, 64, bilinear)

  上采样则采用了bilinear,看一下Up的代码:

class Up(nn.Module):"""Upscaling then double conv"""def __init__(self, in_channels, out_channels, bilinear=True):super(Up,self).__init__()if bilinear:self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)else:self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels)

  在nn.Upsample函数中,scale_factor指定输出大小为输入的多少倍数,mode:可使用的上采样算法,align_corners为True,输入的角像素将与输出张量对齐,因此将保存下来这些像素的值,nn.ConvTranspose2d是反卷积,对卷积层进行上采样,使其回到原始图片的分辨率。而对应的forward代码为:

  def forward(self, x1, x2):x1 = self.up(x1)diffY = x2.size()[2] - x1.size()[2]diffX = x2.size()[3] - x1.size()[3]x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])x = torch.cat([x2, x1], dim=1)#进行融合裁剪return self.conv(x)

  上采样的过程中需要对两个特征图进行融合,通道数一样并且尺寸也应该一样,x1是上采样获得的特征,而x2是下采样获得的特征,首先对x1进行反卷积使其大小变为输入时的2倍,首先需要计算两张图长宽的差值,作为填补padding的依据,由于此时图片的表示为(C,H,W),因此diffY对应的图片的高,diffX对应图片的宽度, F.pad指的是(左填充,右填充,上填充,下填充),其数值代表填充次数,因此需要/2,最后进行融合剪裁。
  上采样所对应的forword代码:

        x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)

  以第一层上采样为例,x5对应的是最后一次下采样获得的图片,通道数为512,大小为28*28,x4是第三次下采样获得的图片,通道为512,大小为64*64,首先将x5的特征大小变为2倍为56*56,然后长宽差距为8,所以周围分别补4个0,再和x4进行竖向拼接,因此输出通道数为1024,大小为64*64,然后就继续进行三次上采样,最终获得的图片通道为64,大小为572(跟图不符,但是用过程是没问题的,用代码测试过了),此时就已经变成了跟原来图片大小了,接着:

 self.outc = OutConv(64, n_classes)

  看一下OutConv代码:

class OutConv(nn.Module):def __init__(self, in_channels, out_channels):super(OutConv, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):return self.conv(x)

  仍然采用了两次卷积,此时由于卷积核为1*1大小,因此不改变图片的大小,forword代码为:

     logits = self.outc(x)

  看一下每一层输出的结果吧,设原始图片大小为[1, 3, 572, 572]:

输入图片: torch.Size([1, 3, 572, 572])
下采样x1: torch.Size([1, 64, 572, 572])
下采样x2: torch.Size([1, 128, 286, 286])
下采样x3: torch.Size([1, 256, 143, 143])
下采样x4: torch.Size([1, 512, 71, 71])
下采样x5: torch.Size([1, 512, 35, 35])
上采样x4: torch.Size([1, 256, 71, 71])
上采样x3: torch.Size([1, 128, 143, 143])
上采样x2: torch.Size([1, 64, 286, 286])
上采样x1: torch.Size([1, 64, 572, 572])
输出图片: torch.Size([1, 1, 572, 572])

  此时输出的则是黑白图片了,黑白图片plt输出要压缩到2维才行。
  尝试进行一下输出:

Pytorch:Unet网络代码详解相关推荐

  1. CLIP(Contrastive Language-Image Pretraining)主体网络代码详解

    CLIP是OpenAI于2021年发表的工作,其采用无监督学习中的对比学习的训练方法,使用了规模巨大的数据集(4亿个图片文本对)来进行训练,其在多个数据集上均得到了让人欣喜的结果,有效地证实了NLP与 ...

  2. pytorch中resnet_ResNet代码详解

    代码学习第一天! fighting! import torch.nn as nn import math import torch.utils.model_zoo as model_zoo# 这个文件 ...

  3. Resnet50残差网络代码详解

    Resnet50是Resnet残差网络系列的代表网络,由Kaiming于2016年发表于CVPR 论文地址:CVPR 2016 Open Access Repository 参考代码:https:// ...

  4. DenseNet网络代码详解

    这个代码是pytorch官方实现的代码,自己做了些备注,主要是方便自己以后学习和使用. 下图自己根据代码画的densenet169的网络结构图,输入图片的尺寸跟官方有所不同,而且对过度层的平均池化也做 ...

  5. Pytorch Bi-LSTM + CRF 代码详解

    久闻LSTM + CRF的效果强大,最近在看Pytorch官网文档的时候,看到了这段代码,前前后后查了很多资料,终于把代码弄懂了.我希望在后来人看这段代码的时候,直接就看我的博客就能完全弄懂这段代码. ...

  6. pytorch BiLSTM+CRF代码详解 重点

    一. BILSTM + CRF介绍 https://www.jianshu.com/p/97cb3b6db573 1.介绍 基于神经网络的方法,在命名实体识别任务中非常流行和普遍. 如果你不知道Bi- ...

  7. 【Gans入门】Pytorch实现Gans代码详解【70+代码】

    简述 由于科技论文老师要求阅读Gans论文并在网上找到类似的代码来学习. 文章目录 简述 代码来源 代码含义概览 代码分段解释 导入包: 设置参数: 给出标准数据: 构建模型: 构建优化器 迭代细节 ...

  8. 基于U-Net的的图像分割代码详解及应用实现

    摘要 U-Net是基于卷积神经网络(CNN)体系结构设计而成的,由Olaf Ronneberger,Phillip Fischer和Thomas Brox于2015年首次提出应用于计算机视觉领域完成语 ...

  9. PyTorch 迁移学习 (Transfer Learning) 代码详解

    PyTorch 迁移学习 代码详解 概述 为什么使用迁移学习 更好的结果 节省时间 加载模型 ResNet152 冻层实现 模型初始化 获取需更新参数 训练模型 获取数据 完整代码 概述 迁移学习 ( ...

  10. Alphapose论文代码详解

    注:B站有相应视频,点击此链接即可跳转观看https://www.bilibili.com/video/BV1hb4y117mu/ 第1节 人体姿态估计的基本概念 第2节:Alphapose 2.1A ...

最新文章

  1. 嵌入式linux仪器,一种基于嵌入式Linux设备双系统的启动方法
  2. java学习教程之代码块
  3. centos 输入密码正确进不去系统
  4. python资源管理器选择文件_Python:在资源管理器中获取选定文件的列表(windows7)...
  5. java面向对象三个关键字,Java 面向对象(三)static 关键字
  6. 华为 “OSPF” 单区域配置
  7. ViBe算法核心思想
  8. 用matlab的毕业设计,毕业设计课题: 用 MATLAB.ppt
  9. js 动态生成表格案例
  10. 视音频编解码H264,265,MPEG-4,VP8,VP9知识总结
  11. NO JVM installation found. please install a 64-bit JDK,解决方法   Error launching android studio   NO J
  12. window7 安装grldr
  13. 彻底关闭Win10自动更新(Win10企业版或专业版)
  14. Axure的基本原件
  15. Python - 体脂率
  16. 19、android面试题整理(自己给自己充充电吧)
  17. 记录学习《流畅的python》的一些知识-----对象引用,可变性和垃圾回收
  18. 山野村夫的总提纲!……还是羞于见人啦=////=
  19. 怎么把mkv文件转成mp4格式,3招立马处理
  20. 电信黑莓手机出国漫游注意事项

热门文章

  1. php加密解密文件内容,php文件加密解密 - osc_0g0vbf0z的个人空间 - OSCHINA - 中文开源技术交流社区...
  2. kali实现ARP断网
  3. JavaWeb医院挂号系统
  4. idea怎么运行c语言程序,IntelliJ IDEA 10.0 64位运行方法
  5. 近期活动盘点:个人消费信贷与风险控制讲座、清华大学教育大数据论坛
  6. 接口动态签名,防止被人恶意调用
  7. linux smb无法访问服务器,samba服务器访问失败
  8. 自动搜索关键词点击广告或网站,自动换ip实现过程
  9. 逻辑门 与 买猫电路升级版
  10. python实现字符串去重