原文来自https://www.cnblogs.com/bymo/p/9026198.html

记下来,以便以后忘记了可以找到

目录

  • 一.自动切分
  • 二.手动切分
  • 三.K折交叉验证(k-fold cross validation)

在训练深度学习模型的时候,通常将数据集切分为训练集和验证集.Keras提供了两种评估模型性能的方法:

  • 使用自动切分的验证集
  • 使用手动切分的验证集

回到顶部

一.自动切分

在Keras中,可以从数据集中切分出一部分作为验证集,并且在每次迭代(epoch)时在验证集中评估模型的性能.

具体地,调用model.fit()训练模型时,可通过validation_split参数来指定从数据集中切分出验证集的比例.

# MLP with automatic validation set
from keras.models import Sequential
from keras.layers import Dense
import numpy
# fix random seed for reproducibility
numpy.random.seed(7)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# create model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# Fit the model
model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10)

validation_split:0~1之间的浮点数,用来指定训练集的一定比例数据作为验证集。验证集将不参与训练,并在每个epoch结束后测试的模型的指标,如损失函数、精确度等。

注意,validation_split的划分在shuffle之前,因此如果你的数据本身是有序的,需要先手工打乱再指定validation_split,否则可能会出现验证集样本不均匀。 

回到顶部

二.手动切分

Keras允许在训练模型的时候手动指定验证集.

例如,用sklearn库中的train_test_split()函数将数据集进行切分,然后在kerasmodel.fit()的时候通过validation_data参数指定前面切分出来的验证集.

# MLP with manual validation set
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
import numpy
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# split into 67% for train and 33% for test
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.33, random_state=seed)
# create model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# Fit the model
model.fit(X_train, y_train, validation_data=(X_test,y_test), epochs=150, batch_size=10)

回到顶部

三.K折交叉验证(k-fold cross validation)

将数据集分成k份,每一轮用其中(k-1)份做训练而剩余1份做验证,以这种方式执行k轮,得到k个模型.将k次的性能取平均,作为该算法的整体性能.k一般取值为5或者10.

  • 优点:能比较鲁棒性地评估模型在未知数据上的性能.
  • 缺点:计算复杂度较大.因此,在数据集较大,模型复杂度较高,或者计算资源不是很充沛的情况下,可能不适用,尤其是在训练深度学习模型的时候.

sklearn.model_selection提供了KFold以及RepeatedKFold, LeaveOneOut, LeavePOut, ShuffleSplit, StratifiedKFold, GroupKFold, TimeSeriesSplit等变体.

下面的例子中用的StratifiedKFold采用的是分层抽样,它保证各类别的样本在切割后每一份小数据集中的比例都与原数据集中的比例相同.

# MLP for Pima Indians Dataset with 10-fold cross validation
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import StratifiedKFold
import numpy
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# define 10-fold cross validation test harness
kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)
cvscores = []
for train, test in kfold.split(X, Y):# create modelmodel = Sequential()model.add(Dense(12, input_dim=8, activation='relu'))model.add(Dense(8, activation='relu'))model.add(Dense(1, activation='sigmoid'))# Compile modelmodel.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])# Fit the modelmodel.fit(X[train], Y[train], epochs=150, batch_size=10, verbose=0)# evaluate the modelscores = model.evaluate(X[test], Y[test], verbose=0)print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))cvscores.append(scores[1] * 100)
print("%.2f%% (+/- %.2f%%)" % (numpy.mean(cvscores), numpy.std(cvscores)))

