UNet 网络做图像分割DRIVE数据集
目录
1. 介绍
2. 搭建 UNet 网络
3. dataset 数据加载
4. train 训练网络
5. predict 分割图像
6. show
7. 完整代码
1. 介绍
项目的目录如下所示
- DRIVE 存放的是数据集
- predict 是待分割的图像
- result 里面放分割predict 的结果
- dataset 是处理数据的文件、model存放unet网络、predict是预测、train是网络的训练、UNet.pth 是训练好的权重文件
之前做了一个图像分割的例子,里面大部分的代码和本篇的内容重合,所以每个脚本的代码只会做简单的介绍。具体的可以参考之前的内容,这里给出链接:
model : UNet - unet网络
dataset :UNet - 数据加载 Dataset
train : UNet - 训练数据train
predict : UNet - 预测数据predict(多个图像的分割)
DRIVE ( Digital Retinal Images for Vessel Extraction ):用于血管提取的数字视网膜图像
训练样本:灰度图像
对应的标签:二值图像
因为这个分割项目完成几周了,最近才整理。所以,原数据集 DRIVE 可能是彩色图像 + mask 掩膜(具体的记不清了)
- 这里没有使用 mask
- 如果是彩色图像的话,在生成unet网络的时候,传入的channel设置成3就行了。或者想用灰度图像的形式,要么用opencv转一下,可以看见灰度化的效果类似于展示的那样;要么在预处理的里面转成灰度图片 transform.Grayscale()
2. 搭建 UNet 网络
和之前unet网络不同的是,这里通过填充size,可以保证任意图像维度的输入
之前的代码需要经过4此下采样,每次维度扩展,size减半,所以需要保证输入图像的大小是 2的4次方
具体这块怎么实现我也看不懂,经过测试,可以实现任意输入的size
3. dataset 数据加载
数据加载的时候,将图像的预处理也放到了这里
这里训练的图像要 ToTensor ,归一化+改变通道顺序+转为tensor等等。同时,为了加快训练,对图像正规化,因为训练的图像是灰度图,所以只需要单通道的均值和标准差
然后是 数据加载 的初始化
这里的imgs里面的内容是,传入路径root下的图像路径,这里是:
['01.png', '02.png', '03.png', '04.png', '05.png', '06.png', '07.png', '08.png', '09.png']
self.imgs 是将root 路径和root 里面每个图像的路径 拼接在一块的路径,这里是:
['./DRIVE/test/image\\01.png', './DRIVE/test/image\\02.png', './DRIVE/test/image\\03.png', './DRIVE/test/image\\04.png', './DRIVE/test/image\\05.png', './DRIVE/test/image\\06.png', './DRIVE/test/image\\07.png', './DRIVE/test/image\\08.png', './DRIVE/test/image\\09.png']
如图:
初始化路径和预处理后,需要对图像进行处理
这里训练的样本和对应的二值图像的label文件名要保证一样,否则需要做别的处理。例如,这里只需要将训练样本的图像路径里面的image 替换(replace)成label 就能找到对应的分割图像
然后读取图像,预处理之后,在进行返回即可。
这里为了防止label不是严格的二值图像,在归一化(灰度值 / 255)后,将中间的灰度值也映射为前景像素点
4. train 训练网络
训练网络的代码基本上没有改变,这里简单介绍
判断网络运行的设备,将网络to到device上
加载训练集+测试集
这里传入的是训练的样本,因为Data_loader 会将样本的路径替换成 label找到对应分割的标签图像
因为内存不足,所以这里将batch size 设置成 1
然后定义优化器+损失函数,并且保存网络的训练权重文件
有关BCEWithLogitsLoss可以参考这个:聊聊关于图像分割的损失函数 - BCEWithLogitsLoss
训练的时候,需要网络在train模式下,然后就是正确的前向传播预测+反向梯度下降的内容
最后是计算正确率,需要将网络放到eval模式下
这里将网络的预测转为二值图像,然后计算准确率的方式是预测的二值图像和label进行逐个像素点的比对,最后比上整幅图像的空间分辨率,即图像的大小。
test_label 的通道顺序是:batch、channel、height、width
5. predict 分割图像
这里的预处理要和处理样本的预处理一致
加载网络+读取网络参数
预测的时候,需要扩展维度。保存图像的时候,需要将batch和channel减去
然后将预测的结果转为二值图像就可以了
6. show
训练了20个epoch,结果显示如下
这里来预测的图像在test数据集里面,predict里面的图像为:
UNet 分割的结果:
真实的label为:
分割了大部分的信息,但是仍有细节没有分割出来
图像的size 是 565*584 的,大概预测的准确率是 0.96 左右
也就是说 还有 565*584*0.04 = 13198 ,这些损失的像素点就是缺少的细节
7. 完整代码
model部分:
import torch.nn as nn
import torch
import torch.nn.functional as F# 搭建unet 网络
class DoubleConv(nn.Module): # 连续两次卷积def __init__(self,in_channels,out_channels):super(DoubleConv,self).__init__()self.double_conv = nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1,bias=False),nn.BatchNorm2d(out_channels), # 用 BN 代替 Dropoutnn.ReLU(inplace=True),nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1,bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self,x):x = self.double_conv(x)return xclass Down(nn.Module): # 下采样def __init__(self,in_channels,out_channels):super(Down, self).__init__()self.downsampling = nn.Sequential(nn.MaxPool2d(kernel_size=2,stride=2),DoubleConv(in_channels,out_channels))def forward(self,x):x = self.downsampling(x)return xclass Up(nn.Module): # 上采样def __init__(self, in_channels, out_channels):super(Up,self).__init__()self.upsampling = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) # 转置卷积self.conv = DoubleConv(in_channels, out_channels)def forward(self, x1, x2):x1 = self.upsampling(x1)diffY = torch.tensor([x2.size()[2] - x1.size()[2]]) # 确保任意size的图像输入diffX = torch.tensor([x2.size()[3] - x1.size()[3]])x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])x = torch.cat([x2, x1], dim=1) # 从channel 通道拼接x = self.conv(x)return xclass OutConv(nn.Module): # 最后一个网络的输出def __init__(self, in_channels, num_classes):super(OutConv, self).__init__()self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)def forward(self, x):return self.conv(x)class UNet(nn.Module): # unet 网络def __init__(self, in_channels = 1, num_classes = 1):super(UNet, self).__init__()self.in_channels = in_channelsself.num_classes = num_classesself.in_conv = DoubleConv(in_channels, 64)self.down1 = Down(64, 128)self.down2 = Down(128, 256)self.down3 = Down(256, 512)self.down4 = Down(512, 1024)self.up1 = Up(1024, 512)self.up2 = Up(512, 256)self.up3 = Up(256, 128)self.up4 = Up(128, 64)self.out_conv = OutConv(64, num_classes)def forward(self, x):x1 = self.in_conv(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)x = self.out_conv(x)return x
dataset 数据处理部分:
import os
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transformsdata_transform = {"train": transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, ), (0.5, ))]),"test": transforms.Compose([transforms.ToTensor()])
}# 数据处理文件
class Data_Loader(Dataset): # 加载数据def __init__(self, root, transforms_train=data_transform['train'],transforms_test=data_transform['test']): # 初始化imgs = os.listdir(root) # 读取图像的路径self.imgs = [os.path.join(root,img) for img in imgs] # 取出路径下所有的图片self.transforms_train = transforms_train # 预处理self.transforms_test = transforms_testdef __getitem__(self, index): # 获取数据、预处理等等image_path = self.imgs[index] # 根据index读取图片label_path = image_path.replace('image', 'label') # 根据image_path生成label_pathimage = Image.open(image_path) # 读取图片和对应的label图label = Image.open(label_path)image = self.transforms_train(image) # 样本预处理label = self.transforms_test(label) # label 预处理label[label > 0] = 1return image, labeldef __len__(self): # 返回样本的数量return len(self.imgs)
train 网络训练部分:
from model import UNet
from dataset import Data_Loader
from torch import optim
import torch.nn as nn
import torch# 网络训练模块
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # GPU or CPU
print(device)
net = UNet(in_channels=1, num_classes=1) # 加载网络
net.to(device) # 将网络加载到device上# 加载训练集
trainset = Data_Loader("./DRIVE/train/image")
train_loader = torch.utils.data.DataLoader(dataset=trainset,batch_size=1,shuffle=True)
len = len(trainset) # 样本总数为 31# 加载测试集
testset = Data_Loader("./DRIVE/test/image")
test_loader = torch.utils.data.DataLoader(dataset=testset,batch_size=1)# 加载优化器和损失函数
optimizer = optim.RMSprop(net.parameters(), lr=0.00001,weight_decay=1e-8, momentum=0.9) # 定义优化器
criterion = nn.BCEWithLogitsLoss() # 定义损失函数# 保存网络参数
save_path = './UNet.pth' # 网络参数的保存路径
best_acc = 0.0 # 保存最好的准确率# 训练
for epoch in range(20):net.train() # 训练模式running_loss = 0.0for image,label in train_loader:optimizer.zero_grad() # 梯度清零pred = net(image.to(device)) # 前向传播loss = criterion(pred, label.to(device)) # 计算损失loss.backward() # 反向传播optimizer.step() # 梯度下降running_loss += loss.item() # 计算损失和net.eval() # 测试模式acc = 0.0 # 正确率total = 0with torch.no_grad():for test_image, test_label in test_loader:outputs = net(test_image.to(device)) # 前向传播outputs[outputs >= 0] = 1 # 将预测图片转为二值图片outputs[outputs < 0] = 0# 计算预测图片与真实图片像素点一致的精度:acc = 相同的 / 总个数acc += (outputs == test_label.to(device)).sum().item() / (test_label.size(2) * test_label.size(3))total += test_label.size(0)accurate = acc / total # 计算整个test上面的正确率print('[epoch %d] train_loss: %.3f test_accuracy: %.3f %%' %(epoch + 1, running_loss/len, accurate*100))if accurate > best_acc: # 保留最好的精度best_acc = accuratetorch.save(net.state_dict(), save_path) # 保存网络参数
predict 预测部分:
import numpy as np
import torch
import cv2
from model import UNet
from torchvision import transforms
from PIL import Imagetransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5))])# 加载模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(in_channels=1, num_classes=1)
net.load_state_dict(torch.load('UNet.pth', map_location=device))
net.to(device)# 测试模式
net.eval()
with torch.no_grad():img = Image.open('./predict/img.png') # 读取预测的图片img = transform(img) # 预处理img = torch.unsqueeze(img,dim = 0) # 增加batch维度pred = net(img.to(device)) # 网络预测pred = torch.squeeze(pred) # 将(batch、channel)维度去掉pred = np.array(pred.data.cpu()) # 保存图片需要转为cpu处理pred[pred >=0 ] =255 # 转为二值图片pred[pred < 0 ] =0pred = np.uint8(pred) # 转为图片的形式cv2.imwrite('./result/res.png', pred) # 保存图片
UNet 网络做图像分割DRIVE数据集相关推荐
- 使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记
使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记 https://www.bilibili.com/video/BV1rq4y1w7xM?spm_id_from=33 ...
- tensorflow版使用uNet进行医学图像分割(Skin数据集)
tensorflow版使用uNet进行医学图像分割(Skin数据集) 深度学习.计算机视觉学习笔记.医学图像分割.uNet.Skin皮肤数据集 tensorflow版使用uNet进行医学图像分割(Sk ...
- Unet网络实现叶子病虫害图像分割
作者|李秋键 出品|AI科技大本营(ID:rgznai100) 智能化农业作为人工智能应用的重要领域,对较高的图像处理能力要求较高,其中图像分割作为图像处理方法在其中起着重要作用.图像分割是图像分析的 ...
- 【CV】基于UNet网络实现的人像分割 | 附数据集
文章来源于AI算法与图像处理,作者AI_study 今天要分享的是人像分割相关的内容,如果你喜欢的话,欢迎三连哦 主要内容 人像分割简介 UNet的简介 UNet实现人像分割 人像分割简介 人像分割的 ...
- unet训练自己的数据集_基于UNet网络实现的人像分割 | 附数据集
点击上方↑↑↑"OpenCV学堂"关注我 来源:公众号 AI算法与图像处理 授权 以后我会在公众号分享一些关于算法的应用(美颜相关的),工作之后,发现更重要的能力如何理解业务并将算 ...
- 基于UNet网络实现的人像分割 | 附数据集
点击上方"AI算法与图像处理",选择加"星标"或"置顶" 重磅干货,第一时间送达 以后我会在公众号分享一些关于算法的应用(美颜相关的),工作 ...
- FCN、Unet、Unet++:医学图像分割网络一览
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨Error@知乎 来源丨https://zhuanlan.z ...
- PyTorch:Unet网络实现脑肿瘤图像分割
1 介绍 U-Net是一篇基本结构非常好的论文,主要是针对生物医学图片的分割,而且,在今后的许多对医学图像的分割网络中,很大一部分会采取U-Net作为网络的主干.相对于当年的,在EM segmenta ...
- 【医学图像分割网络】之Res U-Net网络PyTorch复现
[医学图像分割网络]之Res U-Net网络PyTorch复现 1.内容 U-Net网络算是医学图像分割领域的开山之作,我接触深度学习到现在大概将近大半年时间,看到了很多基于U-Net网络的变体,后续 ...
最新文章
- PaddleClas
- Kaggle 机器学习竞赛冠军及优胜者的源代码汇总
- C#中格式化小数位数为指定位数的工具类
- re.containerbase.startinternal 子容器启动失败_微服务架构:基于微服务和Docker容器技术的PaaS云平台架构设计(微服务架构实施原理)...
- C++与Java语法上的不同,互联网 面试官 如何面试
- Coherence Step by Step 第三篇 缓存(一) 介绍(翻译)
- python动态心形代码_Python实现酷炫的动态交互式数据可视化,附代码!
- 斯坦福吴恩达《机器学习》--增强学习
- SQL Server中的即时文件初始化概述
- 你真的会用Jupyter吗?这里有7个进阶功能助你效率翻倍
- 城市不透水面空间分析——以宁波为例
- KVM 介绍(2):CPU 和内存虚拟化
- java js cookie_[Java教程]js简单操作Cookie
- matlab 画点标号,学习笔记(四)——MATLAB画图
- 阮一峰ES6之Generator函数理解
- 古体字与简体字对照表_古代汉语必备简化字与繁体字对照表
- Java 开发手册 - 编程规约 【控制语句】
- python byte什么意思_python bytes是什么
- 对List集合嵌套了map集合的排序
- C语言 有两个矩形 求重叠面积,计算两个矩形重叠面积的简单方法