Unet是一种自编码器网络结构,常用于医学图像分割任务,比如眼底图像血管分割。这位大佬已经开源了非常棒的代码,但是这套代码比较复杂,我初学菜鸟硬是啃了好几天才啃下来。现在我代码进行重写,仅保留最必要的部分,并尽量简明,全部代码不到100行,便于初学者快速看懂能用。

1.任务简介

  本任务基于DRIVE数据集,将眼底图像中的血管树给分离出来。DRIVE数据集包含40张眼底图像,尺寸为565×584,其中20张为训练集,20张测试集,40张图片都给出了专家标注结果。数据集下载可自行百度,或者官网下载。
  本代码包含以下几个部分:
  数据加载及预处理:把图片分割成若干48×48的小图片,由于原图尺寸不能被48整除,这里先把原图尺寸resize为576×576。和大佬开源代码不同,这里没有使用随机选取的方式,而只使用原图分割出来的全部小图,相当于没有用数据增强,这样总共得到训练集2880个,训练精度比原作者使用190000个稍低一些,但训练速度会快很多,便于快速运行和调参。如果需要提高精度,可自行设计数据增强方法。
  Unet模型:模型输入的张量形状为(?,1,48,48),输出为(?,2340,2)。?表示训练集的样本数,本例中为2880。
  训练:把原作者代码中的SGD改为Adam,效果有提升。
  推理:也需要先把待预测图像分割成48×48的小图,输入模型,然后把结果整理还原为完整图像,再和专家标注结果进行对比。代码中以测试集第一张图片为例,可自行修改为其他眼底图片路径。

2.完整代码

