实战:使用Pytorch搭建分类网络(肺结节假阳性剔除)

阅前可看:
实战:使用yolov3完成肺结节检测(Luna16数据集)及肺实质分割
其中的脚本资源getMat.py文件是对肺结节进行切割。
注意:
本博客内容没有上述博客内容完美衔接,这里只是提供思路(在完成肺结节检测之后进行假阳性剔除)以及练习使用pytorch搭建深度学习网络。

  1. 样本生成
  2. Pytorch分类网络搭建
  3. 训练+测试

注意:原始脚本有误,因为进行重采样之后,尺寸就不再是512*512的。!!!
修改如下:

        ImageW = int(512 * real_resize_factor[0])ImageH = int(512 * real_resize_factor[1]) #如果图像进行旋转了,那么坐标也要旋转if isflip :    x = ImageW - xy = ImageH - y

2019.11.27发现新的两个bug。
1.注意:原始脚本(getMat.py)有误,修改如下!!!

        mat = pix_resampled[y - 20: y + 20, x - 20: x + 20,z - 12: z + 12]

原因:
这是因为PIL/opencv图片坐标和数组坐标问题。

  • 图像和数组的坐标均是左上角(0,0)
  • 图像的第一维度是横着的(从做到右),第二维度是竖着的(从上到下)。
  • 数组的第一维度是竖着的(从上到下),第二维度是横着的(从做到右)。
  • x 和 y 均是worldToVoxelCoord函数从世界坐标换算到图像坐标的值
  • 这个因为其中一个转置imgMat = image_array.transpose(1,2,0) #transpose是将(z,x,y)的三维矩阵转为(x,y,z)的矩阵,是把x轴放在了数组的第一个维度

一张图来说明这个问题: 这个图有误,详见下一个图。

这个bug是在可视化.mat文件的时候发现的。
出错图下图所示:


2.注意:原始脚本(traindataset.py)有误,修改如下!!!**
这个是在dataAugmentation函数中
修改部分:

......
dataMat = dataMat.transpose(2,1,0)#转置 从x,y,z到z,y,x
......
dataMat = dataMat[randZ : randZ + cropD, randY : randY + cropH, randX : randX + cropW]
......

这个原因是使用了pytorch的torch.nn.Conv3D
3D的卷积,输入的shape是(N,Cin,D,H,W)\left(N, C_{i n}, D, H, W\right)(N,Cin​,D,H,W),输出shape 是(N,Cout,Dout,Hout,Wout)\left(N, C_{o u t}, D_{o u t}, H_{o u t}, W_{o u t}\right)(N,Cout​,Dout​,Hout​,Wout​),z轴对应的是深度,y轴对应的H,x轴对应的是W,为了保证一致性,转置的地方需要修改
2019.12.5
1.2019.11.27发现的bug确实是bug,但解释的原因有误。
是在这里:imgMat = image_array.transpose(1,2,0) #transpose是将(z,x,y)的三维矩阵转为(x,y,z)的矩阵,正确的理解是#transpose是将(z,y,x)的三维矩阵转为(y,x,z)的矩阵
重新修改图示为:

2. 所以在traindataset.py中上述修改的这个dataMat = dataMat.transpose(2,1,0)#转置 从x,y,z到z,y,x有 误。正确的是:dataMat = dataMat.transpose(2,0,1)#转置 从y,x,z到z,y,x

1.样本生成

肺结节样本来自Luna16数据集,这里主要使用pytorch进行搭建分类网络,对疑似肺结节进行分类,进行假阳性剔除。在进行深度学习训练之前,需要完成样本集的生成。getMat.py脚本已上传到上述博客的资源中,这里对核心代码进行分析。

1.1尺度归一化

因为CT影像的采样间隔是不一致的,这里进行尺度归一化成1mm * 1mm * 1mm。也许在图像检测的时候这个尺度是否归一化可能影像不大,但是如果对单个肺结节进行分类的话,放缩到统一尺度下,是有利于分类的。

