代码主要核心思想来自:https://www.cnblogs.com/JadenFK3326/p/12164519.html

K折交叉交叉验证的过程如下:

以200条数据,十折交叉验证为例子,十折也就是将数据分成10组,进行10组训练,每组用于测试的数据为:数据总条数/组数,即每组20条用于valid,180条用于train,每次valid的都是不同的。

(1)将200条数据,分成按照 数据总条数/组数(折数),进行切分。然后取出第i份作为第i次的valid,剩下的作为train

(2)将每组中的train数据利用DataLoader和Dataset,进行封装。

(3)将train数据用于训练,epoch可以自己定义,然后利用valid做验证。得到一次的train_loss和 valid_loss。

(4)重复(2)(3)步骤,得到最终的 averge_train_loss和averge_valid_loss

上述过程如下图所示:

上述的代码如下:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
import torch.nn.functional as F
from torch.autograd import Variable#####构造的训练集####
x = torch.rand(100,28,28)
y = torch.randn(100,28,28)
x = torch.cat((x,y),dim=0)
label =[1] *100 + [0]*100
label = torch.tensor(label,dtype=torch.long)######网络结构##########
class Net(nn.Module):#定义Netdef __init__(self):super(Net, self).__init__() self.fc1   = nn.Linear(28*28, 120) self.fc2   = nn.Linear(120, 84)self.fc3   = nn.Linear(84, 2)def forward(self, x):x = x.view(-1, self.num_flat_features(x)) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return xdef num_flat_features(self, x):size = x.size()[1:] num_features = 1for s in size:num_features *= sreturn num_features##########定义dataset##########
class TraindataSet(Dataset):def __init__(self,train_features,train_labels):self.x_data = train_featuresself.y_data = train_labelsself.len = len(train_labels)def __getitem__(self,index):return self.x_data[index],self.y_data[index]def __len__(self):return self.len########k折划分############
def get_k_fold_data(k, i, X, y):  ###此过程主要是步骤(1)# 返回第i折交叉验证时所需要的训练和验证数据,分开放,X_train为训练数据,X_valid为验证数据assert k > 1fold_size = X.shape[0] // k  # 每份的个数:数据总条数/折数(组数)X_train, y_train = None, Nonefor j in range(k):idx = slice(j * fold_size, (j + 1) * fold_size)  #slice(start,end,step)切片函数##idx 为每组 validX_part, y_part = X[idx, :], y[idx]if j == i: ###第i折作validX_valid, y_valid = X_part, y_partelif X_train is None:X_train, y_train = X_part, y_partelse:X_train = torch.cat((X_train, X_part), dim=0) #dim=0增加行数,竖着连接y_train = torch.cat((y_train, y_part), dim=0)#print(X_train.size(),X_valid.size())return X_train, y_train, X_valid,y_validdef k_fold(k, X_train, y_train, num_epochs=3,learning_rate=0.001, weight_decay=0.1, batch_size=5):train_loss_sum, valid_loss_sum = 0, 0train_acc_sum ,valid_acc_sum = 0,0for i in range(k):data = get_k_fold_data(k, i, X_train, y_train) # 获取k折交叉验证的训练和验证数据net =  Net()  ### 实例化模型### 每份数据进行训练,体现步骤三####train_ls, valid_ls = train(net, *data, num_epochs, learning_rate,\weight_decay, batch_size) print('*'*25,'第',i+1,'折','*'*25)print('train_loss:%.6f'%train_ls[-1][0],'train_acc:%.4f\n'%valid_ls[-1][1],\'valid loss:%.6f'%valid_ls[-1][0],'valid_acc:%.4f'%valid_ls[-1][1])train_loss_sum += train_ls[-1][0]valid_loss_sum += valid_ls[-1][0]train_acc_sum += train_ls[-1][1]valid_acc_sum += valid_ls[-1][1]print('#'*10,'最终k折交叉验证结果','#'*10) ####体现步骤四#####print('train_loss_sum:%.4f'%(train_loss_sum/k),'train_acc_sum:%.4f\n'%(train_acc_sum/k),\'valid_loss_sum:%.4f'%(valid_loss_sum/k),'valid_acc_sum:%.4f'%(valid_acc_sum/k))#########训练函数##########
def train(net, train_features, train_labels, test_features, test_labels, num_epochs, learning_rate,weight_decay, batch_size):train_ls, test_ls = [], [] ##存储train_loss,test_lossdataset = TraindataSet(train_features, train_labels) train_iter = DataLoader(dataset, batch_size, shuffle=True) ### 将数据封装成 Dataloder 对应步骤(2)#这里使用了Adam优化算法optimizer = torch.optim.Adam(params=net.parameters(), lr= learning_rate, weight_decay=weight_decay)for epoch in range(num_epochs):for X, y in train_iter:  ###分批训练 output  = net(X)loss = loss_func(output,y)optimizer.zero_grad()loss.backward()optimizer.step()### 得到每个epoch的 loss 和 accuracy train_ls.append(log_rmse(0,net, train_features, train_labels)) if test_labels is not None:test_ls.append(log_rmse(1,net, test_features, test_labels))#print(train_ls,test_ls)return train_ls, test_lsdef log_rmse(flag,net,x,y):if flag == 1: ### valid 数据集net.eval()output = net(x)result = torch.max(output,1)[1].view(y.size())corrects = (result.data == y.data).sum().item()accuracy = corrects*100.0/len(y)  #### 5 是 batch_sizeloss = loss_func(output,y)net.train()return (loss.data.item(),accuracy)loss_func = nn.CrossEntropyLoss() ###申明loss函
k_fold(10,x,label) ### k=10,十折交叉验证

