目录

1. 介绍

2. 搭建 UNet 网络

3. dataset 数据加载

4. train 训练网络

5. predict 分割图像

6. show

7. 完整代码


1. 介绍

项目的目录如下所示

  1. DRIVE 存放的是数据集
  2. predict 是待分割的图像
  3. result 里面放分割predict 的结果
  4. dataset 是处理数据的文件、model存放unet网络、predict是预测、train是网络的训练、UNet.pth 是训练好的权重文件

之前做了一个图像分割的例子,里面大部分的代码和本篇的内容重合,所以每个脚本的代码只会做简单的介绍。具体的可以参考之前的内容,这里给出链接:

model :  UNet - unet网络

dataset :UNet - 数据加载 Dataset

train : UNet - 训练数据train

predict : UNet - 预测数据predict(多个图像的分割)

DRIVE ( Digital Retinal Images for Vessel Extraction ):用于血管提取的数字视网膜图像

训练样本:灰度图像

对应的标签:二值图像

因为这个分割项目完成几周了,最近才整理。所以,原数据集 DRIVE 可能是彩色图像 + mask 掩膜(具体的记不清了)

  • 这里没有使用 mask
  • 如果是彩色图像的话,在生成unet网络的时候,传入的channel设置成3就行了。或者想用灰度图像的形式,要么用opencv转一下,可以看见灰度化的效果类似于展示的那样;要么在预处理的里面转成灰度图片 transform.Grayscale()

2. 搭建 UNet 网络

和之前unet网络不同的是,这里通过填充size,可以保证任意图像维度的输入

之前的代码需要经过4此下采样,每次维度扩展,size减半,所以需要保证输入图像的大小是 2的4次方

具体这块怎么实现我也看不懂,经过测试,可以实现任意输入的size

3. dataset 数据加载

数据加载的时候,将图像的预处理也放到了这里

这里训练的图像要 ToTensor ,归一化+改变通道顺序+转为tensor等等。同时,为了加快训练,对图像正规化,因为训练的图像是灰度图,所以只需要单通道的均值和标准差


然后是 数据加载 的初始化

这里的imgs里面的内容是,传入路径root下的图像路径,这里是:

['01.png', '02.png', '03.png', '04.png', '05.png', '06.png', '07.png', '08.png', '09.png']

self.imgs 是将root 路径和root 里面每个图像的路径 拼接在一块的路径,这里是:

['./DRIVE/test/image\\01.png', './DRIVE/test/image\\02.png', './DRIVE/test/image\\03.png', './DRIVE/test/image\\04.png', './DRIVE/test/image\\05.png', './DRIVE/test/image\\06.png', './DRIVE/test/image\\07.png', './DRIVE/test/image\\08.png', './DRIVE/test/image\\09.png']

如图:


初始化路径和预处理后,需要对图像进行处理

这里训练的样本和对应的二值图像的label文件名要保证一样,否则需要做别的处理。例如,这里只需要将训练样本的图像路径里面的image 替换(replace)成label 就能找到对应的分割图像

然后读取图像,预处理之后,在进行返回即可。

这里为了防止label不是严格的二值图像,在归一化(灰度值 / 255)后,将中间的灰度值也映射为前景像素点

4. train 训练网络

训练网络的代码基本上没有改变,这里简单介绍

判断网络运行的设备,将网络to到device上

加载训练集+测试集

这里传入的是训练的样本,因为Data_loader 会将样本的路径替换成 label找到对应分割的标签图像

因为内存不足,所以这里将batch size 设置成 1

然后定义优化器+损失函数,并且保存网络的训练权重文件

有关BCEWithLogitsLoss可以参考这个:聊聊关于图像分割的损失函数 - BCEWithLogitsLoss

训练的时候,需要网络在train模式下,然后就是正确的前向传播预测+反向梯度下降的内容

最后是计算正确率,需要将网络放到eval模式下

