在文章Keras入门(一)搭建深度神经网络(DNN)解决多分类问题中,笔者介绍了如何搭建DNN模型来解决IRIS数据集的多分类问题。
  本文将在此基础上介绍如何在Keras中实现K折交叉验证。

什么是K折交叉验证?

  K折交叉验证是机器学习中的一个专业术语,它指的是将原始数据随机分成K份,每次选择K-1份作为训练集,剩余的1份作为测试集。交叉验证重复K次,取K次准确率的平均值作为最终模型的评价指标。一般取K=10,即10折交叉验证,如下图所示:

  用交叉验证的目的是为了得到可靠稳定的模型。K折交叉验证能够有效提高模型的学习能力,类似于增加了训练样本数量,使得学习的模型更加稳健,鲁棒性更强。选择合适的K值能够有效避免过拟合。

Keras实现K折交叉验证

  我们仍采用文章Keras入门(一)搭建深度神经网络(DNN)解决多分类问题中的模型,如下:

同时,我们对IRIS数据集采用10折交叉验证,完整的实现代码如下:

# -*- coding: utf-8 -*-
# model_train.py
# Python 3.6.8, TensorFlow 2.3.0, Keras 2.4.3
# 导入模块
import keras as K
import pandas as pd
from sklearn.model_selection import KFold# 读取CSV数据集
# 该函数的传入参数为csv_file_path: csv文件路径
def load_data(sv_file_path):iris = pd.read_csv(sv_file_path)target_var = 'class'  # 目标变量# 数据集的特征features = list(iris.columns)features.remove(target_var)# 目标变量的类别Class = iris[target_var].unique()# 目标变量的类别字典Class_dict = dict(zip(Class, range(len(Class))))# 增加一列target, 将目标变量转化为类别变量iris['target'] = iris[target_var].apply(lambda x: Class_dict[x])return features, 'target', iris# 创建模型
def create_model():init = K.initializers.glorot_uniform(seed=1)simple_adam = K.optimizers.Adam()model = K.models.Sequential()model.add(K.layers.Dense(units=5, input_dim=4, kernel_initializer=init, activation='relu'))model.add(K.layers.Dense(units=6, kernel_initializer=init, activation='relu'))model.add(K.layers.Dense(units=3, kernel_initializer=init, activation='softmax'))model.compile(loss='sparse_categorical_crossentropy', optimizer=simple_adam, metrics=['accuracy'])return modeldef main():# 1. 读取CSV数据集print("Loading Iris data into memory")n_split = 10features, target, data = load_data("./iris_data.csv")x = data[features]y = data[target]avg_accuracy = 0avg_loss = 0for train_index, test_index in KFold(n_split).split(x):print("test index: ", test_index)x_train, x_test = x.iloc[train_index], x.iloc[test_index]y_train, y_test = y.iloc[train_index], y.iloc[test_index]print("create model and train model")model = create_model()model.fit(x_train, y_train, batch_size=1, epochs=80, verbose=0)print('Model evaluation: ', model.evaluate(x_test, y_test))avg_accuracy += model.evaluate(x_test, y_test)[1]avg_loss += model.evaluate(x_test, y_test)[0]print("K fold average accuracy: {}".format(avg_accuracy / n_split))print("K fold average accuracy: {}".format(avg_loss / n_split))main()

模型的输出结果如下:

Iteration loss accuracy
1 0.00056 1.0
2 0.00021 1.0
3 0.00022 1.0
4 0.00608 1.0
5 0.21925 0.8667
6 0.52390 0.8667
7 0.00998 1.0
8 0.04431 1.0
9 0.14590 1.0
10 0.21286 0.8667
avg 0.11633 0.9600

10折交叉验证的平均loss为0.11633,平均准确率为96.00%。

总结

  本文代码已存放至Github,网址为:https://github.com/percent4/Keras-K-fold-test 。
  感谢大家的阅读~
  2020.1.24于上海浦东