def resample(image, spacing, new_spacing=[1,1,1]):# Determine current pixel spacing#print("spacing:",spacing)#spacing = np.array(list(spacing))resize_factor = spacing / new_spacingnew_real_shape = image.shape * resize_factornew_shape = np.round(new_real_shape)real_resize_factor = new_shape / image.shapenew_spacing = spacing / real_resize_factor#print("image.shape",image.shape)#print("new_shape",new_shape)image = ndimage.interpolation.zoom(image, real_resize_factor, mode='nearest')return image, new_spacing, real_resize_factor

1.2换算坐标

        #换算成重采样后的坐标 与长宽高x = np.round(worldToVoxelCoord(x_ano, CT.x_offset, CT.x_ElementSpacing) * real_resize_factor[0]).astype(int)y = np.round(worldToVoxelCoord(y_ano, CT.y_offset, CT.y_ElementSpacing) * real_resize_factor[1]).astype(int)z = np.round(worldToVoxelCoord(z_ano, CT.z_offset, CT.z_ElementSpacing) * real_resize_factor[2]).astype(int)"""w = np.round(r/CT.x_ElementSpacing * real_resize_factor[0]).astype(int)h = np.round(r/CT.y_ElementSpacing * real_resize_factor[1]).astype(int)l = np.round(r/CT.z_ElementSpacing * real_resize_factor[2]).astype(int)"""#如果图像进行旋转了,那么坐标也要旋转if isflip :    """x_min = 512 - x_miny_min = 512 - y_minx_max = 512 - x_maxy_max = 512 - y_max"""x = 512 - xy = 512 - y

坐标要换算成尺度重采样之后的坐标,不能理解,但在这里我放弃使用换算之后的w,h,l作为肺结节的尺度信息。而是使用了固定尺寸24 x 40 x 40。
这里是参考了论文
《Accurate Pulmonary Nodule Detection in Computed Tomography Images Using Deep Convolutional Neural Networks》https://arxiv.org/pdf/1706.04303.pdf
本次网络借鉴了其中的假阳性剔除网络,(因为没有给出超参数,所以只好自己搭咯orz)。

注意:原始脚本有误,因为进行重采样之后,尺寸就不再是512*512的。!!!
修改如下:

        ImageW = int(512 * real_resize_factor[0])ImageH = int(512 * real_resize_factor[1]) #如果图像进行旋转了,那么坐标也要旋转if isflip :    x = ImageW - xy = ImageH - y

1.3保存.Mat文件

        mat = pix_resampled[x - 20: x + 20, y - 20: y + 20,z - 12: z + 12]#对mat尺寸进行判断x_mat, y_mat, z_mat = mat.shapeif x_mat * y_mat * z_mat < 38400:print("mat error:mat.shape:{} \n ID:{} \n x_ano:{},y_ano:{},z_ano:{},x:{},y:{},z:{}" \.format(mat.shape, namePre, x_ano, y_ano, z_ano, x, y, z))continue#进行下一个循环if classMat == 1:numNoudle = numNoudle + 1 #结节计数for num in range(20):#连续复制20次matio.savemat(mat_path+'{:05d}.mat'.format(count), {'data': mat, 'class':classMat})count = count + 1else :io.savemat(mat_path+'{:05d}.mat'.format(count), {'data': mat, 'class':classMat})count = count + 1

这里需要注意一下:
1.本次切割的结节是来自data/CSVFILES/candidates.csv,里面有55W个疑似肺结节,其中判定为肺结节的只有1186个,也就是只有1/550的疑似肺结节是肺结节。
2.做个实验,只切割文件ID前5000的肺结节。
3.对判定是结节的数据进行复制20次(数据扩充)。
4.保存疑似肺结节的数据和类别。

2. Pytorch分类网络搭建

2.1定义数据获取的类