这里将网络的预测转为二值图像,然后计算准确率的方式是预测的二值图像和label进行逐个像素点的比对,最后比上整幅图像的空间分辨率,即图像的大小。

test_label 的通道顺序是:batch、channel、height、width

5. predict 分割图像

这里的预处理要和处理样本的预处理一致

加载网络+读取网络参数

预测的时候,需要扩展维度。保存图像的时候,需要将batch和channel减去

然后将预测的结果转为二值图像就可以了

6. show

训练了20个epoch,结果显示如下

这里来预测的图像在test数据集里面,predict里面的图像为:

UNet 分割的结果:

真实的label为:

分割了大部分的信息,但是仍有细节没有分割出来

图像的size 是 565*584 的,大概预测的准确率是 0.96 左右

也就是说 还有 565*584*0.04 = 13198 ,这些损失的像素点就是缺少的细节

7. 完整代码

model部分:

import torch.nn as nn
import torch
import torch.nn.functional as F# 搭建unet 网络
class DoubleConv(nn.Module):    # 连续两次卷积def __init__(self,in_channels,out_channels):super(DoubleConv,self).__init__()self.double_conv = nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1,bias=False),nn.BatchNorm2d(out_channels),                           # 用 BN 代替 Dropoutnn.ReLU(inplace=True),nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1,bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self,x):x = self.double_conv(x)return xclass Down(nn.Module):   # 下采样def __init__(self,in_channels,out_channels):super(Down, self).__init__()self.downsampling = nn.Sequential(nn.MaxPool2d(kernel_size=2,stride=2),DoubleConv(in_channels,out_channels))def forward(self,x):x = self.downsampling(x)return xclass Up(nn.Module):    # 上采样def __init__(self, in_channels, out_channels):super(Up,self).__init__()self.upsampling = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) # 转置卷积self.conv = DoubleConv(in_channels, out_channels)def forward(self, x1, x2):x1 = self.upsampling(x1)diffY = torch.tensor([x2.size()[2] - x1.size()[2]])         # 确保任意size的图像输入diffX = torch.tensor([x2.size()[3] - x1.size()[3]])x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])x = torch.cat([x2, x1], dim=1)  # 从channel 通道拼接x = self.conv(x)return xclass OutConv(nn.Module):   # 最后一个网络的输出def __init__(self, in_channels, num_classes):super(OutConv, self).__init__()self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)def forward(self, x):return self.conv(x)class UNet(nn.Module):   # unet 网络def __init__(self, in_channels = 1, num_classes = 1):super(UNet, self).__init__()self.in_channels = in_channelsself.num_classes = num_classesself.in_conv = DoubleConv(in_channels, 64)self.down1 = Down(64, 128)self.down2 = Down(128, 256)self.down3 = Down(256, 512)self.down4 = Down(512, 1024)self.up1 = Up(1024, 512)self.up2 = Up(512, 256)self.up3 = Up(256, 128)self.up4 = Up(128, 64)self.out_conv = OutConv(64, num_classes)def forward(self, x):x1 = self.in_conv(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)x = self.out_conv(x)return x

dataset 数据处理部分:

import os
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transformsdata_transform = {"train": transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, ), (0.5, ))]),"test": transforms.Compose([transforms.ToTensor()])
}# 数据处理文件
class Data_Loader(Dataset):     # 加载数据def __init__(self, root, transforms_train=data_transform['train'],transforms_test=data_transform['test']):    # 初始化imgs = os.listdir(root)                                                         # 读取图像的路径self.imgs = [os.path.join(root,img) for img in imgs]                            # 取出路径下所有的图片self.transforms_train = transforms_train                                        # 预处理self.transforms_test = transforms_testdef __getitem__(self, index):                      # 获取数据、预处理等等image_path = self.imgs[index]                  # 根据index读取图片label_path = image_path.replace('image', 'label')   # 根据image_path生成label_pathimage = Image.open(image_path)                      # 读取图片和对应的label图label = Image.open(label_path)image = self.transforms_train(image)        # 样本预处理label = self.transforms_test(label)         # label 预处理label[label > 0] = 1return image, labeldef __len__(self):  # 返回样本的数量return len(self.imgs)

