本文采用Unet3d进行BRAST2015数据集的分割

BRAST2015读取

BRAST2015数据集为脑部肿瘤分割的,数据集分为两类,一类为HGG(高分级胶质瘤),另一类为LGG(低分级胶质瘤)

文件夹目录如图所示

包含了MRI Flair T1 T1c T2四种模态,OT为ground Truth(0,1,2,3,4五种标签)

数据文件都为.mha文件,可以直接使用SimpleITK进行读取

def sitk_read(img_path):nda = sitk.ReadImage(img_path)nda = sitk.GetArrayFromImage(nda) #(155,240,240)nda = nda.transpose(1, 2, 0) #(240,240,155)return nda

直接调用SimpleITK.ReadImage得到Image对象,然后转成np,这里读出来的shape是(depth,height,width)

因为BRAST的数据集中只有train数据集有ground truth,val数据集是没有ground truth的,所以要在train数据集上分成train、val、test三部分用于网络,先读入所有train文件名用list存起来,然后用random随机打乱顺序,然后取8份为train,1份为val,1份为test,将三个数据集的文件名用txt存起来,以便于训练的时候直接读取

    def write_train_val_test_name_list(self):data_name_list = os.listdir(self.train_root_path + self.type + "\\")random.shuffle(data_name_list)length = len(data_name_list)n_train_file = int(length / 10 * 8)n_val_file = int(length / 10 * 1)train_name_list = data_name_list[0:n_train_file]val_name_list = data_name_list[n_train_file:(n_train_file + n_val_file)]test_name_list = data_name_list[(n_train_file + n_val_file):len(data_name_list)]self.write_name_list(train_name_list, "train_name_list.txt")self.write_name_list(val_name_list, "val_name_list.txt")self.write_name_list(test_name_list, "test_name_list.txt")def write_name_list(self, name_list, file_name):f = open(self.train_root_path + file_name, 'w')for i in range(len(name_list)):f.write(name_list[i] + "\n")f.close()

完整代码如下:

def make_one_hot_3d(x, n):one_hot = np.zeros([x.shape[0], x.shape[1], x.shape[2], n])for i in range(x.shape[0]):for j in range(x.shape[1]):for v in range(x.shape[2]):one_hot[i, j, v, int(x[i, j, v])] = 1return one_hotclass brast_reader:def __init__(self, train_batch_size, val_batch_size, test_batch_size, type='HGG'):self.train_root_path = "D:\\pyproject\\data\\BRATS2015\\BRATS2015_Training\\BRATS2015_Training\\"self.type = typeself.train_name_list = self.load_file_name_list(self.train_root_path + "train_name_list.txt")self.val_name_list = self.load_file_name_list(self.train_root_path + "val_name_list.txt")self.test_name_list = self.load_file_name_list(self.train_root_path + "test_name_list.txt")self.n_train_file = len(self.train_name_list)self.n_val_file = len(self.val_name_list)self.n_test_file = len(self.test_name_list)self.train_batch_size = train_batch_sizeself.val_batch_size = val_batch_sizeself.test_batch_size = test_batch_sizeself.n_train_steps_per_epoch = self.n_train_file // self.train_batch_sizeself.n_val_steps_per_epoch = self.n_val_file // self.val_batch_sizeself.img_height = 240self.img_width = 240self.img_depth = 160self.n_labels = 5self.train_batch_index = 0self.val_batch_index = 0def load_file_name_list(self, file_path):file_name_list = []with open(file_path, 'r') as file_to_read:while True:lines = file_to_read.readline().strip()  # 整行读取数据if not lines:breakpassfile_name_list.append(lines)passreturn file_name_listdef write_train_val_test_name_list(self):data_name_list = os.listdir(self.train_root_path + self.type + "\\")random.shuffle(data_name_list)length = len(data_name_list)n_train_file = int(length / 10 * 8)n_val_file = int(length / 10 * 1)train_name_list = data_name_list[0:n_train_file]val_name_list = data_name_list[n_train_file:(n_train_file + n_val_file)]test_name_list = data_name_list[(n_train_file + n_val_file):len(data_name_list)]self.write_name_list(train_name_list, "train_name_list.txt")self.write_name_list(val_name_list, "val_name_list.txt")self.write_name_list(test_name_list, "test_name_list.txt")def write_name_list(self, name_list, file_name):f = open(self.train_root_path + file_name, 'w')for i in range(len(name_list)):f.write(name_list[i] + "\n")f.close()def next_train_batch_2d(self):if self.train_batch_index >= self.n_train_file:self.train_batch_index = 0data_path = self.train_root_path + self.type + '\\' + self.train_name_list[self.train_batch_index]# flair, t1, t1c, t2, ot=self.get_np_data(data_path)t1, ot = self.get_np_data_2d(data_path)train_imgs=t1[:,:,:,np.newaxis] #(155,240,240,1)train_labels=make_one_hot_3d(ot,self.n_labels) #(155,240,240,5)self.train_batch_index+=1return train_imgs,train_labelsdef next_val_batch_2d(self):if self.val_batch_index >= self.n_val_file:self.val_batch_index = 0data_path = self.train_root_path + self.type + '\\' + self.val_name_list[self.val_batch_index]# flair, t1, t1c, t2, ot=self.get_np_data(data_path)t1, ot = self.get_np_data_2d(data_path)val_imgs=t1[:,:,:,np.newaxis] #(155,240,240,1)val_labels=make_one_hot_3d(ot,self.n_labels)self.val_batch_index += 1return val_imgs, val_labelsdef next_train_batch_3d(self):train_imgs = np.zeros((self.train_batch_size, self.img_height, self.img_width, self.img_depth, 1))train_labels = np.zeros([self.train_batch_size, self.img_height, self.img_width, self.img_depth, self.n_labels])if self.train_batch_index >= self.n_train_steps_per_epoch:self.train_batch_index = 0for i in range(self.train_batch_size):data_path = self.train_root_path + self.type + '\\' + self.train_name_list[self.train_batch_size * self.train_batch_index + i]# flair, t1, t1c, t2, ot=self.get_np_data(data_path)t1, ot = self.get_np_data_3d(data_path)# flair=flair[:,:,:,np.newaxis]t1 = t1[:, :, :, np.newaxis]# t1c = t1c[:, :, :, np.newaxis]# t2 = t2[:, :, :, np.newaxis]train_imgs[i] = t1one_hot = make_one_hot_3d(ot, self.n_labels)train_labels[i] = one_hotself.train_batch_index += 1return train_imgs, train_labelsdef next_val_batch_3d(self):val_imgs = np.zeros((self.train_batch_size, self.img_height, self.img_width, self.img_depth, 1))val_labels = np.zeros([self.train_batch_size, self.img_height, self.img_width, self.img_depth, self.n_labels])if self.val_batch_index >= self.n_val_steps_per_epoch:self.val_batch_index = 0for i in range(self.val_batch_size):data_path = self.train_root_path + self.type + '\\' + self.val_name_list[self.val_batch_size * self.val_batch_index + i]# flair, t1, t1c, t2, ot=self.get_np_data(data_path)t1, ot = self.get_np_data_3d(data_path)# flair=flair[:,:,:,np.newaxis]t1 = t1[:, :, :, np.newaxis]# t1c = t1c[:, :, :, np.newaxis]# t2 = t2[:, :, :, np.newaxis]val_imgs[i] = t1one_hot = make_one_hot_3d(ot, self.n_labels)val_labels[i] = one_hotself.val_batch_index += 1return val_imgs, val_labelsdef get_np_data_3d(self, data_path):for i in glob.glob(os.path.join(data_path, 'VSD.Brain.XX.O.MR_T1.*\\VSD.Brain.XX.O.MR_T1.*.mha')):t1_file_path = ifor i in glob.glob(os.path.join(data_path, 'VSD.Brain_3more.XX.*\\VSD.Brain_3more.XX.*.mha')):ot_file_path = i'''for i in glob.glob(os.path.join(data_path,'VSD.Brain.XX.O.MR_Flair.*\\VSD.Brain.XX.O.MR_Flair.*.mha')):flair_file_path=ifor i in glob.glob(os.path.join(data_path,'VSD.Brain.XX.O.MR_T1c.*\\VSD.Brain.XX.O.MR_T1c.*.mha')):t1c_file_path=ifor i in glob.glob(os.path.join(data_path,'VSD.Brain.XX.O.MR_T2.*\\VSD.Brain.XX.O.MR_T2.*.mha')):t2_file_path=iflair=sitk_read(flair_file_path)t1c=sitk_read(t1c_file_path)t2=sitk_read(t2_file_path)'''t1 = sitk_read(t1_file_path)ot = sitk_read(ot_file_path)return t1, ot# return flair,t1,t1c,t2,otdef get_np_data_2d(self, data_path):for i in glob.glob(os.path.join(data_path, 'VSD.Brain.XX.O.MR_T1.*\\VSD.Brain.XX.O.MR_T1.*.mha')):t1_file_path = ifor i in glob.glob(os.path.join(data_path, 'VSD.Brain_3more.XX.*\\VSD.Brain_3more.XX.*.mha')):ot_file_path = i'''for i in glob.glob(os.path.join(data_path,'VSD.Brain.XX.O.MR_Flair.*\\VSD.Brain.XX.O.MR_Flair.*.mha')):flair_file_path=ifor i in glob.glob(os.path.join(data_path,'VSD.Brain.XX.O.MR_T1c.*\\VSD.Brain.XX.O.MR_T1c.*.mha')):t1c_file_path=ifor i in glob.glob(os.path.join(data_path,'VSD.Brain.XX.O.MR_T2.*\\VSD.Brain.XX.O.MR_T2.*.mha')):t2_file_path=iflair=sitk_read(flair_file_path)t1c=sitk_read(t1c_file_path)t2=sitk_read(t2_file_path)'''t1 = sitk_read_row(t1_file_path)ot = sitk_read_row(ot_file_path)return t1, ot# return flair,t1,t1c,t2,ot