def getAllDataPath(dataPath):#获取路径下所有文件路径pathAll=[]for root, dirs, files in os.walk(dataPath):path = [os.path.join(root, name) for name in files]#print(path)pathAll.extend(path)return pathAll
#定义dataset的框架
class MyTrainData(data.Dataset):   #需要繼承data.Dataset def __init__(self, dataPath, cropSize, transform=None ): #初始化文件路進或文件名self.dataPath = getAllDataPath(dataPath)self.cropW = cropSize[0]self.cropH = cropSize[1]self.cropD = cropSize[2]    def __getitem__(self, idx):dataMatPath = self.dataPath[idx]#加载.mat文件#print("dataMatPath:", dataMatPath)load_data = sio.loadmat(dataMatPath)#获取Mat值和Class类别dataMat = load_data["data"]classMat = load_data["class"]#print("1. dataMat.shape:",dataMat.shape)#归一化到 0~1之间dataMat = dataMat/1.0#print("2. dataMat.shape:",dataMat.shape)dataMat = dataAugmentation(dataMat, self.cropW, self.cropH, self.cropD)#裁剪 翻转对称#print("3. dataMat.shape:",dataMat.shape)#添加大小为1的维度#dataMat = torch.unsqueeze(dataMat, 0) # 在第0个维度上扩展return dataMat, classMat #返回Mat的数据(numpy)和类别(0、1)  def __len__(self):return len(self.dataPath)

数据类均使用MyTrainData加载样本生成的.Mat文件

2.2 数据增强

#数据增强
#@torchsnooper.snoop()
def dataAugmentation(dataMat, cropW, cropH, cropD):#随机裁剪dataMat = dataMat.transpose(2,0,1)#转置 从x,y,z到z,x,yD, W, H = dataMat.shape#转换成张量dataMat = torch.from_numpy(dataMat)#print("dataMat.shape: {},W:{},H:{},D:{} \n cropW:{},cropH:{},cropD:{}" \#     .format(dataMat.shape,W,H,D,cropW,cropH,cropD))randX = random.randint(0, W - cropW)randY = random.randint(0, H - cropH)randZ = random.randint(0, D - cropD)#print("randX:{},randY:{},randZ:{}".format(randX,randY,randZ))"""print("randX: ",randX)print("randY: ",randY)print("randZ: ",randZ)"""dataMat = dataMat[randZ : randZ + cropD, randX : randX + cropW, randY : randY + cropH]#随机翻转randDim = random.randint(0, 7)dims = ((0,),(1,),(2,),(0,1),(0,2),(1,2),(0,1,2))if randDim < 7:        dataMat = torch.flip(dataMat, dims[randDim])return dataMat

增强方式:
1.把40 x 40 x 24随机裁剪成cropW x cropH x cropD(本次设计成363620,数据可以扩充125倍)。
2.对肺结节(class = 1)进行翻转对称,扩充8倍。

2.3 分类网络搭建

def conv3x3(in_channel, out_channel, stride=1):return nn.Conv2d(in_channel, out_channel, 3, stride=stride, padding=1, bias=False)
# Conv3d的规定输入数据格式为(batch, channel, Depth, Height, Width)
def conv3x3x3(in_channel, out_channel, stride=1):return nn.Conv3d(in_channel,out_channel,kernel_size=(3,3,3),stride=stride,padding=1,dilation=1,groups=1,bias=False)
#3d残差块
class residual_block_3d(nn.Module):def __init__(self, in_channel, out_channel, same_shape=True):super(residual_block_3d, self).__init__()self.same_shape = same_shapestride = 1 if self.same_shape else 2self.conv1 = conv3x3x3(in_channel, out_channel, stride=stride)self.bn1 = nn.BatchNorm3d(out_channel)self.conv2 = conv3x3x3(out_channel, out_channel)self.bn2 = nn.BatchNorm3d(out_channel)if not self.same_shape:self.conv3 = nn.Conv3d(in_channel, out_channel, 1, stride=stride)#self.max = nn.MaxPool3d(kernel_size = (1,2,2),stride = (1,2,2))def forward(self, x):out = self.conv1(x)out = F.relu(self.bn1(out), True)out = self.conv2(out)out = F.relu(self.bn2(out), True)if not self.same_shape:x = self.conv3(x)#print("x.shape ",x.shape)return F.relu(x + out, True)#实现一个 ResNet3d,它就是 residual block 3d模块的堆叠
class resnet3d(nn.Module):def __init__(self, in_channel, num_classes, verbose=False):super(resnet3d, self).__init__()self.verbose = verbose#1*24*40*40self.block1 = nn.Conv3d(in_channel, 64, 1, 1)#64*20*36*36 self.block2 = nn.Sequential(nn.MaxPool3d(kernel_size = (1,2,2),stride = (1,2,2)),#64*20*18*18residual_block_3d(64, 64),residual_block_3d(64, 64),#64*20*18*18torch.nn.Dropout(0.5))self.block3 = nn.Sequential(residual_block_3d(64, 128, False),#128*10*9*9nn.Conv3d(128, 128, kernel_size=(1,2,2),stride=1, padding=0, dilation=1, groups=1,bias=False),#128*10*8*8residual_block_3d(128, 128),#128*10*8*8torch.nn.Dropout(0.5))self.block4 = nn.Sequential(residual_block_3d(128, 256, False),#256*5*4*4nn.Conv3d(256, 256, kernel_size=(2,1,1),stride=1, padding=0, dilation=1, groups=1,bias=False),#256*4*4*4residual_block_3d(256, 256),#256*4*4*4torch.nn.Dropout(0.5))self.block5 = nn.Sequential(residual_block_3d(256, 512, False),#512*2*2*2residual_block_3d(512, 512),#512*2*2*2nn.AvgPool3d((2,2,2),1),#512*1*1*1torch.nn.Dropout(0.5))self.classifier = nn.Linear(512, num_classes)self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.block1(x)if self.verbose:print('block 1 output: {}'.format(x.shape))x = self.block2(x)if self.verbose:print('block 2 output: {}'.format(x.shape))x = self.block3(x)if self.verbose:print('block 3 output: {}'.format(x.shape))x = self.block4(x)if self.verbose:print('block 4 output: {}'.format(x.shape))x = self.block5(x)if self.verbose:print('block 5 output: {}'.format(x.shape))x = x.view(x.shape[0], -1)x = self.classifier(x)x = self.sigmoid(x)#归一化到 0~1之间#print("end: ",x.shape)return x