train 网络训练部分:

from model import UNet
from dataset import Data_Loader
from torch import optim
import torch.nn as nn
import torch# 网络训练模块
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   # GPU or CPU
print(device)
net = UNet(in_channels=1, num_classes=1)        # 加载网络
net.to(device)                                  # 将网络加载到device上# 加载训练集
trainset = Data_Loader("./DRIVE/train/image")
train_loader = torch.utils.data.DataLoader(dataset=trainset,batch_size=1,shuffle=True)
len = len(trainset)                         # 样本总数为 31# 加载测试集
testset = Data_Loader("./DRIVE/test/image")
test_loader = torch.utils.data.DataLoader(dataset=testset,batch_size=1)# 加载优化器和损失函数
optimizer = optim.RMSprop(net.parameters(), lr=0.00001,weight_decay=1e-8, momentum=0.9)     # 定义优化器
criterion = nn.BCEWithLogitsLoss()                             # 定义损失函数# 保存网络参数
save_path = './UNet.pth'       # 网络参数的保存路径
best_acc = 0.0                 # 保存最好的准确率# 训练
for epoch in range(20):net.train()     # 训练模式running_loss = 0.0for image,label in train_loader:optimizer.zero_grad()                          # 梯度清零pred = net(image.to(device))                   # 前向传播loss = criterion(pred, label.to(device))       # 计算损失loss.backward()                                # 反向传播optimizer.step()                               # 梯度下降running_loss += loss.item()                    # 计算损失和net.eval()  # 测试模式acc = 0.0   # 正确率total = 0with torch.no_grad():for test_image, test_label in test_loader:outputs = net(test_image.to(device))     # 前向传播outputs[outputs >= 0] = 1  # 将预测图片转为二值图片outputs[outputs < 0] = 0# 计算预测图片与真实图片像素点一致的精度:acc = 相同的 / 总个数acc += (outputs == test_label.to(device)).sum().item() / (test_label.size(2) * test_label.size(3))total += test_label.size(0)accurate = acc / total  # 计算整个test上面的正确率print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f %%' %(epoch + 1, running_loss/len, accurate*100))if accurate > best_acc:     # 保留最好的精度best_acc = accuratetorch.save(net.state_dict(), save_path)     # 保存网络参数

predict 预测部分:

import numpy as np
import torch
import cv2
from model import UNet
from torchvision import transforms
from PIL import Imagetransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5))])# 加载模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(in_channels=1, num_classes=1)
net.load_state_dict(torch.load('UNet.pth', map_location=device))
net.to(device)# 测试模式
net.eval()
with torch.no_grad():img = Image.open('./predict/img.png')           # 读取预测的图片img = transform(img)                            # 预处理img = torch.unsqueeze(img,dim = 0)              # 增加batch维度pred = net(img.to(device))                      # 网络预测pred = torch.squeeze(pred)                      # 将(batch、channel)维度去掉pred = np.array(pred.data.cpu())                # 保存图片需要转为cpu处理pred[pred >=0 ] =255                            # 转为二值图片pred[pred < 0 ] =0pred = np.uint8(pred)                           # 转为图片的形式cv2.imwrite('./result/res.png', pred)           # 保存图片

