USR模块

网络结构

开源github地址:https://github.com/NathanUA/U-2-Net

自己的实现:

u2net.py:

import torch
from torch import nn
from torchvision import models
import torch.nn.functional as Fclass Convalution(nn.Module):def __init__(self,in_ch=3,out_ch=3,dirate=1):super(Convalution, self).__init__()self.conv = nn.Conv2d(in_ch,out_ch,3,padding=dirate,dilation=dirate)self.bn = nn.BatchNorm2d(out_ch)self.relu = nn.ReLU(inplace=True)def forward(self,x):return self.relu(self.bn(self.conv(x)))def upsample(src,tar):return F.upsample(src,size=tar.shape[2:],mode='bilinear')class Encode(nn.Module):def __init__(self,in_ch,out_ch,dirate):super(Encode, self).__init__()self.conv = Convalution(in_ch,out_ch,dirate)self.pool = nn.MaxPool2d(2,2,ceil_mode=True)def forward(self,x):return self.pool(self.conv(x))class RSU1(nn.Module):def __init__(self,in_ch=3,mid_ch=12,out_ch=3):super(RSU1, self).__init__()self.en1 = Convalution(in_ch,out_ch,dirate=1)self.en2 = Encode(out_ch,mid_ch,1)self.en3 = Encode(mid_ch,mid_ch,1)self.en4 = Encode(mid_ch,mid_ch,1)self.en5 = Encode(mid_ch,mid_ch,1)self.en6 = Encode(mid_ch,mid_ch,1)self.en7 = Convalution(mid_ch,mid_ch,2)self.en8 = Convalution(mid_ch, mid_ch, 2)self.de6 = Convalution(mid_ch * 2,mid_ch,dirate=1)self.de5 = Convalution(mid_ch * 2,mid_ch,dirate=1)self.de4 = Convalution(mid_ch * 2,mid_ch,dirate=1)self.de3 = Convalution(mid_ch * 2, mid_ch, dirate=1)self.de2 = Convalution(mid_ch * 2, mid_ch, dirate=1)self.de1 = Convalution(mid_ch * 2, out_ch, dirate=1)def forward(self,x):en1 = self.en1(x)en2 = self.en2(en1)en3 = self.en3(en2)en4 = self.en4(en3)en5 = self.en5(en4)en6 = self.en6(en5)en7 = self.en7(en6)en8 = self.en8(en7)de6 = self.de6(torch.cat([en8,en7],dim=1))de5 = self.de5(torch.cat([de6,en6],dim=1))de5 = upsample(de5,en5)de4 = self.de4(torch.cat([de5,en5],dim=1))de4 = upsample(de4,en4)de3 = self.de3(torch.cat([de4,en4],dim=1))de3 = upsample(de3,en3)de2 = self.de2(torch.cat([de3,en3],dim=1))de2 = upsample(de2,en2)de1 = self.de1(torch.cat([de2,en2],dim=1))de1 = upsample(de1,en1)return de1 + en1class RSU2(nn.Module):def __init__(self,in_ch=3,mid_ch=12,out_ch=3):super(RSU2, self).__init__()self.en1 = Convalution(in_ch,out_ch)self.en2 = Encode(out_ch,mid_ch,1)self.en3 = Encode(mid_ch,mid_ch,1)self.en4 = Encode(mid_ch,mid_ch,1)self.en5 = Encode(mid_ch,mid_ch,1)self.en6 = Convalution(mid_ch,mid_ch)self.en7 = Convalution(mid_ch,mid_ch,dirate=2)self.de5 = Convalution(mid_ch * 2,mid_ch)self.de4 = Convalution(mid_ch * 2,mid_ch)self.de3 = Convalution(mid_ch * 2,mid_ch)self.de2 = Convalution(mid_ch * 2,mid_ch)self.de1 = Convalution(mid_ch * 2,out_ch)def forward(self,x):en1 = self.en1(x)en2 = self.en2(en1)en3 = self.en3(en2)en4 = self.en4(en3)en5 = self.en5(en4)en6 = self.en6(en5)en7 = self.en7(en6)de5 = self.de5(torch.cat([en7,en6],dim=1))de4 = self.de4(torch.cat([de5,en5],dim=1))de4 = upsample(de4,en4)de3 = self.de3(torch.cat([de4,en4],dim=1))de3 = upsample(de3,en3)de2 = self.de2(torch.cat([de3,en3],dim=1))de2 = upsample(de2,en2)de1 = self.de1(torch.cat([de2,en2],dim=1))de1 = upsample(de1,en1)return de1 + en1class RSU3(nn.Module):def __init__(self,in_ch=3,mid_ch=12,out_ch=3):super(RSU3, self).__init__()self.en1 = Convalution(in_ch,out_ch)self.en2 = Encode(out_ch,mid_ch,1)self.en3 = Encode(mid_ch,mid_ch,1)self.en4 = Encode(mid_ch,mid_ch,1)self.en5 = Convalution(mid_ch,mid_ch)self.en6 = Convalution(mid_ch,mid_ch,2)self.de4 = Convalution(mid_ch * 2,mid_ch)self.de3 = Convalution(mid_ch * 2,mid_ch)self.de2 = Convalution(mid_ch * 2,mid_ch)self.de1 = Convalution(mid_ch * 2,out_ch)def forward(self,x):en1 = self.en1(x)en2 = self.en2(en1)en3 = self.en3(en2)en4 = self.en4(en3)en5 = self.en5(en4)en6 = self.en6(en5)de4 = self.de4(torch.cat([en6,en5],dim=1))de3 = self.de3(torch.cat([de4,en4],dim=1))de3 = upsample(de3,en3)de2 = self.de2(torch.cat([de3,en3],dim=1))de2 = upsample(de2,en2)de1 = self.de1(torch.cat([de2,en2],dim=1))de1 = upsample(de1,en1)return de1 + en1class RSU4(nn.Module):def __init__(self,in_ch=3,mid_ch=12,out_ch=3):super(RSU4, self).__init__()self.en1 = Convalution(in_ch,out_ch)self.en2 = Encode(out_ch,mid_ch,1)self.en3 = Encode(mid_ch,mid_ch,1)self.en4 = Convalution(mid_ch,mid_ch)self.en5 = Convalution(mid_ch,mid_ch,2)self.de3 = Convalution(mid_ch * 2,mid_ch)self.de2 = Convalution(mid_ch * 2,mid_ch)self.de1 = Convalution(mid_ch * 2,out_ch)def forward(self,x):en1 = self.en1(x)en2 = self.en2(en1)en3 = self.en3(en2)en4 = self.en4(en3)en5 = self.en5(en4)de3 = self.de3(torch.cat([en5,en4],dim=1))de2 = self.de2(torch.cat([de3,en3],dim=1))de2 = upsample(de2,en2)de1 = self.de1(torch.cat([de2,en2],dim=1))de1 = upsample(de1,en1)return de1 + en1class RSU5(nn.Module):def __init__(self,in_ch=3,mid_ch=12,out_ch=3):super(RSU5, self).__init__()self.en1 = Convalution(in_ch,out_ch)self.en2 = Convalution(out_ch,mid_ch)self.en3 = Convalution(mid_ch,mid_ch,dirate=2)self.en4 = Convalution(mid_ch,mid_ch,dirate=4)self.en5 = Convalution(mid_ch,mid_ch,dirate=8)self.de3 = Convalution(mid_ch * 2,mid_ch,dirate=4)self.de2 = Convalution(mid_ch * 2,mid_ch,dirate=2)self.de1 = Convalution(mid_ch * 2,out_ch)def forward(self,x):en1 = self.en1(x)en2 = self.en2(en1)en3 = self.en3(en2)en4 = self.en4(en3)de3 = self.de3(torch.cat([en4,en3],dim=1))de2 = self.de2(torch.cat([de3,en3],dim=1))de1 = self.de1(torch.cat([de2,en2],dim=1))return de1 + en1class U2NET(nn.Module):def __init__(self,in_ch=3,out_ch=1):super(U2NET, self).__init__()self.En_1 = nn.Sequential(RSU1(in_ch,32,64),nn.MaxPool2d(2,2,ceil_mode=True))self.En_2 = nn.Sequential(RSU2(64,32,128),nn.MaxPool2d(2,2,ceil_mode=True))self.En_3 = nn.Sequential(RSU3(128,64,256),nn.MaxPool2d(2,2,ceil_mode=True))self.En_4 = nn.Sequential(RSU4(256,128,512),nn.MaxPool2d(2,2,ceil_mode=True))self.En_5 = nn.Sequential(RSU5(512,256,512),nn.MaxPool2d(2,2,ceil_mode=True))self.En_6 = RSU5(512,256,512)self.De_5 = RSU5(1024,256,512)self.De_4 = RSU4(1024,128,256)self.De_3 = RSU3(512,64,128)self.De_2 = RSU2(256,32,64)self.De_1 = RSU1(128,16,64)self.side1 = nn.Conv2d(64,out_ch,3,padding=1)self.side2 = nn.Conv2d(64,out_ch,3,padding=1)self.side3 = nn.Conv2d(128,out_ch,3,padding=1)self.side4 = nn.Conv2d(256,out_ch,3,padding=1)self.side5 = nn.Conv2d(512,out_ch,3,padding=1)self.side6 = nn.Conv2d(512,out_ch,3,padding=1)self.sup0 = nn.Conv2d(6,1,1)def forward(self,x):en1 = self.En_1(x)en2 = self.En_2(en1)en3 = self.En_3(en2)en4 = self.En_4(en3)en5 = self.En_5(en4)en6 = self.En_6(en5)de5 = self.De_5(torch.cat([en6,en5],dim=1))de5 = upsample(de5,en4)de4 = self.De_4(torch.cat([de5,en4],dim=1))de4 = upsample(de4,en3)de3 = self.De_3(torch.cat([de4,en3],dim=1))de3 = upsample(de3,en2)de2 = self.De_2(torch.cat([de3,en2],dim=1))de2 = upsample(de2,en1)de1 = self.De_1(torch.cat([de2,en1],dim=1))de1 = upsample(de1,x)side1 = self.side1(de1)side2 = self.side2(de2)side3 = self.side3(de3)side4 = self.side4(de4)side5 = self.side5(de5)side6 = self.side6(en6)sup1 = side1sup2 = upsample(side2,x)sup3 = upsample(side3,x)sup4 = upsample(side4,x)sup5 = upsample(side5,x)sup6 = upsample(side6,x)sup0 = self.sup0(torch.cat([sup1,sup2,sup3,sup4,sup5,sup6],dim=1))return F.sigmoid(sup0),F.sigmoid(sup1),F.sigmoid(sup2),F.sigmoid(sup3),F.sigmoid(sup4),F.sigmoid(sup5),F.sigmoid(sup6)if __name__ == '__main__':x = torch.rand((1,3,512,512))m = U2NET()print(m(x).shape)

