import nibabel as nib
import scipy.io as sio
import torch.optim as optim
from torch.utils import Dataset
from torch.utils.data import DataLoader
from sklearn.model_selection import LeaveOneOut

torch.utils.data.DataLoader 为Pytorch 数据读取的重要接口,接口定义在dataloader.py 脚本,用Pytorch训练模型必须调用此接口,该接口用于自定义的数据接口的输出,或者Pytorch已有的数据读取接口的输入 按照batch_size封装成Tensor,后续只需要在包装成variable 即作为模型输入,此接口为数据的承上、启下。

class TensorDataset(Dataset):def __init__(self,data_tensor,target_tensor):self.data_tensor=data_tensorself.target_tensor=target_tensordef __getitem__(self,index):return self.data_tensor[index],self.target_tensor[index]def __len__(self):return self.data_tensor.size(0)

TensorDataset 继承Dataset,重载__init__,getitem,len;共用类,数据集写入后共同继承,写入形式 data_tensor,target_tensor,得出data的索引,data的长度。

def read_data():data=nib.load(path)return data
data_dir='your_data_path'
data_list = os.listdir(data_dir)
data=np.ones((30,121,145,121))#输入sMRI为nii,维度121*145*121,共30例数据,构建data的容器
for i in range(0,len(data_list)):path = os.path.join(data_dir,data_list[i])if os.path.isfile(path):temp=read_data(path).get_data()temp=torch.form_numpy(temp)temp=temp.type(torch.FloatTensor)#array转为tensortemp=temp.premute(2,1,0)#对数据形式转换data[i,:,:,:]=temp#temp 赋值给data后三维
label=np.ones([30,1])
label([0:19])=0
label=torch.LongTensor(label)

完成数据集构造
定义CNN卷积层

class Net(nn.Module):def __init__(self):super(Net,self).__init__()"""in_channels=?  nn.Linear(?,) 输入格式:样本数量*通道数*宽*高*帧数"""self.conv1=nn.Conv3d(in_channels=1,out_channels=4,kernel_size=(6,6,6),stride=1)#输入1*1*121*145*121 一层卷积输出 2*1*116*140*116self.max_pool1=nn.MaxPool3d(kernel_size=(3,3,3),stride=2)#1层池化输出 4*1*57*69*57self.conv2=nn.Conv3d(in_channels=4,out_channels=8,kernel_size=(6,6,6),stride=1)#输入4*1*57*69*57 二层卷积输出 8*1*52*64*52 2层池化 4*1*25*31*25self.max_pool2=nn.MaxPool3d(kernel_size=(3,3,3),stride=2)self.conv3=nn.Conv3d(in_channels=8,out_channels=16,kernel_size=(6,6,6),stride=1)#输入4*1*25*31*25 三层卷积输出 16*1*20*26*20self.max_pool3=nn.MaxPool3d(kernel_size=(3,3,3),stride=2)#输入16*1*20*26*20 输出16*1*9*12*9self.linear1 =nn.Linear(16*1*9*12*9,360)#全连接层输入接口为卷积层结果,卷积层输出维数太高,线性层输入太大,效果影响self.linear2 =nn.Linear(360,120)self.linear3 =nn.Linear(120,2)def forward(self,x):x=F.relu(self.conv1(x))x=self.max_pool1(x)x=F.relu(self.conv2(x))x=self.max_pool2(x)x=F.relu(self.conv3(x))x=self.max_pool3(x)x=x.contiguous().view(-1,16*9*12*9)#调整结构,与全连接层进入的结构匹配  contiguous :连续函数x=F.relu(self.linear1(x))x=F.relu(self.linear2(x))x=self.linear3(x)return x
net=Net()

super(Net,self).init() super函数是继承父类调用一个函数,包含了父类函数的所有性质,此处flag。data.contigugous().view(-1,self.num_flat_features(data)) 把-1至最后卷积层输出的矩阵展成一维数组,以便进入后续线性变化层操作。

数据集训练测试

loo=LeaveOneOut()
all_correct=0
for train_index,test_index in loo.split(data):train_data=data[train_index,:,:,:]test_data=data[test_index,:,:,:]train_label=label(train_index)test_label=label(test_index)train_data=torch.FloatTensor(train_data)test_data=torch.FloatTensor(test_data)train_data=train_data.unsqueeze(1)test_data=test_data.unsqueeze(1)#3维卷积输入为5维数据,在数据进入网络前完成维度匹配trainset=TensorDataset(train_data,train_label)trainloader=DataLoader(trainset,bitch_size=1,shuffle=True,num_workers=0)criterion=nn.CrossEntropyLoss()#出来结果为分类,采用交叉熵损失函数 criterion 规则optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9) #优化函数  momentum 动量for  epoch in range(10):runnging_loss=0for i ,input_data in enumerate(trainloader,0):inputs,labels=input_dataoptimizer.zero_grad()outputs= net(inputs)loss=criterion(outputs,labels.view(-1))loss.backward()optimizer.step()running_loss +=loss.item()if i %8=7:print('[%d,%5d] loss:%0.3f' %(epoch +1,i+1,running_loss/8)) #每8个样本损失函数求一次平均running_loss=0.0with torch.no_grad():outputs=net(test_data)_,predicted = torch.max(outputs,1)print('predicted:'+str(predicted))if predicted == test_label:correct=1all_correct+=correctprint('all_correct:'+str(all_correct))
print('mean_correct:' +str(all_correct/10))
print('%:'+str(all_correct*100/300))