U-Net3D

import keras.backend as K
from keras.engine import Input, Model
import keras
from keras.optimizers import Adam
from keras.layers import BatchNormalization, Activation, Conv3D, Conv3DTranspose, MaxPooling3D
import metrics as m
from keras.layers.core import Lambda
import numpy as npdef up_and_concate_3d(down_layer, layer):in_channel = down_layer.get_shape().as_list()[4]out_channel = in_channel // 2up = Conv3DTranspose(out_channel, [2, 2, 2], strides=[2, 2, 2], padding='valid')(down_layer)print("--------------")print(str(up.get_shape()))print(str(layer.get_shape()))print("--------------")my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=4))concate = my_concat([up, layer])# must use lambda# concate=K.concatenate([up, layer], 3)return concatedef attention_block_3d(x, g, inter_channel):''':param x: x input from down_sampling same layer output x(?,x_height,x_width,x_depth,x_channel):param g: gate input from up_sampling layer last output g(?,g_height,g_width,g_depth,g_channel)g_height,g_width,g_depth=x_height/2,x_width/2,x_depth/2:return:'''# theta_x(?,g_height,g_width,g_depth,inter_channel)theta_x = Conv3D(inter_channel, [2, 2, 2], strides=[2, 2, 2])(x)# phi_g(?,g_height,g_width,g_depth,inter_channel)phi_g = Conv3D(inter_channel, [1, 1, 1], strides=[1, 1, 1])(g)# f(?,g_height,g_width,g_depth,inter_channel)f = Activation('relu')(keras.layers.add([theta_x, phi_g]))# psi_f(?,g_height,g_width,g_depth,1)psi_f = Conv3D(1, [1, 1, 1], strides=[1, 1, 1])(f)# sigm_psi_f(?,g_height,g_width,g_depth)sigm_psi_f = Activation('sigmoid')(psi_f)# rate(?,x_height,x_width,x_depth)rate = UpSampling3D(size=[2, 2, 2])(sigm_psi_f)# att_x(?,x_height,x_width,x_depth,x_channel)att_x = keras.layers.multiply([x, rate])return att_xdef unet_model_3d(input_shape, n_labels, batch_normalization=False, initial_learning_rate=0.00001,metrics=m.dice_coefficient):"""input_shape:without batch_size,(img_height,img_width,img_depth)metrics:"""inputs = Input(input_shape)down_layer = []layer = inputs# down_layer_1layer = res_block_v2_3d(layer, 64, batch_normalization=batch_normalization)down_layer.append(layer)layer = MaxPooling3D(pool_size=[2, 2, 2], strides=[2, 2, 2],padding='same')(layer)print(str(layer.get_shape()))# down_layer_2layer = res_block_v2_3d(layer, 128, batch_normalization=batch_normalization)down_layer.append(layer)layer = MaxPooling3D(pool_size=[2, 2, 2], strides=[2, 2, 2],padding='same')(layer)print(str(layer.get_shape()))# down_layer_3layer = res_block_v2_3d(layer, 256, batch_normalization=batch_normalization)down_layer.append(layer)layer = MaxPooling3D(pool_size=[2, 2, 2], strides=[2, 2, 2],padding='same')(layer)print(str(layer.get_shape()))# down_layer_4layer = res_block_v2_3d(layer, 512, batch_normalization=batch_normalization)down_layer.append(layer)layer = MaxPooling3D(pool_size=[2, 2, 2], strides=[2, 2, 2],padding='same')(layer)print(str(layer.get_shape()))# bottle_layerlayer = res_block_v2_3d(layer, 1024, batch_normalization=batch_normalization)print(str(layer.get_shape()))# up_layer_4layer = up_and_concate_3d(layer, down_layer[3])layer = res_block_v2_3d(layer, 512, batch_normalization=batch_normalization)print(str(layer.get_shape()))# up_layer_3layer = up_and_concate_3d(layer, down_layer[2])layer = res_block_v2_3d(layer, 256, batch_normalization=batch_normalization)print(str(layer.get_shape()))# up_layer_2layer = up_and_concate_3d(layer, down_layer[1])layer = res_block_v2_3d(layer, 128, batch_normalization=batch_normalization)print(str(layer.get_shape()))# up_layer_1layer = up_and_concate_3d(layer, down_layer[0])layer = res_block_v2_3d(layer, 64, batch_normalization=batch_normalization)print(str(layer.get_shape()))# score_layerlayer = Conv3D(n_labels, [1, 1, 1], strides=[1, 1, 1])(layer)print(str(layer.get_shape()))# softmaxlayer = Activation('softmax')(layer)print(str(layer.get_shape()))outputs = layermodel = Model(inputs=inputs, outputs=outputs)metrics = [metrics]model.compile(optimizer=Adam(lr=initial_learning_rate), loss=m.dice_coefficient_loss, metrics=metrics)return modeldef res_block_v2_3d(input_layer, out_n_filters, batch_normalization=False, kernel_size=[3, 3, 3], stride=[1, 1, 1],padding='same'):input_n_filters = input_layer.get_shape().as_list()[3]print(str(input_layer.get_shape()))layer = input_layerfor i in range(2):if batch_normalization:layer = BatchNormalization()(layer)layer = Activation('relu')(layer)layer = Conv3D(out_n_filters, kernel_size, strides=stride, padding=padding)(layer)if out_n_filters != input_n_filters:skip_layer = Conv3D(out_n_filters, [1, 1, 1], strides=stride, padding=padding)(input_layer)else:skip_layer = input_layerout_layer = keras.layers.add([layer, skip_layer])return out_layer