dataloader.py:

from torch.utils.data import Dataset
import os
from PIL import Image
from torchvision import transformsdata_dir = r'D:\MASKpicture\train_pig_body'
label_dir = r'D:\MASKpicture\label_pig_body'class MydataSet(Dataset):def __init__(self):super(MydataSet, self).__init__()self.dataset = os.listdir(data_dir)self.dataset = self.datasetdef __getitem__(self, index):try:image = Image.open(os.path.join(data_dir,self.dataset[index])).convert('RGB')label = Image.open(os.path.join(label_dir,self.dataset[index])).convert('L')pad = max(image.size)size = (pad, pad)transform = transforms.Compose([transforms.CenterCrop(size),transforms.Resize(490),transforms.ToTensor()])imagedata = transform(image)labeldata = transform(label)return imagedata,labeldataexcept:return self.__getitem__(index + 1)def __len__(self):return len(self.dataset)# s = MydataSet()
# for i in range(s.__len__()):
#     a,b = s.__getitem__(i)
#     print(a.shape,b.shape)

train.py

from dataloader import MydataSet
from u2net import U2NET
from torch.utils.data import DataLoader
from torch import optim
from torch import nn
import torchif __name__ == '__main__':net = U2NET()net.cuda()dataset = MydataSet()dataloader = DataLoader(dataset,batch_size=3,shuffle=True,num_workers=4)Lossfuction = nn.BCELoss()optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)net.train()for epoch in range(300):for data,label in dataloader:data = data.cuda()label = label.cuda()o0,o1,o2,o3,o4,o5,o6 = net(data)loss0 = Lossfuction(o0,label)loss1 = Lossfuction(o1, label)loss2 = Lossfuction(o2, label)loss3 = Lossfuction(o3, label)loss4 = Lossfuction(o4, label)loss5 = Lossfuction(o5, label)loss6 = Lossfuction(o6, label)loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6optimizer.zero_grad()loss.backward()optimizer.step()print(loss)torch.save(net.state_dict(),'netpig.pt')

