详细Unet网络结构可以查看Unet算法原理详解

深度网络训练之中需要大量的有标样本,Unet作者提供了一种新的训练方法,可以更有效的运用相应的有标样本,使网络即使通过少量的训练图片也可以进行更精确的分割。

这里只是记录一下近期在网站资源上学到的Unet模型项目的代码,代码中有较详细的解释(学习笔记)

代码来源:GitHub - qiaofengsheng/pytorch-UNet: pytorch搭建自己的unet网络,训练自己的数据集。 B站视频地址全程带你手撸代码:https://www.bilibili.com/video/BV11341127iK?spm_id_from=333.999.0.0

一、制作自己的训练集

使用labelme制作数据集标记

1、首先使用cmd命令行,打开labelme软件

2、点击labelme软件左侧的Open Dir,在地址栏输入项目的数据集照片的地址,选择文件即可

3、点击左侧Create Polygons.进行标签标画,裱花完成后点击save,即可。生成的json文件

二、json转为mask

import osimport cv2
import numpy as np
from PIL import Image, ImageDraw
import jsonCLASS_NAMES = ['yellowtrain', 'person']def make_mask(image_dir, save_dir):data = os.listdir(image_dir)temp_data = []for i in data:if i.split('.')[1] == 'json':temp_data.append(i)else:continuefor js in temp_data:json_data = json.load(open(os.path.join(image_dir, js), 'r'))shapes_ = json_data['shapes']#得到标签mask = Image.new('P', Image.open(os.path.join(image_dir, js.replace('json', 'png'))).size)for shape_ in shapes_:#由于一张图片可能有多个标签,所以要遍历一遍shapelabel = shape_['label']points = shape_['points']points = tuple(tuple(i) for i in points)mask_draw = ImageDraw.Draw(mask)mask_draw.polygon(points, fill=CLASS_NAMES.index(label) + 1)#对图片进行画多边形的操作mask=np.array(mask)*255cv2.imshow('mask',mask)cv2.waitKey(0)mask.save(os.path.join(save_dir, js.replace('json', 'png')))def vis_label(img):img=Image.open(img)img=np.array(img)print(set(img.reshape(-1).tolist()))if __name__ == '__main__':make_mask('image', 'SegmentationClass')# vis_label('SegmentationClass/2007_009436.png')# img=Image.open('SegmentationClass/2007_009436.png')# print(np.array(img).shape)# out=np.array(img).reshape(-1)# print(set(out.tolist()))

三、网络结构

import torch
from torch import nn
from torch.nn import functional as F#插值法上采样class Conv_Block(nn.Module):def __init__(self,in_channel,out_channel):super(Conv_Block, self).__init__()self.layer=nn.Sequential(nn.Conv2d(in_channel,out_channel,3,1,1,padding_mode='reflect',bias=False),#卷积3*3,步长为1,padding为1nn.BatchNorm2d(out_channel),nn.Dropout2d(0.3),nn.LeakyReLU(),nn.Conv2d(out_channel, out_channel, 3, 1, 1, padding_mode='reflect', bias=False),nn.BatchNorm2d(out_channel),nn.Dropout2d(0.3),nn.LeakyReLU())def forward(self,x):return self.layer(x)class DownSample(nn.Module):#池化(下采样)def __init__(self,channel):super(DownSample, self).__init__()self.layer=nn.Sequential(#序列构造器nn.Conv2d(channel,channel,3,2,1,padding_mode='reflect',bias=False),#这里不采用最大池化,最大池化特征丢失太多,所以采用步长为2nn.BatchNorm2d(channel),nn.LeakyReLU())def forward(self,x):return self.layer(x)class UpSample(nn.Module):#上采样def __init__(self,channel):super(UpSample, self).__init__()self.layer=nn.Conv2d(channel,channel//2,1,1)#1*1卷积,降低通道,无需特征提取,只是降通道数def forward(self,x,feature_map):up=F.interpolate(x,scale_factor=2,mode='nearest')#最邻近插值法out=self.layer(up)return torch.cat((out,feature_map),dim=1)class UNet(nn.Module):def __init__(self,num_classes):super(UNet, self).__init__()self.c1=Conv_Block(3,64)self.d1=DownSample(64)self.c2=Conv_Block(64,128)self.d2=DownSample(128)self.c3=Conv_Block(128,256)self.d3=DownSample(256)self.c4=Conv_Block(256,512)self.d4=DownSample(512)self.c5=Conv_Block(512,1024)self.u1=UpSample(1024)self.c6=Conv_Block(1024,512)self.u2 = UpSample(512)self.c7 = Conv_Block(512, 256)self.u3 = UpSample(256)self.c8 = Conv_Block(256, 128)self.u4 = UpSample(128)self.c9 = Conv_Block(128, 64)self.out=nn.Conv2d(64,num_classes,3,1,1)def forward(self,x):R1=self.c1(x)R2=self.c2(self.d1(R1))R3 = self.c3(self.d2(R2))R4 = self.c4(self.d3(R3))R5 = self.c5(self.d4(R4))O1=self.c6(self.u1(R5,R4))O2 = self.c7(self.u2(O1, R3))O3 = self.c8(self.u3(O2, R2))O4 = self.c9(self.u4(O3, R1))return self.out(O4)if __name__ == '__main__':x=torch.randn(2,3,256,256)net=UNet()print(net(x).shape)