因为本次数据是3D,所以所有的操作都必须是三维的。本次分类网络较简单,使用3D残差块堆叠而成。

3.训练+测试

3.1 加载数据与模型

    #os.makedirs() 方法用于递归创建目录。model = resnet3d(1,1).to(device)#Get dataloadertrian_dataset = MyTrainData(opt.train_path, opt.crop_size)test_dataset = MyTrainData(opt.test_path, opt.crop_size)    train_data = torch.utils.data.DataLoader(trian_dataset,batch_size=opt.batch_size,shuffle=True,num_workers=opt.n_cpu,pin_memory=True,)test_data = torch.utils.data.DataLoader(test_dataset,batch_size=opt.batch_size,shuffle=True,num_workers=opt.n_cpu,pin_memory=True,)#优化器optimizer = torch.optim.Adam(model.parameters())#损失函数criterion = nn.BCELoss()

3.2 训练/测试结果

我在网络的最后一层使用的是全连接层到一个神经元上,并使用sigmoid函数使其值限制在0~1之间。在制作数据的时候,定义:1为真肺结节,0为假肺结节。所以在预测的时候必须设定阈值opt.thresh。本次设定的为0.5,当预测值大于0.5时,另其为1(真肺结节),否则另其为0(假肺结节)。

            predict = torch.tensor(outputs)#复制predict[predict >= opt.thresh] = 1predict[predict < opt.thresh] = 0

那么正确率的判断为:

correct += (predict == targets).squeeze().sum().cpu().numpy()

有**加号“+”**是因为在一个epoch中有很多个batch,每当累计一定的batch时,则会计算其平均准确度,所以会使用加号累计正确的个数。
在本次测试时使用5000+个.mat文件,batchsize设置为64,epoch为40,达到以下效果:

训练集到达98~99%左右,测试集97%左右。

脚本资源已上传,按需下载。所有核心代码已在本博客讲解。

思考:

1.数据样本扩充的必要性。
2.增加focal loss之后,准确度是否还会提高。
3.对肺结节进行单独测试,因为真实肺结节只占很小一部分,所以分类器的准确度很有可能在骗人