上述代码中,直接按照顺序从x中每次截取20条作为valid,也可以先打乱然后在截取,这样效果应该会更好。如下所示:

import random
import torchx = torch.rand(100,28,28)
y = torch.randn(100,28,28)
x = torch.cat((x,y),dim=0)
label =[1] *100 + [0]*100
label = torch.tensor(label,dtype=torch.long)index = [i for i in range(len(x))]
random.shuffle(index)
x = x[index]
label = label[index]

pytorch - K折交叉验证过程说明及实现相关推荐

  1. Pytorch最简单的图像分类——K折交叉验证处理小型鸟类数据集分类2.0版本ing

    https://blog.csdn.net/hb_learing/article/details/110411532 https://blog.csdn.net/Pl_Sun/article/deta ...

  2. 在Mnist数据上使用k折交叉验证训练,pytorch代码到底怎么写

    前言 最近学到了K折交叉验证,已经迫不及待去实验一下他的效果是不是如老师讲的一样好,特此写下本文. 本文运行环境为:sklearn.pytorch .jupyter notebook k折交叉验证介绍 ...

  3. Python:K折交叉验证,将数据集分成训练集与测试集

    注意文件夹格式:父文件夹/类别/图像(同torch读取图像格式保存一致),传入路径为父文件夹路径. """ 对图像进行交叉验证, 用于检验分类效果 对每个类别的n张图像进 ...

  4. 学习笔记5-梯度爆炸和梯度消失(K折交叉验证)

    1.梯度消失.梯度爆炸 梯度消失和梯度爆炸 考虑到环境因素的其他问题 1.1 本次课程主要任务 了解学习梯度消失与梯度爆炸产生的原因以及怎么样解决. 考虑到环境因素的其他问题(协变量偏移,标签偏移,概 ...

  5. 5折交叉验证_[Machine Learning] 模型评估——交叉验证/K折交叉验证

    首先区分两个概念:'模型评估' 与 '模型性能度量' 模型评估:这里强调的是如何划分和利用数据,对模型学习能力的评估,重点在数据的划分方法. Keywords: 划分.利用数据 模型性能度量:是在研究 ...

  6. 交叉验证(cross validation)是什么?K折交叉验证(k-fold crossValidation)是什么?

    交叉验证(cross validation)是什么?K折交叉验证(k-fold crossValidation)是什么? 交叉验证(cross validation)是什么?  交叉验证是一种模型的验 ...

  7. 机器学习(MACHINE LEARNING)交叉验证(简单交叉验证、k折交叉验证、留一法)

    文章目录 1 简单的交叉验证 2 k折交叉验证 k-fold cross validation 3 留一法 leave-one-out cross validation 针对经验风险最小化算法的过拟合 ...

  8. 【Python-ML】SKlearn库Pipeline工作流和K折交叉验证

    # -*- coding: utf-8 -*- ''' Created on 2018年1月18日 @author: Jason.F @summary: Pipeline,流水线工作流,串联模型拟合. ...

  9. K折交叉验证(StratifiedKFold与KFold比较)

    文章目录 一.交叉验证 二.K折交叉验证 KFold()方法 StratifiedKFold()方法 一.交叉验证 交叉验证的基本思想是把在某种意义下将原始数据(dataset)进行分组,一部分做为训 ...

最新文章

  1. npm run dev 报错:missing script:dev
  2. final finally finalize 的区别
  3. 亿级短视频社交美拍架构实践
  4. unigui中TUniDBEdit的OnEndDrag问题
  5. react钩子_迷上了钩子:如何使用React的useReducer()
  6. Spring之AOP(面向切面编程)_入门Demo
  7. 设计模式学习之--Singleton(单例)模式
  8. 关于C语言中fseek函数的使用
  9. 处理一份内心煎熬的工作有两种方法——只有一种是正确的
  10. php实现五维雷达图,Unity 属性雷达图
  11. 【视频目标检测】|Towards High Performance Video Object Detection
  12. Quorum工作原理
  13. 【Agni-s Philosophy】使用的图形技术解说(后篇)Volume渲染和粒子处理
  14. 【天光学术】汉语言文学论文:浅谈农村初中文言文教学现状及有效策略
  15. 如何快速的了解某种数据库
  16. Vue2.0的三种常用传值方式、父传子、子传父、非父子组件传值
  17. Linux命令-程序启动
  18. 计算机工程学院运动会加油稿50字,运动会加油稿50字(30篇)
  19. SQL提高查询效率知识拾忆
  20. 论文篇 | 2020-Facebook-DETR :利用Transformers端到端的目标检测=>翻译及理解(持续更新中)

热门文章

  1. [转]最封闭的开源系统,话说Android的八宗罪
  2. mysql关于group by加count的优化
  3. 网易游戏大咖访谈实录丨走进《阴阳师:妖怪屋》的视听盛宴
  4. 如何清洁Mac键盘,显示器等
  5. 计算机毕业设计Java智能健身房管理(源码+系统+mysql数据库+lw文档)
  6. Android界面设计5--管理系统UI
  7. 射线追踪(ray tracing)
  8. html 代码 border,HTML Style border用法及代码示例
  9. 人工智能 归结原理实验报告
  10. SIP 协议消息应答代码解释详录