四、训练代码

import osimport tqdm
from torch import nn, optim
import torch
from torch.utils.data import DataLoader#数据集加载器
from data import *
from net import *
from torchvision.utils import save_imagedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weight_path = 'params/unet.pth'#权重地址
data_path = r'data'#数据集地址
save_path = 'train_image'#
if __name__ == '__main__':num_classes = 2 + 1  # +1是背景也为一类data_loader = DataLoader(MyDataset(data_path), batch_size=1, shuffle=True)#加载数据集,batch_size批次,根据自身电脑的情况进行修改net = UNet(num_classes).to(device)#实例化Unet网路if os.path.exists(weight_path):#判断权重是否存在net.load_state_dict(torch.load(weight_path))print('successful load weight!')else:print('not successful load weight')opt = optim.Adam(net.parameters())loss_fun = nn.CrossEntropyLoss()epoch = 1while epoch < 200:for i, (image, segment_image) in enumerate(tqdm.tqdm(data_loader)):image, segment_image = image.to(device), segment_image.to(device)out_image = net(image)train_loss = loss_fun(out_image, segment_image.long())opt.zero_grad()train_loss.backward()opt.step()if i % 1 == 0:print(f'{epoch}-{i}-train_loss===>>{train_loss.item()}')_image = image[0]_segment_image = torch.unsqueeze(segment_image[0], 0) * 255_out_image = torch.argmax(out_image[0], dim=0).unsqueeze(0) * 255img = torch.stack([_segment_image, _out_image], dim=0)save_image(img, f'{save_path}/{i}.png')if epoch % 20 == 0:#每20次保存一次权重torch.save(net.state_dict(), weight_path)print('save successfully!')epoch += 1

五、测试代码

import osimport cv2
import numpy as np
import torchfrom net import *
from utils import *
from data import *
from torchvision.utils import save_image
from PIL import Image
net=UNet(3).cuda()weights='params/unet.pth'
if os.path.exists(weights):net.load_state_dict(torch.load(weights))print('successfully')
else:print('no loading')_input=input('please input JPEGImages path:')img=keep_image_size_open_rgb(_input)
img_data=transform(img).cuda()
img_data=torch.unsqueeze(img_data,dim=0)
net.eval()
out=net(img_data)
out=torch.argmax(out,dim=1)
out=torch.squeeze(out,dim=0)
out=out.unsqueeze(dim=0)
print(set((out).reshape(-1).tolist()))
out=(out).permute((1,2,0)).cpu().detach().numpy()
cv2.imwrite('result/result.png',out)
cv2.imshow('out',out*255.0)
cv2.waitKey(0)

