Unet的网络结构:

根据该结构,用Pytorch实现Unet:

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
import torch.utils.data as Data seed = 2019
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
import random
np.random.seed(seed)  # Numpy module.
random.seed(seed)  # Python random module.
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True##定义卷积核
def default_conv(in_channels,out_channels,kernel_size,bias=True):return nn.Conv2d(in_channels,out_channels,kernel_size,padding=0, bias=bias)
##定义ReLU
def default_relu():return nn.ReLU(inplace=True)class Up_Sample(nn.Module):def __init__(self,in_channels,conv=default_conv,relu=default_relu):super(Up_Sample,self).__init__()up1 = nn.Upsample(scale_factor=2,mode='nearest')up2 = conv(in_channels,in_channels//2,1)self.module_up = nn.Sequential(up1,up2,relu())def forward(self,input_down,input_left):x = self.module_up(input_down)        dif = (input_left.shape[3] - x.shape[3])/2input_left = input_left[:,:,int(dif):int(dif+x.shape[3]),int(dif):int(dif+x.shape[3])]return torch.cat((x,input_left),1)class Unet(nn.Module):def __init__(self,in_channels,out_channels,conv=default_conv,relu=default_relu,n_feats=64):super(Unet,self).__init__()left1 = [conv(in_channels,n_feats,3),relu(),conv(n_feats,n_feats,3)]left2 = [conv(n_feats,2*n_feats,3),relu(),conv(2*n_feats,2*n_feats,3)]left3 = [conv(2*n_feats,4*n_feats,3),relu(),conv(4*n_feats,4*n_feats,3)]left4 = [conv(4*n_feats,8*n_feats,3),relu(),conv(8*n_feats,8*n_feats,3)]bottom = [conv(8*n_feats,16*n_feats,3),relu(),conv(16*n_feats,16*n_feats,3)]right1 = [conv(2*n_feats,n_feats,3),relu(),conv(n_feats,n_feats,3)]right2 = [conv(4*n_feats,2*n_feats,3),relu(),conv(2*n_feats,2*n_feats,3)]right3 = [conv(8*n_feats,4*n_feats,3),relu(),conv(4*n_feats,4*n_feats,3)]right4 = [conv(16*n_feats,8*n_feats,3),relu(),conv(8*n_feats,8*n_feats,3)]self.left1 = nn.Sequential(*left1)       self.left2 = nn.Sequential(*left2)       self.left3 = nn.Sequential(*left3)       self.left4 = nn.Sequential(*left4)self.bottom = nn.Sequential(*bottom)self.right1 = nn.Sequential(*right1)self.right2 = nn.Sequential(*right2)self.right3 = nn.Sequential(*right3)self.right4 = nn.Sequential(*right4)self.tail = conv(n_feats,out_channels,1)down = []for layer in range(4):down.append(nn.MaxPool2d(kernel_size = 1,stride = 2))self.down = nn.Sequential(*down)up = nn.ModuleList()for layer in range(4):up.append(Up_Sample(in_channels=(2**(layer+1))*n_feats))self.up = nn.Sequential(*up)def forward(self,x):x1 = self.left1(x)x1d = self.down[0](x1)x2 = self.left2(x1d)x2d = self.down[1](x2)x3 = self.left3(x2d)x3d = self.down[2](x3)x4 = self.left4(x3d)x4d = self.down[3](x4)x_b = self.bottom(x4d)y4d = self.up[3](x_b,x4)y3 = self.right4(y4d)y3d = self.up[2](y3,x3)y2 = self.right3(y3d)y2d = self.up[1](y2,x2)y1 = self.right2(y2d)y1d = self.up[0](y1,x1)y = self.right1(y1d)out = self.tail(y)return outdef main():model = Unet(in_channels=1,out_channels=2)from torchsummary import summary    summary(model.cuda(), (1, 572, 572))if __name__=='__main__':main()

打印模型:

----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [-1, 64, 570, 570]             640ReLU-2         [-1, 64, 570, 570]               0Conv2d-3         [-1, 64, 568, 568]          36,928MaxPool2d-4         [-1, 64, 284, 284]               0Conv2d-5        [-1, 128, 282, 282]          73,856ReLU-6        [-1, 128, 282, 282]               0Conv2d-7        [-1, 128, 280, 280]         147,584MaxPool2d-8        [-1, 128, 140, 140]               0Conv2d-9        [-1, 256, 138, 138]         295,168ReLU-10        [-1, 256, 138, 138]               0Conv2d-11        [-1, 256, 136, 136]         590,080MaxPool2d-12          [-1, 256, 68, 68]               0Conv2d-13          [-1, 512, 66, 66]       1,180,160ReLU-14          [-1, 512, 66, 66]               0Conv2d-15          [-1, 512, 64, 64]       2,359,808MaxPool2d-16          [-1, 512, 32, 32]               0Conv2d-17         [-1, 1024, 30, 30]       4,719,616ReLU-18         [-1, 1024, 30, 30]               0Conv2d-19         [-1, 1024, 28, 28]       9,438,208Upsample-20         [-1, 1024, 56, 56]               0Conv2d-21          [-1, 512, 56, 56]         524,800ReLU-22          [-1, 512, 56, 56]               0Up_Sample-23         [-1, 1024, 56, 56]               0Conv2d-24          [-1, 512, 54, 54]       4,719,104ReLU-25          [-1, 512, 54, 54]               0Conv2d-26          [-1, 512, 52, 52]       2,359,808Upsample-27        [-1, 512, 104, 104]               0Conv2d-28        [-1, 256, 104, 104]         131,328ReLU-29        [-1, 256, 104, 104]               0Up_Sample-30        [-1, 512, 104, 104]               0Conv2d-31        [-1, 256, 102, 102]       1,179,904ReLU-32        [-1, 256, 102, 102]               0Conv2d-33        [-1, 256, 100, 100]         590,080Upsample-34        [-1, 256, 200, 200]               0Conv2d-35        [-1, 128, 200, 200]          32,896ReLU-36        [-1, 128, 200, 200]               0Up_Sample-37        [-1, 256, 200, 200]               0Conv2d-38        [-1, 128, 198, 198]         295,040ReLU-39        [-1, 128, 198, 198]               0Conv2d-40        [-1, 128, 196, 196]         147,584Upsample-41        [-1, 128, 392, 392]               0Conv2d-42         [-1, 64, 392, 392]           8,256ReLU-43         [-1, 64, 392, 392]               0Up_Sample-44        [-1, 128, 392, 392]               0Conv2d-45         [-1, 64, 390, 390]          73,792ReLU-46         [-1, 64, 390, 390]               0Conv2d-47         [-1, 64, 388, 388]          36,928Conv2d-48          [-1, 2, 388, 388]             130
================================================================
Total params: 28,941,698
Trainable params: 28,941,698
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.25
Forward/backward pass size (MB): 2275.74
Params size (MB): 110.40
Estimated Total Size (MB): 2387.39
----------------------------------------------------------------

Python Unet网络结构pytorch简单实现+torchsummary可视化(可以直接运行)相关推荐

  1. 基于Python Unet的医学影像分割系统源码,含皮肤病的数据及皮肤病分割的模型,用户输入图像,模型可以自动分割去皮肤病的区域

    手把手教你用Unet做医学图像分割 我们用Unet来做医学图像分割.我们将会以皮肤病的数据作为示范,训练一个皮肤病分割的模型出来,用户输入图像,模型可以自动分割去皮肤病的区域和正常的区域.废话不多说, ...

  2. PyTorch深度学习训练可视化工具tensorboardX

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 之前笔者提到了PyTorch的专属可视化工具visdom,参看Py ...

  3. pytorch | 深度学习分割网络U-net的pytorch模型实现

    原文:https://blog.csdn.net/u014722627/article/details/60883185 pytorch | 深度学习分割网络U-net的pytorch模型实现 这个是 ...

  4. pythonplotly k线 动态_GitHub - 846626465/PythonPlotlyCodes: 《Python 数据分析:基于 Plotly 的动态可视化绘图》 源代码...

    PythonPlotlyCodes <Python 数据分析:基于 Plotly 的动态可视化绘图> 源代码 前言 Python是一门非常优秀的编程语言,其语法简捷.易学易用,越来越受到编 ...

  5. pythonplotly k线 动态_GitHub - chitandacc/PythonPlotlyCodes: 《Python 数据分析:基于 Plotly 的动态可视化绘图》 源代码...

    PythonPlotlyCodes <Python 数据分析:基于 Plotly 的动态可视化绘图> 源代码 前言 Python是一门非常优秀的编程语言,其语法简捷.易学易用,越来越受到编 ...

  6. pythonplotly k线 动态_GitHub - Yanglian666/PythonPlotlyCodes: 《Python 数据分析:基于 Plotly 的动态可视化绘图》 源代码...

    PythonPlotlyCodes <Python 数据分析:基于 Plotly 的动态可视化绘图> 源代码 前言 Python是一门非常优秀的编程语言,其语法简捷.易学易用,越来越受到编 ...

  7. 【医学图像分割网络】之Res U-Net网络PyTorch复现

    [医学图像分割网络]之Res U-Net网络PyTorch复现 1.内容 U-Net网络算是医学图像分割领域的开山之作,我接触深度学习到现在大概将近大半年时间,看到了很多基于U-Net网络的变体,后续 ...

  8. 通过带Flask的REST API在Python中部署PyTorch

    通过带Flask的REST API在Python中部署PyTorch 在本文中,将使用Flask来部署PyTorch模型,并用讲解用于模型推断的 REST API.特别是,将部署一个预训练的Dense ...

  9. 大学python和vb哪个简单-python和vb哪个简单

    Visual Basic(简称VB)是Microsoft公司开发的一种通用的基于对象的程序设计语言,为结构化的.模块化的.面向对象的.包含协助开发环境的事件驱动为机制的可视化程序设计语言.是一种可用于 ...

最新文章

  1. 循环IRNNv2Layer实现
  2. qtmessagebox对话框里自定义按钮文本_Word里表格都是这么来的 — 生成绘制表格有技巧...
  3. 字符串数组中查找字符串
  4. 打开高效文本编辑之门_Linux Awk之条件判断与循环
  5. Linux基础命令---killall
  6. 华尔街弃儿:雷曼兄弟158岁被清算
  7. C语言之去掉https链接的默认443端口
  8. apache.camel_在即将发布的Camel 2.21版本中改进了使用Apache Camel和ActiveMQ Artemis处理大型消息的功能...
  9. React开发(124):ant design学习指南之form中的this.props.form
  10. 燃爆了!胡歌秒变最帅产品经理发布荣耀V20!
  11. 947. 移除最多的同行或同列石头2021-01-23
  12. MySQL存储过程-循环遍历查询到的结果集
  13. Revit2018下载和安装教程
  14. 孪生网络(1)_孪生网络的分类
  15. 拼多多破1000亿美金,黄峥自述:我的人生经历和创业理念
  16. hivesql uv
  17. Logo小变动,心境大不同,SVG矢量动画格式网站Logo图片制作与实践教程(Python3)
  18. python提取html中的href标签,如何使用Python从HTML获取href链接?
  19. 【Anki 牌组+Markdown笔记分享】汇编语言
  20. 定义一个复数类Complex,重载运算符“+”,“ -”,“*”,“/”使之能用于计算两个复数的加减乘除。

热门文章

  1. SQL Server高级查询
  2. C语言基础知识总结(简单算法套路)
  3. Java运算符的优先级和C语言中有何异同,C语言运算符优先级小结
  4. EmEditor注册码
  5. 英文网站注册常用词汇
  6. 跳槽后“好马也吃回头草”
  7. 考研数学 概率论争议题 [Python验证版]
  8. ORACLE EBS WORKFLOW实现多附件下载
  9. K8SEASY:一键安装K8S高可用集群
  10. java采购管理系统设计_Java毕业设计——采购管理系统的设计参考