和之前那篇使用u-net2d进行voc分割的网络结构没有什么区别,只是将卷积、pooling、concate操作都改成了3维操作

在跑的过程中发现了一些bug,最新代码可在git上找到,以最近代码为准:

https://github.com/panxiaobai/brats_keras

[深度学习从入门到女装]keras实战-Unet3d(BRAST2015)相关推荐

  1. 深度学习从入门到精通——图像分割实战DeeplabV3

    DeeplabV3算法 参数配置 关于数据集的配置 训练集参数 数据预处理模块 DataSet构建模块 测试一下数据集 去正则化 模型加载模块 DeepLABV3+ 参数配置 关于数据集的配置 par ...

  2. [深度学习从入门到女装]V-Net

    论文地址:V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation V-Net相当于对于 ...

  3. [深度学习从入门到女装]DeepLab v3

    DeepLab v3 论文地址:Rethinking Atrous Convolution for Semantic Image Segmentation 1.相比于DeepLab v2,有一点改进就 ...

  4. [深度学习从入门到女装]DSFD: Dual Shot Face Detector

    论文地址:https://arxiv.org/abs/1810.10220v2 Dual Shot Face Detector 这是一篇CVPR19来自腾讯的论文,本文是在SSD的基础上,加入了Fea ...

  5. 《Keras深度学习:入门、实战与进阶》之印第安人糖尿病诊断

    本文摘自<Keras深度学习:入门.实战与进阶>. 1.数据理解 本节使用Pima Indians糖尿病发病情况数据集.该数据集最初来自国家糖尿病/消化/肾脏疾病研究所.数据集的目标是基于 ...

  6. 150页书籍《PyTorch 深度学习快速入门指南》附PDF电子版

    为什么说是极简教程,首先本书只涵盖了150页.内容比较精简,特别适合作为 PyTorch 深度学习的入门书籍.为什么这么说呢?因为很多时候,一份厚重的书籍往往会削弱我们学习的积极性,在学习一门新的语言 ...

  7. 深度学习地震勘探入门

    深度学习地震勘探入门 简介 我们在论文中提供了一个例子,但是由于数据不容易下载,很多同学没有测试成功,这个帖子中我们将这个例子进行了详细注释,同时提供手把手教学,数据也上传到了百度网盘.如果大家觉得有 ...

  8. 深度学习之图像识别核心技术与案例实战

    <神经网络与深度学习讲义>pdf下载在线阅读全文,求百度网盘云资源 求<神经网络与深度学习讲义>全文免费下载百度网盘资源,谢谢~ <深度学习之图像识别核心技术与案例实战& ...

  9. 深度学习框架PyTorch快速开发与实战

    深度学习框架PyTorch快速开发与实战 邢梦来,王硕,孙洋洋 著 ISBN:9787121345647 包装:平装 开本:16开 用纸:胶版纸 正文语种:中文 出版社:电子工业出版社 出版时间:20 ...