test.py:

from u2net import U2NET
import cv2
import torch
from PIL import Image
from torchvision import transforms
import os
import timeif __name__ == '__main__':net = U2NET()net.load_state_dict(torch.load('netpig.pt'))net.cuda()net.eval()for dir in os.listdir(r'D:\Mask_Unet'):image = Image.open(r'C:\Users\Administrator\Desktop\企业微信截图_16050696914678.png').convert('RGB')pad = max(image.size)size = (pad, pad)transform = transforms.Compose([transforms.CenterCrop(size),transforms.Resize(490),transforms.ToTensor()])data = transform(image)data = torch.unsqueeze(data,0)data = data.cuda()s = time.time()o0,o1,o2,o3,o4,o5,o6 = net(data)print(time.time()-s)o0 = o0.view(490,490).cpu().detach().numpy()cv2.imshow('',o0)cv2.waitKey(0)

U2NET的pytorch实现相关推荐

  1. 使用U2net+cpu+pytorch完成照片素描化

    使用U2net+cpu+pytorch完成照片素描化 u2net_portrait_test.py import os from skimage import io, transform import ...

  2. U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection|环境搭建|人物素描 测试 简记 |

    这个代码非常强大,最近作者更新了模型 我也特别更新一篇博文 [最新同步更新教程链接 – 2021-9-3 ]-- 敬请移步 文章目录 U2-Net: Going Deeper with Nested ...

  3. 使用U2-Net深层网络实现——证件照生成程序

    效果预览:http://map.gnnu.work/rm21/qy/profilepicture 使用到: 1.pytorch 加载分割模型 2.PIL 更加方便地操作图像 2.U2-net 网络将人 ...

  4. U2Net论文解读及代码测试

    论文名称: U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection 论文地址: https://arxiv. ...

  5. U2-net网络详解

    学习视频:U2Net网络结构讲解_哔哩哔哩_bilibili 论文名称:U2-Net: Goging Deeper with Nested U-Structure forSalient Object ...

  6. Pytorch使用GPU训练模型加速

    Pytorch使用GPU训练模型加速 深度学习神经网络训练经常很耗时,耗时主要来自两个部分,数据准备和自参数迭代. 当数据准备是主要耗时原因时,采用多进程准备数据.当迭代过程是训练耗时主力时,采用GP ...

  7. 通过anaconda2安装python2.7和安装pytorch

    ①由于官网下载anaconda2太慢,最好去byrbt下载,然后安装就行 ②安装完anaconda2会自动安装了python2.7(如终端输入python即进入python模式) 但是可能没有设置环境 ...

  8. 记录一次简单、高效、无错误的linux上安装pytorch的过程

    1 准备miniconda Miniconda Miniconda 可以理解成Anaconda的免费.浓缩版.它非常小,只包含了conda.python以及它们依赖的一些包.我们可以根据我们的需要再安 ...

  9. 各种注意力机制PyTorch实现

    给出了整个系列的PyTorch的代码实现,以及使用方法. 各种注意力机制 Pytorch implementation of "Beyond Self-attention: External ...