pytorch学习--UNet模型相关推荐

  1. Pytorch学习 - 保存模型和重新加载

    Pytorch学习 - 保存和加载模型 1. 3个函数 2. 模型不同后缀名的区别 3. 保存和重载模型 参考资料: Pytorch官方文档链接 某博客 1. 3个函数 torch.save() : ...

  2. 手把手教程:零基础使用MATLAB完成基于深度学习U-Net模型的遥感影像分类

    背景: 很多初入深度学习的学生都会遇到各种环境配置问题,环境搭建不好模型就跑不了,所以这是限制新手的一大难点,MATLAB具有成熟的运行环境,无需配置,这点对于想跑通一个深度学习模型的新手是非常有利的 ...

  3. PyTorch学习笔记(五):模型定义、修改、保存

    往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...

  4. 【深度学习】带有 CRF-RNN 层的 U-Net模型

    [深度学习]带有 CRF-RNN 层的 U-Net模型 文章目录 1 图像语义分割之FCN和CRF 2 CRF as RNN语义分割 3 全连接条件随机场与稀疏条件随机场的区别 4 CRF as RN ...

  5. 人脸口罩检测现开源PyTorch、TensorFlow、MXNet等全部五大主流深度学习框架模型和代码...

    号外!号外! 现在,AIZOO开源PyTorch.TensorFlow.MXNet.Keras和Caffe五大主流深度学习框架的人脸检测模型和代码啦! 先附上Github链接为敬. https://g ...

  6. 速成pytorch学习——7天模型层layers

    深度学习模型一般由各种模型层组合而成. torch.nn中内置了非常丰富的各种模型层.它们都属于nn.Module的子类,具备参数管理功能. 例如: nn.Linear, nn.Flatten, nn ...

  7. 使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记

    使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记 https://www.bilibili.com/video/BV1rq4y1w7xM?spm_id_from=33 ...

  8. 深度学习【使用pytorch实现基础模型、优化算法介绍、数据集的加载】

    文章目录 一 Pytorch完成基础模型 1. Pytorch完成模型常用API 1.1 `nn.Module` 1.2 优化器类 1.3 损失函数 1.4 线性回归完整代码 2. 在GPU上运行代码 ...

  9. PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard

    文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...

最新文章

  1. 编译原理 - 实验三 - 递归下降语法分析器的调试及扩展
  2. 在Centos上编译安装nginx
  3. java 上传文件编码_(java)有什么办法把MultipartFile上传的文件转为utf-8的编码吗
  4. wcf wpf mfc 区别
  5. 02怎么取整php,php取整的几种方式
  6. java 下载文件大小_如何在浏览器中显示使用角度5下载的文件的文件大小?
  7. Asp.Net Core Mvc上Json序列化首字母大小写的问题
  8. Go基础-核心特性和前景
  9. tilte和body标签
  10. html手机端富文本,移动端富文本踩坑
  11. 10分钟健身法读书笔记(2/5)
  12. 【Linux】万字总结Linux 基本指令,绝对详细!!!
  13. java se  计算机专业技能-Java专项练习(选择题)(三)
  14. Python爬取2345天气网
  15. 论文笔记:An Adaptive Feature Norm Approach for Unsupervised Domain Adaptation
  16. shell 十三问:
  17. PS2019 cc for Mac语言切换
  18. 2022-2028全球与中国制造业物联网市场现状及未来发展趋势
  19. WINDOWS时间服务启动失败的原因
  20. 差分升级(Diff and Patch)

热门文章

  1. 利用matlab进行三维曲线拟合(cftool工具箱实现)
  2. 查询建立连接的IP地址
  3. android 在线预览pdf文件(目前最全)
  4. 利用vlan交换机(网管交换机)打trunk实现单线复用
  5. 双机热备和磁盘阵列柜
  6. CodeForces - 1008D - Pave the Parallelepiped (容斥原理+重复组合公式+状态压缩+思维)
  7. Windows8和MacOS10.9双系统安装及Mac常用软件安装--联想E49A
  8. Android手机开发总结——Android核心分析
  9. HMI-7-[高分屏支持]:Qt 适配高分屏
  10. 阿里云-云存储OSS