Keras入门(八)K折交叉验证相关推荐

  1. 5折交叉验证_[Machine Learning] 模型评估——交叉验证/K折交叉验证

    首先区分两个概念:'模型评估' 与 '模型性能度量' 模型评估:这里强调的是如何划分和利用数据,对模型学习能力的评估,重点在数据的划分方法. Keywords: 划分.利用数据 模型性能度量:是在研究 ...

  2. 交叉验证(cross validation)是什么?K折交叉验证(k-fold crossValidation)是什么?

    交叉验证(cross validation)是什么?K折交叉验证(k-fold crossValidation)是什么? 交叉验证(cross validation)是什么?  交叉验证是一种模型的验 ...

  3. 机器学习(MACHINE LEARNING)交叉验证(简单交叉验证、k折交叉验证、留一法)

    文章目录 1 简单的交叉验证 2 k折交叉验证 k-fold cross validation 3 留一法 leave-one-out cross validation 针对经验风险最小化算法的过拟合 ...

  4. 【Python-ML】SKlearn库Pipeline工作流和K折交叉验证

    # -*- coding: utf-8 -*- ''' Created on 2018年1月18日 @author: Jason.F @summary: Pipeline,流水线工作流,串联模型拟合. ...

  5. K折交叉验证(StratifiedKFold与KFold比较)

    文章目录 一.交叉验证 二.K折交叉验证 KFold()方法 StratifiedKFold()方法 一.交叉验证 交叉验证的基本思想是把在某种意义下将原始数据(dataset)进行分组,一部分做为训 ...

  6. k折交叉验证法python实现_Jason Brownlee专栏| 如何解决不平衡分类的k折交叉验证-不平衡分类系列教程(十)...

    作者:Jason Brownlee 编译:Florence Wong – AICUG 本文系AICUG翻译原创,如需转载请联系(微信号:834436689)以获得授权 在对不可见示例进行预测时,模型评 ...

  7. 机器学习--K折交叉验证(K-fold cross validation)

    K 折交叉验证(K-flod cross validation) 当样本数据不充足时,为了选择更好的模型,可以采用交叉验证方法. 基本思想:把给定的数据进行划分,将划分得到的数据集组合为训练集与测试集 ...

  8. k折交叉验证matlab 流程_第51集 python机器学习:分层K折交叉验证及其他方式

    由于出现类似鸢尾花数据集这种分段数据可能简单的交叉验证无法适用,所以这里引用了分层K折交叉验证.在分层交叉验证中,我们划分数据,使得每个折中类别之间的比例整数与数据集中的比例相同,如下图所示: mgl ...

  9. k折交叉验证优缺点_R语言中K邻近算法的初学者指南:从菜鸟到大神(附代码&链接)...

    作者:Leihua Ye, UC Santa Barbara 翻译:陈超 校对:冯羽 本文约2300字,建议阅读10分钟 本文介绍了一种针对初学者的K临近算法在R语言中的实现方法. 本文呈现了一种在R ...

  10. k折交叉验证优缺点_都说K折交叉验证最常见,你会做吗?

    在临床研究领域,大家特别希望能够未仆先知,于是临床研究者尝试去建立各种预测模型.比如,凭借孕妇的信息预测低出生体重儿的结局.怎么建立预测模型呢?常见的做法是这样的:以低出生体重儿为因变量,以相关的孕妇 ...

最新文章

  1. 异常:android.os.NetworkOnMainThreadException
  2. 无限极分类,子集跟着父集排列,用于后台显示菜单管理
  3. 灰盒测试—数据库软件
  4. 小工匠聊架构 - 构建架构思维
  5. Unity的Lerp函数实现缓动
  6. linux apache gzip filters,Linux Apache2如何开启gzip (deflate module) 压缩功能
  7. 刷新器-Java EE 7后端十大功能
  8. node中的Stream-Readable和Writeable解读
  9. 修改marathon源码后,如何编译,部署到集群中?
  10. 李宏毅2017机器学习homework1-利用gradient descent拟合宝可梦CP值代码并利用adagrad进行优化
  11. Greenplum数据库配置管理-参数配置管理和常用参数优化建议
  12. 车牌识别,车辆检测,车牌检测和识别,与车相关的点点滴滴
  13. 2019年1月2日申请美国F1学生签证记录
  14. 辅修计算机编程,求帮忙~计算机C语言的编程题!大学选的辅修课没去过,要考试了不会? 爱问知识人...
  15. 【朝花夕拾】Android自定义View篇之(六)Android事件分发机制(中)从源码分析事件分发机制...
  16. roscore失败,提示RLException: Unable to contact my own server
  17. 爬取某猫即将上映电影数据,写入excel保存
  18. 传统企业数字转型,主要面临哪些问题?
  19. HDU 6187 Destroy Walls
  20. 上海人工智能强在哪?不妨看魔都AI企业50家

热门文章

  1. 医疗服务机器人市场复合年增长率将达15.7%
  2. 白话区块链 之1: 为什么账本要这么记?
  3. ubuntu文件名乱码(转载)
  4. 基于 OpenLayers3 实现的 HTML5 GIS 电信资源管理系统
  5. fluidsim元件库下载_模块七FluidSIM软件应用 (1)
  6. 总结:Postman测试、IP:POST测试、Postman转换到python测试(Linux下的docker应用部署web容器并存储数据到mysql,调用API)
  7. 格雷码和二进制码的互相转换
  8. 医院管理系统服务器,医院管理的十大运行系统!
  9. 拼多多运营模式分析 | 如何杀出电商重围?
  10. Java聊天室系统的设计与实现(完整源码 sql文件 论文)