UNet 网络做图像分割DRIVE数据集相关推荐

  1. 使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记

    使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记 https://www.bilibili.com/video/BV1rq4y1w7xM?spm_id_from=33 ...

  2. tensorflow版使用uNet进行医学图像分割(Skin数据集)

    tensorflow版使用uNet进行医学图像分割(Skin数据集) 深度学习.计算机视觉学习笔记.医学图像分割.uNet.Skin皮肤数据集 tensorflow版使用uNet进行医学图像分割(Sk ...

  3. Unet网络实现叶子病虫害图像分割

    作者|李秋键 出品|AI科技大本营(ID:rgznai100) 智能化农业作为人工智能应用的重要领域,对较高的图像处理能力要求较高,其中图像分割作为图像处理方法在其中起着重要作用.图像分割是图像分析的 ...

  4. 【CV】基于UNet网络实现的人像分割 | 附数据集

    文章来源于AI算法与图像处理,作者AI_study 今天要分享的是人像分割相关的内容,如果你喜欢的话,欢迎三连哦 主要内容 人像分割简介 UNet的简介 UNet实现人像分割 人像分割简介 人像分割的 ...

  5. unet训练自己的数据集_基于UNet网络实现的人像分割 | 附数据集

    点击上方↑↑↑"OpenCV学堂"关注我 来源:公众号 AI算法与图像处理 授权 以后我会在公众号分享一些关于算法的应用(美颜相关的),工作之后,发现更重要的能力如何理解业务并将算 ...

  6. 基于UNet网络实现的人像分割 | 附数据集

    点击上方"AI算法与图像处理",选择加"星标"或"置顶" 重磅干货,第一时间送达 以后我会在公众号分享一些关于算法的应用(美颜相关的),工作 ...

  7. FCN、Unet、Unet++:医学图像分割网络一览

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨Error@知乎 来源丨https://zhuanlan.z ...

  8. PyTorch:Unet网络实现脑肿瘤图像分割

    1 介绍 U-Net是一篇基本结构非常好的论文,主要是针对生物医学图片的分割,而且,在今后的许多对医学图像的分割网络中,很大一部分会采取U-Net作为网络的主干.相对于当年的,在EM segmenta ...

  9. 【医学图像分割网络】之Res U-Net网络PyTorch复现

    [医学图像分割网络]之Res U-Net网络PyTorch复现 1.内容 U-Net网络算是医学图像分割领域的开山之作,我接触深度学习到现在大概将近大半年时间,看到了很多基于U-Net网络的变体,后续 ...

最新文章

  1. PaddleClas
  2. Kaggle 机器学习竞赛冠军及优胜者的源代码汇总
  3. C#中格式化小数位数为指定位数的工具类
  4. re.containerbase.startinternal 子容器启动失败_微服务架构:基于微服务和Docker容器技术的PaaS云平台架构设计(微服务架构实施原理)...
  5. C++与Java语法上的不同,互联网 面试官 如何面试
  6. Coherence Step by Step 第三篇 缓存(一) 介绍(翻译)
  7. python动态心形代码_Python实现酷炫的动态交互式数据可视化,附代码!
  8. 斯坦福吴恩达《机器学习》--增强学习
  9. SQL Server中的即时文件初始化概述
  10. 你真的会用Jupyter吗?这里有7个进阶功能助你效率翻倍
  11. 城市不透水面空间分析——以宁波为例
  12. KVM 介绍(2):CPU 和内存虚拟化
  13. java js cookie_[Java教程]js简单操作Cookie
  14. matlab 画点标号,学习笔记(四)——MATLAB画图
  15. 阮一峰ES6之Generator函数理解
  16. 古体字与简体字对照表_古代汉语必备简化字与繁体字对照表
  17. Java 开发手册 - 编程规约 【控制语句】
  18. python byte什么意思_python bytes是什么
  19. 对List集合嵌套了map集合的排序
  20. C语言 有两个矩形 求重叠面积,计算两个矩形重叠面积的简单方法

热门文章

  1. python循环语句中的乘法_python循环语句详细讲解
  2. super关键字的使用详解
  3. 火狐浏览器代理服务器拒绝连接
  4. java使用poi导出excel设置颜色问题
  5. vue中velocity
  6. python双重直方图_Python 2.x中两幅图像的直方图匹配?
  7. Mysql之IGNORE关键字
  8. 开脑洞,买买买网站的皮肤
  9. 七层负载均衡HAproxy生产环境LVS+Keepalived+HAproxy(三)
  10. 【知识星球】视频分析/光流估计网络系列上线