torch.optim 优化网络结构,用optim创建optimizer对象,用于保存当前的参数状态,并基于计算梯度进行更新。
构造损失函数:
criterion=nn.CrossEntropyLoss() 梯度下降,通过criterion(outputs,labels.view(-1)).backward() 梯度反向传播,多次迭代,获取最优结果。
构造优化器:optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0.9) 或者 optimizer=optim.Adam([variable1,vaiable2],lr=0.0001)
model.parameters()获取model网络的参数,SGD参数有opfunc、x、config、state .其中config 配置变量用于优化梯度下降,防止求得的最优解为局部最优解而非全局最优。config中的配置变量包括:learningRate(梯度下降速率),learningRateDecay(梯度下降速率衰减),weightDecay(权重衰减),momentum(动量),其中:learningRate较小时,收敛到极值的速度慢,较大时,容易在搜索过程中发生震荡,产生过拟合;而weightDecay 为有效限制模型中自由参数数量避免过拟合,调整成本函数,其方法可以通过在权重上引入零均值高斯先验值,将代价函数改为E(w)=E(w)+λ2w2.在实现过程中,惩罚较大的权重,并有效限制模型的自由度,而正则化参数λ决定如何将原始成本E 与大权重惩罚进行折衷。 learningRateDecay 越小,lr衰减越慢,当decay=0时,lr保持不变;decay越大,lr衰减越快,当decay=1时,lr衰减最快;momentum 冲量表示损失对时间上的累计效应。v = -dxlr + momentumv ;当本次梯度下降 -dx*lr的方向与上次更新的V方向相同时,上次的更新量能够对本次的搜索起到正向加速的作用;当方向相反时,上次的更新量对本次的搜索起到减速的作用。
epoch
当一个完整的数据集通过网络一次并返回一次预测分类结果,整个过程为一个epoch。epoch=n,完成在网络中的多次传递,而非一次,在有限的数据集,使用迭代过程梯度下降,更新权重,拟合的过程从欠拟合 过渡 到过拟合。

epoch 的选取,根据不同的数据集选定,数据的多样性会影响到合适epoch的数据,如二分类的数据集 与多分类数据集,epoch不同。
batch_size
Dataloader(trainset,batch_size=1,shuffle=True,num_workers=0)
数据集不能一次性通过网络,将数据集分成N个batch,batch_size 批量大小决定训练一次的样本数目。batch_size 影响到模型的优化程度和速度,在内存效率和内存容量之间寻找最佳平衡。
batch_size取值:

全批次:(蓝色)针对小样本数据集,采用全部数据集作为每次训练样本数目,全数据集确定的方向能够更好的代表样本总体,从而更准确朝向极值所在的方向。
Mini-batch:(绿色) 选的合适的batch_size,数据输入深度学习网络,计算这个batch的所有样本的平均损失,即代价函数为所有样本的平均。
Stochasti(随机)(红色): batch_size=1,随机取入1个,每次修正方向以各自样本的梯度方向修正,方向多样性 使结果难达到收敛。

适当增加batch_size: 可以并行化提供内存利用率,让GPU内存满载运行,提高训练速度;单次的epoch迭代次数减小,提高运行速度。(单次epoch=(全部训练样本/batchsize)/iteration=1 )训练集1000个样本,batch_size=10,训练完整个样本集需要:1次epoch,100次iteration;同时梯度下降方向准确度增加,训练过程中震荡幅度减小。Batch_size 过小,训练数据难收敛,从而underfitting。batch_size 增加后处理速度加快,但所需内存容量增加,往往出现内存爆炸, epoch增加过多使运算过程增加耗时,运算速度下降。没使用batch_size 之前(默认batch_size=1),意味网络训练时,把一次所有的数据输入网络,然后计算梯度进行反向传播,使用整个样本集,计算出的不同梯度值差别巨大,难以使用一个全局学习率。
for i ,input in enumerate(trainloader,0) enumerate()用于将一个可遍历的数据对象(list(),tuple(),str() )组合为一个索引序列,同时列出trainloader的数据和数据下标,对应到input 和i;
with torch.no_grad() torch.mo_grad()为一个上下文管理器,被该语句wrap(包)起来的部分将不会track(追踪)梯度。torch.no_grad()也可以作为装饰器,在网络测试的函数前加上。

@torch.no_grad()def eval():.....

在训练集中使用model.train(),在测试集中使用model.eval(),主要时保证模型的在测试阶段参数不进行更新。with torch.no_grad 在eval阶段,使用no_grad 让梯度Autograd=False,而训练时默认的是True,主要保证方向过程为纯粹的测试,而不变参数。同时也避免参数设置,节约GPU底层的时间开销。
_,predicted = torch.max(outputs,1) torch.max()的第一个输入为tensor结构,第二个参数1 代表dim,取每一行的最大值,也就是取预测概率最大的index,第三个默认参数loss 为torch.autograd.Variable格式

