项目工程文件结构如下:

参考了Retina_Unet项目,决定自己用代码来实现一遍,数据增强不是像Retina_Unet那样随机裁剪,而是将20个训练数据集按顺序裁剪,每张裁剪成48x48大小的144个patch,20张一共裁剪出2880个patch。

Unet模型:模型输入的张量形状为(Npatch,1,48,48),输出为(Npatch,2304,2)。Npatch表示训练集的样本数,本例中训练时为2880,预测时为144。
训练:把原作者代码中的SGD改为Adam,效果有提升。
预测:也需要先把待预测图像分割成48×48的小图,输入模型,然后把结果整理还原为完整图像,再和专家标注结果进行对比。代码中以测试集第一张图片为例,可自行修改为其他眼底图片路径。

util.py

import numpy as np
import cv2
import os
from PIL import Imagedef read_image_and_name(path):imgdir = os.listdir(path)imglst = []imgs = []for v in imgdir:imglst.append(path + v)imgs.append(cv2.imread(path + v))print(imglst)print('original images shape: ' + str(np.array(imgs).shape))return imglst,imgsdef read_label_and_name(path):labeldir = os.listdir(path)labellst = []labels = []for v in labeldir:labellst.append(path + v)labels.append(np.asarray(Image.open(path + v)))print(labellst)print('original labels shape: ' + str(np.array(labels).shape))return labellst,labelsdef resize(imgs,resize_height, resize_width):img_resize = []for file in imgs:img_resize.append(cv2.resize(file,(resize_height,resize_width)))return img_resize#将N张576x576的图片裁剪成48x48
def crop(image,dx):list = []for i in range(image.shape[0]):for x in range(image.shape[1] // dx):for y in range(image.shape[2] // dx):list.append(image[ i,  y*dx : (y+1)*dx,  x*dx : (x+1)*dx]) #这里的list一共append了20x12x12=2880次所以返回的shape是(2880,48,48)return np.array(list)# 网络预测输出转换成图像子块
# 网络预测输出 size=[Npatches, patch_height*patch_width, 2]
def pred_to_imgs(pred, patch_height, patch_width, mode="original"):assert (len(pred.shape)==3)  #3D array: (Npatches,height*width,2)assert (pred.shape[2]==2 )  #check the classes are 2  # 确认是否为二分类pred_images = np.empty((pred.shape[0],pred.shape[1]))  #(Npatches,height*width)if mode=="original": # 网络概率输出for i in range(pred.shape[0]):for pix in range(pred.shape[1]):pred_images[i,pix]=pred[i,pix,1] #pred[:, :, 0] 是反分割图像输出 pred[:, :, 1]是分割输出elif mode=="threshold": # 网络概率-阈值输出for i in range(pred.shape[0]):for pix in range(pred.shape[1]):if pred[i,pix,1]>=0.5:pred_images[i,pix]=1else:pred_images[i,pix]=0else:print("mode " +str(mode) +" not recognized, it can be 'original' or 'threshold'")exit()# 输出形式改写成(Npatches,1, patch_height, patch_width)pred_images = np.reshape(pred_images,(pred_images.shape[0],1, patch_height, patch_width))return pred_images

unet.py

from keras.models import Model
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, UpSampling2D, Reshape, core, Dropout #core内部定义了一系列常用的网络层,包括全连接、激活层等
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint, LearningRateSchedulerdef get_unet(n_ch,patch_height,patch_width):inputs = Input(shape=(n_ch,patch_height,patch_width))#data_format:字符串,“channels_first”或“channels_last”之一,代表图像的通道维的位置。#以128x128的RGB图像为例,“channels_first”应将数据组织为(3,128,128),而“channels_last”应将数据组织为(128,128,3)。该参数的默认值是~/.keras/keras.json中设置的值,若从未设置过,则为“channels_last”。conv1 = Conv2D(32, (3, 3), activation='relu', padding='same',data_format='channels_first')(inputs)conv1 = Dropout(0.2)(conv1)conv1 = Conv2D(32, (3, 3), activation='relu', padding='same',data_format='channels_first')(conv1)pool1 = MaxPooling2D((2, 2))(conv1)#conv2 = Conv2D(64, (3, 3), activation='relu', padding='same',data_format='channels_first')(pool1)conv2 = Dropout(0.2)(conv2)conv2 = Conv2D(64, (3, 3), activation='relu', padding='same',data_format='channels_first')(conv2)pool2 = MaxPooling2D((2, 2))(conv2)#conv3 = Conv2D(128, (3, 3), activation='relu', padding='same',data_format='channels_first')(pool2)conv3 = Dropout(0.2)(conv3)conv3 = Conv2D(128, (3, 3), activation='relu', padding='same',data_format='channels_first')(conv3)up1 = UpSampling2D(size=(2, 2))(conv3)up1 = concatenate([conv2,up1],axis=1)conv4 = Conv2D(64, (3, 3), activation='relu', padding='same',data_format='channels_first')(up1)conv4 = Dropout(0.2)(conv4)conv4 = Conv2D(64, (3, 3), activation='relu', padding='same',data_format='channels_first')(conv4)#up2 = UpSampling2D(size=(2, 2))(conv4)up2 = concatenate([conv1,up2], axis=1)conv5 = Conv2D(32, (3, 3), activation='relu', padding='same',data_format='channels_first')(up2)conv5 = Dropout(0.2)(conv5)conv5 = Conv2D(32, (3, 3), activation='relu', padding='same',data_format='channels_first')(conv5)##1×1的卷积的作用#大概有两个方面的作用:1. 实现跨通道的交互和信息整合2. 进行卷积核通道数的降维和升维。conv6 = Conv2D(2, (1, 1), activation='relu',padding='same',data_format='channels_first')(conv5)conv6 = core.Reshape((2,patch_height*patch_width))(conv6) #此时output的shape是(batchsize,2,patch_height*patch_width)conv6 = core.Permute((2,1))(conv6)    #此时output的shape是(Npatch,patch_height*patch_width,2)即输出维度是(Npatch,2304,2)############conv7 = core.Activation('softmax')(conv6)model = Model(inputs=inputs, outputs=conv7)# sgd = SGD(lr=0.01, decay=1e-6, momentum=0.3, nesterov=False)model.compile(optimizer=Adam(lr=0.001), loss='categorical_crossentropy',metrics=['accuracy'])return model'''模型Model的compile方法:compile(self, optimizer, loss, metrics=None, loss_weights=None, sample_weight_mode=None, weighted_metrics = None, target_tensors=None)本函数编译模型以供训练,参数有optimizer:         优化器,为预定义优化器名或优化器对.可以在调用model.compile()之前初始化一个优化器对象,然后传入该函数。loss:              损失函数,为预定义损失函数名或一个目标函数metrics:           列表,包含评估模型在训练和测试时的性能的指标,典型用法是metrics=['accuracy']如果要在多输出模型中为不同的输出指定不同的指标,可像该参数传递一个字典,例如metrics={'ouput_a': 'accuracy'}sample_weight_mode:如果需要按时间步为样本赋权(2D权矩阵),将该值设为“temporal”。默认为“None”,代表按样本赋权(1D权)。如果模型有多个输出,可以向该参数传入指定sample_weight_mode的字典或列表。在下面fit函数的解释中有相关的参考内容。weighted_metrics:   metrics列表,在训练和测试过程中,这些metrics将由sample_weight或clss_weight计算并赋权target_tensors:     默认情况下,Keras将为模型的目标创建一个占位符,该占位符在训练过程中将被目标数据代替。如果你想使用自己的目标张量(相应的,Keras将不会在训练时期望为这些目标张量载入外部的numpy数据),你可以通过该参数手动指定。目标张量可以是一个单独的张量(对应于单输出模型),也可以是一个张量列表,或者一个name->tensor的张量字典。kwargs:            使用TensorFlow作为后端请忽略该参数,若使用Theano/CNTK作为后端,kwargs的值将会传递给 K.function。如果使用TensorFlow为后端,这里的值会被传给tf.Session.run在Keras中,compile主要完成损失函数和优化器的一些配置,是为训练服务的。'''

train.py

from util import *
from unet import *if __name__ == '__main__':#参数和路径resize_height, resize_width = (576, 576)dx = 48img_path = 'DRIVE/training/images/'label_path = 'DRIVE/training/1st_manual/'#读取数据并resizeimglst,images = read_image_and_name(img_path)labellst,labels = read_label_and_name(label_path)imgs_resize = resize(images,resize_height, resize_width)labels_resize = resize(labels,resize_height, resize_width)#将imgs列表和manuals列表转换成numpy数组X_train = np.array(imgs_resize)Y_train = np.array(labels_resize)print(X_train.shape)print(Y_train.shape)#标准化X_train = X_train.astype('float32')/255Y_train = Y_train.astype('float32')/255#提取训练集的G通道X_train = X_train[...,1]#对训练数据进行裁剪X_train = crop(X_train,dx)Y_train = crop(Y_train,dx)print('X_train shape: '+str(X_train.shape)) #X_train(2880,48,48)print('Y_train shape: '+str(Y_train.shape)) #Y_train(2880,48,48)#X_train增加一维变成(2880,1,48,48)X_train = X_train[:,np.newaxis, ...]print('X_train shape: '+str(X_train.shape))#Y_train改变shape变成(2880,2304),保持第一维不变,其他维合并Y_train = Y_train.reshape(Y_train.shape[0],-1)print('Y_train shape: '+str(Y_train.shape))Y_train =Y_train[..., np.newaxis]  #增加一维变成(2880,2304,1)print('Y_train shape: '+str(Y_train.shape))temp = 1 - Y_trainY_train = np.concatenate([Y_train, temp], axis=2) #变成(2880,2304,2)print('Y_train shape: '+str(Y_train.shape))#获得modelmodel = get_unet(X_train.shape[1],X_train.shape[2],X_train.shape[3])model.summary() #输出参数Param计算过程checkpointer = ModelCheckpoint(filepath='best_weights.h5',verbose=1,monitor='val_acc',mode='auto',save_best_only=True)model.compile(optimizer=Adam(lr=0.001),loss='categorical_crossentropy',metrics=['accuracy'])model.fit(X_train,Y_train,batch_size=64,epochs=20,verbose=2,shuffle=True,validation_split=0.2,callbacks=[checkpointer])print('ok')

predict.py

import cv2
import numpy as np
from PIL import Image
from unet import get_unet
from util import pred_to_imgs
import matplotlib.pyplot as pltif __name__ == '__main__':resize_height, resize_width = (576, 576)dx = 48#读取预测图片imgs = cv2.imread('DRIVE/test/images/01_test.tif')[...,1] #读取G通道imgs = np.array(cv2.resize(imgs,(resize_height,resize_width))) #imgs现在是576x576大小#读取预测图片的标签label = np.array(Image.open('DRIVE/test/1st_manual/01_manual1.gif'))#预测图片和标签标准化X_test = imgs.astype('float32')/255print('X_test original shape: '+str(X_test.shape))Y_test = label.astype('float32')/255#对预测图片进行裁剪按行优先,裁剪成(144,48,48)list = []for i in range(resize_height//dx):for j in range(resize_width//dx):list.append(X_test[i*dx:(i+1)*dx, j*dx:(j+1)*dx])X_test = np.array(list)[:,np.newaxis,...] #增加一维变成(144,1,48,48)print('input shape: '+str(X_test.shape))#加载模型和权重并预测model = get_unet(1,dx,dx)model.load_weights('best_weights.h5')Y_pred = model.predict(X_test)print('predict shape: '+str(Y_pred.shape)) #预测结果的shape是(Npatches,patch_height*patch_width,2)#把预测输出的numpy数组拼接还原再显示Y_pred = Y_pred[..., 0]  #二分类提取出分割前景 现在Y_pred的shape是(144,2304) 且这个144是按照行优先来拼接的#对预测结果进行拼接,将(144,2304)拼接成(576,576)t=0image = np.zeros((resize_height,resize_width))for i in range(resize_height//dx):for j in range(resize_width//dx):temp = Y_pred[t].reshape(dx,dx)image[i*dx:(i+1)*dx, j*dx:(j+1)*dx] = tempt = t+1image = cv2.resize(image,((Y_test.shape[1], Y_test.shape[0]))) #将576x576大小的图像还原成原图像大小plt.figure(figsize=(6, 6))plt.imshow(image)plt.figure(figsize=(6, 6))plt.imshow(Y_test)plt.show()

预测结果

参考链接:
https://github.com/orobix/retina-unet
https://blog.csdn.net/Brikie/article/details/100177873

简明代码实现Unet眼底图像血管分割相关推荐

  1. Unet简明代码实现眼底图像血管分割

      Unet是一种自编码器网络结构,常用于医学图像分割任务,比如眼底图像血管分割.这位大佬已经开源了非常棒的代码,但是这套代码比较复杂,我初学菜鸟硬是啃了好几天才啃下来.现在我代码进行重写,仅保留最必 ...

  2. 基于U-Net的眼底图像血管分割实例

    [英文说明]https://github.com/orobix/retina-unet#retina-blood-vessel-segmentation-with-a-convolution-neur ...

  3. 零基础基于U-Net网络实战眼底图像血管提取

    文章目录 1 前言 2 血管提取任务概述 3 U-Net架构简介 4 眼底图像血管分割代码 5 结果评估可视化(ROC曲线) 6 改进U-Net网络完成眼底图像血管提取任务思路 1 前言 本文基于U- ...

  4. 眼底图像血管增强与分割--(5)基于Hessian矩阵的Frangi滤波算法

    在最优化里面提到过的hessian矩阵(http://blog.csdn.net/piaoxuezhong/article/details/60135153),本篇讲的方法主要是基于Hessian矩阵 ...

  5. 基于MATLAB的眼底视网膜静脉血管分割实现

    基于MATLAB的眼底视网膜静脉血管分割实现 眼底的视网膜图像对于眼科医生来说是非常重要的.其中,视网膜上血流情况可以为医生提供丰富的信息,如视网膜动脉硬化等.因此,对于眼底图像的分割和特征提取,对于 ...

  6. 基于matlab的眼底视网膜静脉血管分割仿真

    目录 1.算法概述 2.仿真效果 3.MATLAB源码 1.算法概述 随着图像数字化处理的快速发展,医学图像处理越来越受到人们的广泛关注.研究表明,人体许多全身性疾病都与眼底血管的异常有着密切的联系, ...

  7. 深度学习:使用UNet做图像语义分割,训练自己制作的数据集,详细教程

    语义分割(Semantic Segmentation)是图像处理和机器视觉一个重要分支.与分类任务不同,语义分割需要判断图像每个像素点的类别,进行精确分割.语义分割目前在自动驾驶.自动抠图.医疗影像等 ...

  8. 【深度学习】【U-net】医学图像(血管)分割实验记录

    医学图像分割实验记录 U-net介绍 数据集 实验记录 实验1 实验2(fail) 实验3(fail) 实验4(fail) 实验5(fail) 实验6(fail) 本项目仅用于大创实验,使用pytor ...

  9. 眼底影像血管分割(一):选择通道

    一:通道选择 一张眼底影像是RGB三色的,我们在做血管分割时,需要选择比较适合的图像来作为原始图像进行分割.那么选择哪个通道呢? 绿色通道?红色通道?蓝色通道? 好了,上图: 上图中四张图均来自同一张 ...

最新文章

  1. Only the original thread that created a view hierarchy can touch its views
  2. 通过RS232发送和接收短信(二)
  3. 在excel中如何增加组合框──EXCEL VBA的使用
  4. Caffe部署中的几个train-test-solver-prototxt-deploy等说明二
  5. Linux中的cp命令和mv命令
  6. 【渝粤教育】电大中专学前儿童发展心理学3作业 题库
  7. python print 输出到txt_(Python基础教程之七)Python字符串操作
  8. 5-10多分支网络结构
  9. python中列表、元组、字符串都属于有序序列_列表、元组、字符串是Python的有序序列。...
  10. CentOS RabbitMQ安装
  11. Android JNI 学习(十):String Operations Api Other Apis
  12. [转]installshield for VC++6 如何使用
  13. Neo4j-import导入CSV的数据
  14. 强大的SQL计算利器-SPL
  15. python中pdfplumber解析pdf_Python中pdfplumber如何提取pdf中的表格数据
  16. 如何注册域名的详细图文过程分享
  17. 数据库完整性详细解释
  18. eclipse安卓 DDMS中打不开Sdcard文件夹的问题
  19. 自然语言处理中的文本聚类
  20. 谷歌浏览器主页被挟持篡改2345www.dh012.com

热门文章

  1. Unity编辑器扩展——标签属性Attribute
  2. 《方与圆》序人生控制论 第二章 认识自己
  3. 区块链带来第四次技术革命 融入生产大幅提高企业收入
  4. 电脑操作相关英语词汇以及用法(持续更新)
  5. ArcMap 标注、注记、图形文本
  6. Echarts定制化组件展示网站(包括3d饼环图,3d柱状图,三维柱状图,水滴图)
  7. 【计算机网络】第四章:数据链路层(Part2.广播信道的数据链路)
  8. Model 处理模型数据
  9. 在vs2008 vc++ 中添加mfc中消息处理函数
  10. 沟通、务实、平等——读《Scrum and XP from the Trenches》