Python Unet网络结构pytorch简单实现+torchsummary可视化(可以直接运行)
Unet的网络结构:
根据该结构,用Pytorch实现Unet:
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
import torch.utils.data as Data seed = 2019
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
import random
np.random.seed(seed) # Numpy module.
random.seed(seed) # Python random module.
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True##定义卷积核
def default_conv(in_channels,out_channels,kernel_size,bias=True):return nn.Conv2d(in_channels,out_channels,kernel_size,padding=0, bias=bias)
##定义ReLU
def default_relu():return nn.ReLU(inplace=True)class Up_Sample(nn.Module):def __init__(self,in_channels,conv=default_conv,relu=default_relu):super(Up_Sample,self).__init__()up1 = nn.Upsample(scale_factor=2,mode='nearest')up2 = conv(in_channels,in_channels//2,1)self.module_up = nn.Sequential(up1,up2,relu())def forward(self,input_down,input_left):x = self.module_up(input_down) dif = (input_left.shape[3] - x.shape[3])/2input_left = input_left[:,:,int(dif):int(dif+x.shape[3]),int(dif):int(dif+x.shape[3])]return torch.cat((x,input_left),1)class Unet(nn.Module):def __init__(self,in_channels,out_channels,conv=default_conv,relu=default_relu,n_feats=64):super(Unet,self).__init__()left1 = [conv(in_channels,n_feats,3),relu(),conv(n_feats,n_feats,3)]left2 = [conv(n_feats,2*n_feats,3),relu(),conv(2*n_feats,2*n_feats,3)]left3 = [conv(2*n_feats,4*n_feats,3),relu(),conv(4*n_feats,4*n_feats,3)]left4 = [conv(4*n_feats,8*n_feats,3),relu(),conv(8*n_feats,8*n_feats,3)]bottom = [conv(8*n_feats,16*n_feats,3),relu(),conv(16*n_feats,16*n_feats,3)]right1 = [conv(2*n_feats,n_feats,3),relu(),conv(n_feats,n_feats,3)]right2 = [conv(4*n_feats,2*n_feats,3),relu(),conv(2*n_feats,2*n_feats,3)]right3 = [conv(8*n_feats,4*n_feats,3),relu(),conv(4*n_feats,4*n_feats,3)]right4 = [conv(16*n_feats,8*n_feats,3),relu(),conv(8*n_feats,8*n_feats,3)]self.left1 = nn.Sequential(*left1) self.left2 = nn.Sequential(*left2) self.left3 = nn.Sequential(*left3) self.left4 = nn.Sequential(*left4)self.bottom = nn.Sequential(*bottom)self.right1 = nn.Sequential(*right1)self.right2 = nn.Sequential(*right2)self.right3 = nn.Sequential(*right3)self.right4 = nn.Sequential(*right4)self.tail = conv(n_feats,out_channels,1)down = []for layer in range(4):down.append(nn.MaxPool2d(kernel_size = 1,stride = 2))self.down = nn.Sequential(*down)up = nn.ModuleList()for layer in range(4):up.append(Up_Sample(in_channels=(2**(layer+1))*n_feats))self.up = nn.Sequential(*up)def forward(self,x):x1 = self.left1(x)x1d = self.down[0](x1)x2 = self.left2(x1d)x2d = self.down[1](x2)x3 = self.left3(x2d)x3d = self.down[2](x3)x4 = self.left4(x3d)x4d = self.down[3](x4)x_b = self.bottom(x4d)y4d = self.up[3](x_b,x4)y3 = self.right4(y4d)y3d = self.up[2](y3,x3)y2 = self.right3(y3d)y2d = self.up[1](y2,x2)y1 = self.right2(y2d)y1d = self.up[0](y1,x1)y = self.right1(y1d)out = self.tail(y)return outdef main():model = Unet(in_channels=1,out_channels=2)from torchsummary import summary summary(model.cuda(), (1, 572, 572))if __name__=='__main__':main()
打印模型:
----------------------------------------------------------------Layer (type) Output Shape Param #
================================================================Conv2d-1 [-1, 64, 570, 570] 640ReLU-2 [-1, 64, 570, 570] 0Conv2d-3 [-1, 64, 568, 568] 36,928MaxPool2d-4 [-1, 64, 284, 284] 0Conv2d-5 [-1, 128, 282, 282] 73,856ReLU-6 [-1, 128, 282, 282] 0Conv2d-7 [-1, 128, 280, 280] 147,584MaxPool2d-8 [-1, 128, 140, 140] 0Conv2d-9 [-1, 256, 138, 138] 295,168ReLU-10 [-1, 256, 138, 138] 0Conv2d-11 [-1, 256, 136, 136] 590,080MaxPool2d-12 [-1, 256, 68, 68] 0Conv2d-13 [-1, 512, 66, 66] 1,180,160ReLU-14 [-1, 512, 66, 66] 0Conv2d-15 [-1, 512, 64, 64] 2,359,808MaxPool2d-16 [-1, 512, 32, 32] 0Conv2d-17 [-1, 1024, 30, 30] 4,719,616ReLU-18 [-1, 1024, 30, 30] 0Conv2d-19 [-1, 1024, 28, 28] 9,438,208Upsample-20 [-1, 1024, 56, 56] 0Conv2d-21 [-1, 512, 56, 56] 524,800ReLU-22 [-1, 512, 56, 56] 0Up_Sample-23 [-1, 1024, 56, 56] 0Conv2d-24 [-1, 512, 54, 54] 4,719,104ReLU-25 [-1, 512, 54, 54] 0Conv2d-26 [-1, 512, 52, 52] 2,359,808Upsample-27 [-1, 512, 104, 104] 0Conv2d-28 [-1, 256, 104, 104] 131,328ReLU-29 [-1, 256, 104, 104] 0Up_Sample-30 [-1, 512, 104, 104] 0Conv2d-31 [-1, 256, 102, 102] 1,179,904ReLU-32 [-1, 256, 102, 102] 0Conv2d-33 [-1, 256, 100, 100] 590,080Upsample-34 [-1, 256, 200, 200] 0Conv2d-35 [-1, 128, 200, 200] 32,896ReLU-36 [-1, 128, 200, 200] 0Up_Sample-37 [-1, 256, 200, 200] 0Conv2d-38 [-1, 128, 198, 198] 295,040ReLU-39 [-1, 128, 198, 198] 0Conv2d-40 [-1, 128, 196, 196] 147,584Upsample-41 [-1, 128, 392, 392] 0Conv2d-42 [-1, 64, 392, 392] 8,256ReLU-43 [-1, 64, 392, 392] 0Up_Sample-44 [-1, 128, 392, 392] 0Conv2d-45 [-1, 64, 390, 390] 73,792ReLU-46 [-1, 64, 390, 390] 0Conv2d-47 [-1, 64, 388, 388] 36,928Conv2d-48 [-1, 2, 388, 388] 130
================================================================
Total params: 28,941,698
Trainable params: 28,941,698
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.25
Forward/backward pass size (MB): 2275.74
Params size (MB): 110.40
Estimated Total Size (MB): 2387.39
----------------------------------------------------------------
Python Unet网络结构pytorch简单实现+torchsummary可视化(可以直接运行)相关推荐
- 基于Python Unet的医学影像分割系统源码,含皮肤病的数据及皮肤病分割的模型,用户输入图像,模型可以自动分割去皮肤病的区域
手把手教你用Unet做医学图像分割 我们用Unet来做医学图像分割.我们将会以皮肤病的数据作为示范,训练一个皮肤病分割的模型出来,用户输入图像,模型可以自动分割去皮肤病的区域和正常的区域.废话不多说, ...
- PyTorch深度学习训练可视化工具tensorboardX
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 之前笔者提到了PyTorch的专属可视化工具visdom,参看Py ...
- pytorch | 深度学习分割网络U-net的pytorch模型实现
原文:https://blog.csdn.net/u014722627/article/details/60883185 pytorch | 深度学习分割网络U-net的pytorch模型实现 这个是 ...
- pythonplotly k线 动态_GitHub - 846626465/PythonPlotlyCodes: 《Python 数据分析:基于 Plotly 的动态可视化绘图》 源代码...
PythonPlotlyCodes <Python 数据分析:基于 Plotly 的动态可视化绘图> 源代码 前言 Python是一门非常优秀的编程语言,其语法简捷.易学易用,越来越受到编 ...
- pythonplotly k线 动态_GitHub - chitandacc/PythonPlotlyCodes: 《Python 数据分析:基于 Plotly 的动态可视化绘图》 源代码...
PythonPlotlyCodes <Python 数据分析:基于 Plotly 的动态可视化绘图> 源代码 前言 Python是一门非常优秀的编程语言,其语法简捷.易学易用,越来越受到编 ...
- pythonplotly k线 动态_GitHub - Yanglian666/PythonPlotlyCodes: 《Python 数据分析:基于 Plotly 的动态可视化绘图》 源代码...
PythonPlotlyCodes <Python 数据分析:基于 Plotly 的动态可视化绘图> 源代码 前言 Python是一门非常优秀的编程语言,其语法简捷.易学易用,越来越受到编 ...
- 【医学图像分割网络】之Res U-Net网络PyTorch复现
[医学图像分割网络]之Res U-Net网络PyTorch复现 1.内容 U-Net网络算是医学图像分割领域的开山之作,我接触深度学习到现在大概将近大半年时间,看到了很多基于U-Net网络的变体,后续 ...
- 通过带Flask的REST API在Python中部署PyTorch
通过带Flask的REST API在Python中部署PyTorch 在本文中,将使用Flask来部署PyTorch模型,并用讲解用于模型推断的 REST API.特别是,将部署一个预训练的Dense ...
- 大学python和vb哪个简单-python和vb哪个简单
Visual Basic(简称VB)是Microsoft公司开发的一种通用的基于对象的程序设计语言,为结构化的.模块化的.面向对象的.包含协助开发环境的事件驱动为机制的可视化程序设计语言.是一种可用于 ...
最新文章
- 循环IRNNv2Layer实现
- qtmessagebox对话框里自定义按钮文本_Word里表格都是这么来的 — 生成绘制表格有技巧...
- 字符串数组中查找字符串
- 打开高效文本编辑之门_Linux Awk之条件判断与循环
- Linux基础命令---killall
- 华尔街弃儿:雷曼兄弟158岁被清算
- C语言之去掉https链接的默认443端口
- apache.camel_在即将发布的Camel 2.21版本中改进了使用Apache Camel和ActiveMQ Artemis处理大型消息的功能...
- React开发(124):ant design学习指南之form中的this.props.form
- 燃爆了!胡歌秒变最帅产品经理发布荣耀V20!
- 947. 移除最多的同行或同列石头2021-01-23
- MySQL存储过程-循环遍历查询到的结果集
- Revit2018下载和安装教程
- 孪生网络(1)_孪生网络的分类
- 拼多多破1000亿美金,黄峥自述:我的人生经历和创业理念
- hivesql uv
- Logo小变动,心境大不同,SVG矢量动画格式网站Logo图片制作与实践教程(Python3)
- python提取html中的href标签,如何使用Python从HTML获取href链接?
- 【Anki 牌组+Markdown笔记分享】汇编语言
- 定义一个复数类Complex,重载运算符“+”,“ -”,“*”,“/”使之能用于计算两个复数的加减乘除。