最新文章

  1. MySQL数据库(八) 一一 数据库信息和使用序列
  2. (十)MySQL日志
  3. 基于Pytorch再次解析AlexNet现代卷积神经网络
  4. simulink中mask设置_Mask Editor 概述
  5. 信息检索报告_iFixR:缺陷报告驱动程序修复
  6. MySQL数据库的备份和还原
  7. rmi远程代码执行漏洞_【最新漏洞简讯】WebLogic远程代码执行漏洞 (CVE202014645)
  8. java中迭代器要导包吗_java 中迭代器的使用方法详解
  9. 美国只有两样东西比中国贵
  10. php类的实例化方法,php中类的定义和实例化方法
  11. SEO关键词优化:如何理解被百度快速索引?
  12. jmc线程转储_Java线程转储– VisualVM,jstack,kill -3,jcmd
  13. 逆向知识第七讲,三目运算符在汇编中的表现形式,以及编译器优化方式
  14. 【安装包】XMind-ZEN-Update-2019-for-Windows-64bit-9.2.1
  15. Linux与git库建立连接,Linux 下建立 Git 与 GitHub 的连接
  16. 火山安卓文件名类操作
  17. 什么是非接触式IC卡
  18. 1112day10:考前复习50题:断言
  19. 计算机弹歌光年之外谱子,邓紫棋《光年之外》完整钢琴谱
  20. MobileNetV3基于NNI剪枝操作

热门文章

  1. 某机房ups电源更换蓄电池的一次作业过程
  2. 东北大学大一下暑期实训项目--活力长者社区(JAVAFX+scenebuilder)
  3. android viewpager实现画廊效果,android viewpager 实现画廊效果
  4. 天风掌财社可靠?科创板开户掌财学堂合法吗?
  5. c语言土壤墒情监测系统,土壤墒情监测系统
  6. EOS钱包EosToken开始空投代币了,新用户可领取1.05个EOS
  7. 合振动的初相位推导_合振动的初相位确定方法
  8. 箱梁终张拉后弹性上拱度计算_后张法预制桥梁弹性上拱度的检测方法与流程
  9. 使用 Google 高级搜索的一些技巧
  10. Sicily Hansel and Grethel