import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import os
from keras.models import Model
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, UpSampling2D, Reshape, core, Dropout
from keras.optimizers import Adam, SGD
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras import backend as Kimg_x, img_y = (576, 576)
dx = 48
filelst = os.listdir('DRIVE/training/images/')
filelst = ['DRIVE/training/images/'+v for v in filelst]
imgs = [cv2.imread(file) for file in filelst]
filelst = os.listdir('DRIVE/training/1st_manual/')
filelst = ['DRIVE/training/1st_manual/'+v for v in filelst]
manuals = [np.asarray(Image.open(file)) for file in filelst]
imgs = [cv2.resize(v,(img_x, img_y)) for v in imgs]
manuals = [cv2.resize(v,(img_x, img_y)) for v in manuals]
X_train = np.array(imgs)
Y_train = np.array(manuals)
X_train = X_train.astype('float32')/255.
Y_train = Y_train.astype('float32')/255.
X_train = X_train[...,1] # the G channel
X_train = np.array([[X_train[:,v*dx:(v+1)*dx, vv*dx:(vv+1)*dx] for v in range(img_y//dx)] for vv in range(img_x//dx)]).reshape(-1,dx,dx)[:,np.newaxis,...]
Y_train = np.array([[Y_train[:,v*dx:(v+1)*dx, vv*dx:(vv+1)*dx] for v in range(img_y//dx)] for vv in range(img_x//dx)]).reshape(-1,dx*dx)[...,np.newaxis]
temp = 1-Y_train
Y_train = np.concatenate([Y_train,temp],axis=2)def unet_model(n_ch,patch_height,patch_width):inputs = Input(shape=(n_ch,patch_height,patch_width))conv1 = Conv2D(32, (3, 3), activation='relu', padding='same',data_format='channels_first')(inputs)conv1 = Dropout(0.2)(conv1)conv1 = Conv2D(32, (3, 3), activation='relu', padding='same',data_format='channels_first')(conv1)pool1 = MaxPooling2D((2, 2))(conv1)conv2 = Conv2D(64, (3, 3), activation='relu', padding='same',data_format='channels_first')(pool1)conv2 = Dropout(0.2)(conv2)conv2 = Conv2D(64, (3, 3), activation='relu', padding='same',data_format='channels_first')(conv2)pool2 = MaxPooling2D((2, 2))(conv2)conv3 = Conv2D(128, (3, 3), activation='relu', padding='same',data_format='channels_first')(pool2)conv3 = Dropout(0.2)(conv3)conv3 = Conv2D(128, (3, 3), activation='relu', padding='same',data_format='channels_first')(conv3)up1 = UpSampling2D(size=(2, 2))(conv3)up1 = concatenate([conv2,up1],axis=1)conv4 = Conv2D(64, (3, 3), activation='relu', padding='same',data_format='channels_first')(up1)conv4 = Dropout(0.2)(conv4)conv4 = Conv2D(64, (3, 3), activation='relu', padding='same',data_format='channels_first')(conv4)up2 = UpSampling2D(size=(2, 2))(conv4)up2 = concatenate([conv1,up2], axis=1)conv5 = Conv2D(32, (3, 3), activation='relu', padding='same',data_format='channels_first')(up2)conv5 = Dropout(0.2)(conv5)conv5 = Conv2D(32, (3, 3), activation='relu', padding='same',data_format='channels_first')(conv5)conv6 = Conv2D(2, (1, 1), activation='relu',padding='same',data_format='channels_first')(conv5)conv6 = core.Reshape((2,patch_height*patch_width))(conv6)conv6 = core.Permute((2,1))(conv6)conv7 = core.Activation('softmax')(conv6)model = Model(inputs=inputs, outputs=conv7)return modelmodel = unet_model(X_train.shape[1],X_train.shape[2],X_train.shape[3])
model.summary()checkpointer = ModelCheckpoint(filepath='best_weights.h5', verbose=1, monitor='val_acc', mode='auto', save_best_only=True)
model.compile(optimizer=Adam(lr=0.001), loss='categorical_crossentropy',metrics=['accuracy'])model.fit(X_train, Y_train, batch_size=64, epochs=20, verbose=2,shuffle=True, validation_split=0.2,callbacks=[checkpointer])imgs = cv2.imread('DRIVE/test/images/01_test.tif')[...,1] #the G channel
imgs = cv2.resize(imgs,(img_x, img_y))
manuals = np.asarray(Image.open('DRIVE/test/1st_manual/01_manual1.gif'))
X_test = imgs.astype('float32')/255.
Y_test = manuals.astype('float32')/255.
X_test = np.array([[X_test[v*dx:(v+1)*dx, vv*dx:(vv+1)*dx] for v in range(img_y//dx)] for vv in range(img_x//dx)]).reshape(-1,dx,dx)[:,np.newaxis,...]
model.load_weights('best_weights.h5')
Y_pred = model.predict(X_test)
Y_pred = Y_pred[...,0].reshape(img_x//dx,img_y//dx,dx,dx)
Y_pred = [Y_pred[:,v,...] for v in range(img_x//dx)]
Y_pred = np.concatenate(np.concatenate(Y_pred,axis=1),axis=1)
Y_pred = cv2.resize(Y_pred,(Y_test.shape[1], Y_test.shape[0]))
plt.figure(figsize=(6,6))
plt.imshow(Y_pred)
plt.figure(figsize=(6,6))
plt.imshow(Y_test)

3.运行结果

Epoch 7/20- 1s - loss: 0.1522 - acc: 0.9509 - val_loss: 0.1013 - val_acc: 0.9660Epoch 00007: val_acc improved from 0.96389 to 0.96602, saving model to best_weights.h5
......

这里给出测试集中第一张01_test.tif)

其标注结果01manual1.gif :

预测结果为:

4.进一步讨论:

  从上面的结果看,预测和标注值还是比较一致的,隐约可看到小图片的拼接线,如果小图片尺寸更改为64×64,拼接线会轻的多。下面用全部40张(包含测试集)进行训练,然后来处理一些其他眼底图像数据集的图片。这里选择了odir2019的几张图片。可以看出,血管分割的效果还是比较好的,但是有出血点或者屈光介质严重浑浊情况下效果会降低(训练集中没有这些情况)。


Unet简明代码实现眼底图像血管分割相关推荐

  1. 基于U-Net的眼底图像血管分割实例

    [英文说明]https://github.com/orobix/retina-unet#retina-blood-vessel-segmentation-with-a-convolution-neur ...

  2. 简明代码实现Unet眼底图像血管分割

    项目工程文件结构如下: 参考了Retina_Unet项目,决定自己用代码来实现一遍,数据增强不是像Retina_Unet那样随机裁剪,而是将20个训练数据集按顺序裁剪,每张裁剪成48x48大小的144 ...

  3. 零基础基于U-Net网络实战眼底图像血管提取

    文章目录 1 前言 2 血管提取任务概述 3 U-Net架构简介 4 眼底图像血管分割代码 5 结果评估可视化(ROC曲线) 6 改进U-Net网络完成眼底图像血管提取任务思路 1 前言 本文基于U- ...

  4. 眼底图像血管增强与分割--(5)基于Hessian矩阵的Frangi滤波算法

    在最优化里面提到过的hessian矩阵(http://blog.csdn.net/piaoxuezhong/article/details/60135153),本篇讲的方法主要是基于Hessian矩阵 ...

  5. 基于MATLAB的眼底视网膜静脉血管分割实现

    基于MATLAB的眼底视网膜静脉血管分割实现 眼底的视网膜图像对于眼科医生来说是非常重要的.其中,视网膜上血流情况可以为医生提供丰富的信息,如视网膜动脉硬化等.因此,对于眼底图像的分割和特征提取,对于 ...

  6. 基于matlab的眼底视网膜静脉血管分割仿真

    目录 1.算法概述 2.仿真效果 3.MATLAB源码 1.算法概述 随着图像数字化处理的快速发展,医学图像处理越来越受到人们的广泛关注.研究表明,人体许多全身性疾病都与眼底血管的异常有着密切的联系, ...

  7. java 绘制六边形_JAVA代码怎么实现图像六边形网格分割效果

    下面给大家介绍JAVA代码怎么实现图像六边形网格分割效果,希望能给大家提供帮助. 一:原理 根据输入参数blockSize的大小,将图像分块,决定每块的中心通过该像素块内所有像素之和的均值与该块内部每 ...

  8. 眼底影像血管分割(一):选择通道

    一:通道选择 一张眼底影像是RGB三色的,我们在做血管分割时,需要选择比较适合的图像来作为原始图像进行分割.那么选择哪个通道呢? 绿色通道?红色通道?蓝色通道? 好了,上图: 上图中四张图均来自同一张 ...

  9. 基于PaddleSeg实现眼底血管分割——助力医疗人员更高效检测视网膜疾病

    点击左上方蓝字关注我们 [飞桨开发者说]郑博培,北京联合大学机器人学院2018级自动化专业本科生,飞桨开发者技术专家PPDE,深圳市柴火创客空间认证会员,百度大脑智能对话训练师 项目背景 研究表明,各 ...

最新文章

  1. 机器学习中的数学全集 tsinghua 石溪
  2. zoj 1698 Easier Done Than Said?
  3. Intellij IDEA常用配置详解
  4. 有关“双重检查锁定失效”的说明
  5. 如何安装MiniGUI 3.0在Linux PC
  6. 银行利率涨了,定期存款有必要取出再存吗?
  7. vtun中setsockopt fcntl等有关套接字设置
  8. c语言程序设计移动字母,C语言程序设计模拟试题二(含答案)
  9. [软件测试_LAB1]安装junit和hamcrest及其使用
  10. Linux下2号进程的kthreadd--Linux进程的管理与调度(七)
  11. 电路与模拟电子技术(作业答案)
  12. 有道词典单词本导入到欧路词典单词本
  13. PLC 控制柜常用电气元件整理表
  14. 9x9九宫格java_数独9x9九宫格的口诀 9×9数独技巧
  15. Python拉宾米勒(判断素数)
  16. 【坐标转换】四参数和七参数计算,并正向转换坐标(附完整源代码地址)
  17. 手机助手+for+linux,你的手机助手(com.microsoft.appmanager) - 3.5.8 - 应用 - 酷安
  18. python中的repr_python中的 __repr__和__str__
  19. 华为Ebackup模板部署
  20. 搭建一个STC8H的最小系统

热门文章

  1. 需求:在微信h5页面中下载第三方app —— 安卓, 直接下载apk文件包;iphone,跳转AppStore
  2. iOS新浪微博分享SDK Check List
  3. 低层次特征提取(一)------------边缘检测(转载)
  4. Qt for MCUs
  5. 1069 微博转发抽奖 (20 分)
  6. 微信小程序HTTPS证书部署案例 1
  7. 计算机请假,计算机学院2020请假条模板(短期临时专用).docx
  8. 每个人的商学院--管理基础(第五章:管理常见病)--读书笔记
  9. linux基本命令及文件管理
  10. python统计excel出现次数_Excel-统计元素出现次数和统计不重复元素的个数