实战:使用Pytorch搭建分类网络(肺结节假阳性剔除)相关推荐

  1. 使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记

    使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记 https://www.bilibili.com/video/BV1rq4y1w7xM?spm_id_from=33 ...

  2. [深度应用]·实战掌握PyTorch图片分类简明教程

    [深度应用]·实战掌握PyTorch图片分类简明教程 个人网站--> http://www.yansongsong.cn 项目GitHub地址--> https://github.com/ ...

  3. Pytorch搭建FCN网络

    Pytorch搭建FCN网络 前言 原理 代码实现 前言 FCN 全卷积网络,用卷积层替代CNN的全连接层,最后通过转置卷积层得到一个和输入尺寸一致的预测结果: 原理 为了得到更好的分割结果,论文中提 ...

  4. 使用PyTorch搭建ResNet50网络

    ResNet18的搭建请移步:使用PyTorch搭建ResNet18网络并使用CIFAR10数据集训练测试 ResNet34的搭建请移步:使用PyTorch搭建ResNet34网络 ResNet101 ...

  5. pytorch 搭建 VGG 网络

    目录 1. VGG 网络介绍 2. 搭建VGG 网络 3. code 1. VGG 网络介绍 VGG16 的网络结构如图: VGG 网络是由卷积层和池化层构成基础的CNN 它的CONV卷积层的参数全部 ...

  6. Pytorch搭建LeNet5网络

    本讲目标:   介绍Pytorch搭建LeNet5网络的流程. Pytorch八股法搭建LeNet5网络 1.LeNet5网络介绍 2.Pytorch搭建LeNet5网络 2.1搭建LeNet网络 2 ...

  7. CNN简单实战:PyTorch搭建CNN对猫狗图片进行分类

    在上一篇文章:CNN训练前的准备:PyTorch处理自己的图像数据(Dataset和Dataloader),大致介绍了怎么利用pytorch把猫狗图片处理成CNN需要的数据,今天就用该数据对自己定义的 ...

  8. pytorch搭建孪生网络比较人脸相似性

    参考文献: 神经网络学习小记录52--Pytorch搭建孪生神经网络(Siamese network)比较图片相似性_Bubbliiiing的博客-CSDN博客_神经网络图片相似性 Python - ...

  9. 实战:使用Faster R-CNN完成肺结节检测(LUNA16)(1)/LUNA16数据集初探

    LUNA16 LUNA16,全称Lung Nodule Analysis 16.该数据集来自另一个更大的数据集LIDC-IDRI,旨在推动更多计算机视觉领域的SOTA算法用于CAD领域.官方网站 比赛 ...

最新文章

  1. 隐马尔可夫模型:HMM
  2. plsql job执行多个存储过程_spring-boot-micro-job一款分布式任务调度执行框架
  3. shell tr 替换 空格_shell tr命令
  4. CDU集训代码:基础算法和数据结构2
  5. android - 小技巧合集(不断更新)
  6. python开发软件的实例-如何编写Python软件开发文档(7个技巧)
  7. 谈谈主策划需要的能力
  8. 第一次使用Pocket-PowerBuilder和开发DLL的经历
  9. win7小工具打不开_有了这个工具,小白也能设置一键网络共享文件夹与打印机...
  10. 金蝶KIS专业版V14.1生产任务单|销售单等单据图片打印
  11. 百度云下载不限速方式集合
  12. Matlab逻辑运算符/与/或/非/异或/all/any
  13. 悲剧四个月python培训班,需要踩完坑犯过错,这些免费的编程资源,值得一生推
  14. jQuery 案例-图片抽奖
  15. aardio 模拟键盘按键,实现msgbox对话框自动关闭
  16. Tensorflow Test1
  17. 原来工业互联网和工业物联网是两个东西啊
  18. 现代化医院PACS/RIS系统概述
  19. 2021 CCF 非专业级别软件能力认证第一轮(CSP-S1)提高级
  20. AI项目被谷歌撂挑子 美国防部愤而狂挖硅谷AI人才

热门文章

  1. C语言实现人民币小写转大写
  2. chrome浏览器安全检查_为您的Chrome浏览器检查皮肤
  3. paperswithcode使用方法
  4. 根据体重和身高获取BMI值
  5. 家庭教育:怎样安慰不想上学的人
  6. ZCANPRO 周立功CAN通道配置方法
  7. 高性能浏览器网络(High Performance Browser Networking) 第四章
  8. iOS 在TabBarController视图切换的时候添加动画
  9. OpenJudge百炼习题解答(C++)--题4040:买书问题
  10. Android客户端连接服务器- OKHttp的简单实用方法