最新文章

  1. 2016.8.11 DataTable合并及排除重复方法
  2. FreeSWITCH IVR中lua调用并执行nodejs代码
  3. centos lamp 连接mysql_centOS下lamp安装
  4. linux命令老是忘记,Linux可能会忘记的命令整理
  5. 使用ANT打包Android应用
  6. redis 备份导出rdb_Redis数据迁移利器之redisshake
  7. 支付宝支付-支付宝PC端扫码支付
  8. 程序结束后去哪儿了?
  9. VBA如何打开html文件6,VBA程序中如何自动打开网页
  10. 跨应用的访问 contentprovider
  11. centos 6.5 编译php mysql5.6_CentOS6.5 编译安装PHP5.6(apache模块)
  12. 电脑PC端实现微信多开
  13. 安全运维基础知识梳理
  14. 移动H5前端性能优化指南
  15. 信息与网络安全 Diffie-Hellman密匙交换算法 题目练习
  16. 《德鲁克管理思想精要》读书笔记7 - 如何做人
  17. 【Nacos】Nacos注册中心的使用
  18. 11.8版本更新公告:灵罗娃娃 格温登场
  19. mysql连接数怎么清理_MySQL连接数太多应该怎么解决?
  20. mysql启动失败 查看日志文件_mysql诊断启动问题、查看日志文件详解

热门文章

  1. AI药物设计与新药先导化合物筛选
  2. C编程经验总结5(剧终)
  3. compose的TextField等组件请求焦点和释放焦点
  4. iframe 自适应高度的多种实现方式
  5. 大数据专业怎么样?是什么?
  6. 天津理工大学《操作系统》实验二,存储器的分配与回收算法实现,代码详解,保姆式注释讲解
  7. 微软模拟飞行10厦门航空涂装_《微软模拟飞行》好评如潮:无与伦比飞行体验...
  8. 视频转为GIF怎么做 如何做GIF动态图
  9. 警告:不能读取 AppletViewer 的属性文件解决
  10. 新闻资讯门户类网站源码 织梦dedecms内核