准确值的累计:correct +=(predicted==labels).sum()

参考:添加链接描述

sMRI影像数据3维CNN卷积相关推荐

  1. 深度学习--TensorFlow(项目)识别自己的手写数字(基于CNN卷积神经网络)

    目录 基础理论 一.训练CNN卷积神经网络 1.载入数据 2.改变数据维度 3.归一化 4.独热编码 5.搭建CNN卷积神经网络 5-1.第一层:第一个卷积层 5-2.第二层:第二个卷积层 5-3.扁 ...

  2. DeepLearning tutorial(4)CNN卷积神经网络原理简介+代码详解

    FROM: http://blog.csdn.net/u012162613/article/details/43225445 DeepLearning tutorial(4)CNN卷积神经网络原理简介 ...

  3. Deep Learning论文笔记之(五)CNN卷积神经网络代码理解

    Deep Learning论文笔记之(五)CNN卷积神经网络代码理解 zouxy09@qq.com http://blog.csdn.net/zouxy09          自己平时看了一些论文,但 ...

  4. Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现

    Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现 zouxy09@qq.com http://blog.csdn.net/zouxy09          自己平时看了一些论文, ...

  5. TensorFlow 2.0 mnist手写数字识别(CNN卷积神经网络)

    TensorFlow 2.0 (五) - mnist手写数字识别(CNN卷积神经网络) 源代码/数据集已上传到 Github - tensorflow-tutorial-samples 大白话讲解卷积 ...

  6. CNN卷积神经网络分析

    五月两场 | NVIDIA DLI 深度学习入门课程 5月19日/5月26日一天密集式学习  快速带你入门阅读全文> 正文共2566个字,4张图,预计阅读时间13分钟. CNN最大的优势在特征提 ...

  7. python神经网络案例——CNN卷积神经网络实现mnist手写体识别

    分享一个朋友的人工智能教程.零基础!通俗易懂!风趣幽默!还带黄段子!大家可以看看是否对自己有帮助:点击打开 全栈工程师开发手册 (作者:栾鹏) python教程全解 CNN卷积神经网络的理论教程参考 ...

  8. 深度学习之CNN卷积神经网络

    详解卷积神经网络(CNN) 卷积神经网络(Convolutional Neural Network, CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出 ...

  9. cnn(卷积神经网络)比较系统的讲解

    本文整理了网上几位大牛的博客,详细地讲解了CNN的基础结构与核心思想,欢迎交流. [1]Deep learning简介 [2]Deep Learning训练过程 [3]Deep Learning模型之 ...

  10. Deep Learning模型之:CNN卷积神经网络(一)深度解析CNN

    本文整理了网上几位大牛的博客,详细地讲解了CNN的基础结构与核心思想,欢迎交流. [1]Deep learning简介 [2]Deep Learning训练过程 [3]Deep Learning模型之 ...

最新文章

  1. 经典mysql数据库表案例_MySQL数据库的“十宗罪”(附10大经典错误案例)
  2. java 遍历xml子节点,Axiom解析XML,axiomxml,1、遍历XML全部节点,
  3. 为什么 WebAssembly 更快?
  4. C语言:构建一个二级链表并完成增删改查
  5. Angular中使用JS实现路由跳转、动态路由传值、get方式传值
  6. java putnextentry_Java对zip格式压缩和解压缩
  7. TEG《选择》乘风破浪 · 披荆斩棘
  8. 已知函数func的C语言代码框架,第三章习题-ddg..doc
  9. Interlocked原子访问系列函数
  10. 网卡驱动怎么安装方法教程
  11. C++11多线程join()和detach()
  12. PMP试题 | 每日一练,快速提分 9.1
  13. 日常开发效率神器【Hutool工具类】的使用
  14. 网易视频云:搜索意图识别浅析
  15. android模拟器 vm版,怎样用vmware虚拟机安装android模拟器
  16. 插件!最好用的翻译插件!!
  17. C++ stack 遍历
  18. Xgen Animwires 随笔01
  19. HBuilder X下载安装,运行微信小程序教程(官网)
  20. OpenCV数字图像处理学习平台

热门文章

  1. ue4 材质翻转法线开关控制
  2. Windows下使用Thunderbird实现邮箱的加密解密和签名认证
  3. android A problem occurred starting process
  4. 奇点云宣布完成1.2亿元B1轮融资,首次公开数据星图
  5. vue打包篇-分析包数据再进行CDN配置图片文件压缩等优化
  6. 【可达编程】P0063. 小武老师的烤全羊
  7. JAVA自学-day11-eclipse工具的使用、API、Object类
  8. 入职阿里巴巴,成为年薪百万阿里P7高级架构师需要必备哪些技术栈,带你来观望一下
  9. c语言解除键盘锁定,笔记本键盘被锁怎么办|笔记本解除键盘锁定的四种方法
  10. Todo Tree插件配置