keras实现交叉验证以及K折交叉验证相关推荐

  1. 机器学习(MACHINE LEARNING)交叉验证(简单交叉验证、k折交叉验证、留一法)

    文章目录 1 简单的交叉验证 2 k折交叉验证 k-fold cross validation 3 留一法 leave-one-out cross validation 针对经验风险最小化算法的过拟合 ...

  2. 交叉验证(简单交叉验证、k折交叉验证、留一法)

    针对经验风险最小化算法的过拟合的问题,给出交叉验证的方法,这个方法在做分类问题时很常用: 一:简单的交叉验证的步骤如下: 1. 从全部的训练数据 S中随机选择 中随机选择 s的样例作为训练集 trai ...

  3. 交叉验证的几个方法的解释(简单交叉验证、k折交叉验证、留一法)

    针对经验风险最小化算法的过拟合的问题,给出交叉验证的方法,这个方法在做分类问题时很常用: 一:简单的交叉验证的步骤如下: 1. 从全部的训练数据 S中随机选择 中随机选择 s的样例作为训练集 trai ...

  4. 你真的会用K折交叉吗?对于K折交叉的思考 | K折交叉的坑

    本文目的: 对于K折交叉,想必大家都知道是什么原理.但是在具体实践中让你写的时候,你可能就会突然疑惑:"咦?道理我都懂,可是这个玩意儿到底怎么用." 本文就是为了探讨一下什么时候 ...

  5. k折交叉验证法python实现_Jason Brownlee专栏| 如何解决不平衡分类的k折交叉验证-不平衡分类系列教程(十)...

    作者:Jason Brownlee 编译:Florence Wong – AICUG 本文系AICUG翻译原创,如需转载请联系(微信号:834436689)以获得授权 在对不可见示例进行预测时,模型评 ...

  6. python机器学习库sklearn——交叉验证(K折、留一、留p、随机)

    分享一个朋友的人工智能教程.零基础!通俗易懂!风趣幽默!还带黄段子!大家可以看看是否对自己有帮助:点击打开 全栈工程师开发手册 (作者:栾鹏) python数据挖掘系列教程 学习预测函数的参数,并在相 ...

  7. python 实现k折交叉验证

    k折交叉验证原理: k折交叉验证是将数据分为k份,选取其中的k-1份为训练数据,剩余的一份为测试数据.k份数据循环做测试集进行测试.此原理适用于数据量小的数据. # k-折交叉验证(此处设置k=10) ...

  8. k折交叉验证优缺点_k折交叉验证(R语言)

    "机器学习中需要把数据分为训练集和测试集,因此如何划分训练集和测试集就成为影响模型效果的重要因素.本文介绍一种常用的划分最优训练集和测试集的方法--k折交叉验证." k折交叉验证 ...

  9. 参数调优:K折交叉验证与GridSearch网格搜索

    本文代码及数据集来自<Python大数据分析与机器学习商业案例实战> 一.K折交叉验证 在机器学习中,因为训练集和测试集的数据划分是随机的,所以有时会重复地使用数据,以便更好地评估模型的有 ...

  10. 【技术分享】什么是K折交叉验证?

    文章目录 1.什么是训练集.验证集和测试集? 2.什么是K折交叉验证? 3.数据集划分过程 3.应用场景及注意事项 3.1.应用场景 3.2.注意事项 1.什么是训练集.验证集和测试集? 训练集,即: ...

最新文章

  1. android -各种适配器
  2. Apache服务器部署(2)
  3. 夏日里的激情——FE鹅和鸭农庄行
  4. 《中国人工智能学会通讯》——12.58 大数据不确定性学习的研究
  5. docker构建镜像 发布镜像
  6. windbg bp condition
  7. ASP.NET身份验证机制membership入门——配置篇(1){转}
  8. hbase 的shell操作中相关属性说明
  9. android 如何使用aar,Android Studio如何使用aar依赖包?
  10. mac/linux 解决启动命令行出现declare问题
  11. mysql 传统数据恢复_MySQL误操作后如何快速恢复数据 传统解法 利用binlog2sql快速闪回 常见问题 参考资料...
  12. leetcode 题解 java_leetcode-java题解(每天更新)
  13. AR研究-Demo集
  14. Android开发系列(十七):读取assets文件夹下的数据库文件
  15. 6. 卷2(进程间通信)---System V 消息队列
  16. [转贴]怎样学好法语?
  17. 二维小游戏,飞机大战,图片素材
  18. 机器学习——共享单车数据集预测
  19. 按教师名单分配学生抽签程序
  20. 阿里安全人机行为识别比赛 前五名队伍分享

热门文章

  1. 计算机网络 - ECMAScript和Javascript、jscript关系
  2. 一流科技携手小米、旷视等多家企业共同发起成立中关村数智人工智能产业联盟...
  3. 放慢你的额脚步_放慢脚步使我成为更好的领导者
  4. 使用vuex实现一个简单的小应用
  5. druid加密数据库密码
  6. 第一届全国大学生GIS应用技能大赛试题答案及数据下载(下午)
  7. mysql通过Navcat 备份数据.psc 还原数据时 只有表没有数据解决方法
  8. LabVIEW 编程更改波形图Plots是否可见
  9. 数据结构系列之三红黑树
  10. Eclipse代码格式化无效解决方案