一.前言

这篇博客是为了记录论文补充实验中所遇到的问题,以及解决方法,主要以程序的形式呈现。

二.对象

深度学习框架:keras
研究对象:两分类/多分类

三.技术杂谈

1.K-FOLD交叉验证

1.概念
对一个模型进行K次训练,每次训练将整个数据集分为随机的K份,K-1作为训练集,剩余的1份作为验证集,每次训练结束将验证集上的性能指标保存下来,最后对K个结果进行平均得到最终的模型性能指标。
2.优缺点
优点:模型评估更加鲁棒
缺点:训练时间加大
3.代码
① sklearn与keras独立使用

from sklearn.model_selection import StratifiedKFold
import numpyseed = 7  # 随机种子
numpy.random.seed(seed)  # 生成固定的随机数
num_k = 5  # 多少折# 整个数据集(自己定义)
X =
Y = kfold = StratifiedKFold(n_splits=num_k, shuffle=True, random_state=seed)  # 分层K折,保证类别比例一致cvscores = []
for train, test in kfold.split(X, Y):# 可以用sequential或者function的方式建模(自己定义)model =     model.compile()  # 自定义# 模型训练model.fit(X[train], Y[train], epochs=150, batch_size=10, verbose=0)# 模型测试scores = model.evaluate(X[test], Y[test], verbose=0)print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))  # 打印出验证集准确率cvscores.append(scores[1] * 100)print("%.2f%% (+/- %.2f%%)" % (numpy.mean(cvscores), numpy.std(cvscores)))  # 输出k-fold的模型平均和标准差结果

② sklearn与keras结合使用

from keras.wrappers.scikit_learn import KerasClassifier  # 使用keras下的sklearn API
from sklearn.cross_validation import StratifiedKFold, cross_val_score
import numpy as npseed = 7  # 随机种子
numpy.random.seed(seed)  # 生成固定的随机数
num_k = 5  # 多少折# 整个数据集(自己定义)
X =
Y = # 创建模型
def model():# 可以用sequential或者function的方式建模(自己定义)model =  return model  model = KerasClassifier(build_fn=model, epochs=150, batch_size=10)kfold = StratifiedKFold(Y, n_folds=num_k, shuffle=True, random_state=seed)results = cross_val_score(model, X, Y, cv=kfold)print(np.average(results))  # 输出k-fold的模型平均结果

补充:引入keras的callbacks
只需要在①②中的model.fit中加入一个arg:callbacks=[keras.callbacks.ModelCheckpoint()] # 这样可以保存下模型的权重,当然了你也可以使用callbacks.TensorBoard保存下训练过程

2.二分类/多分类评价指标

1.概念
二分类就是说,一个目标的标签只有两种之一(例如:0或1,对应的one-hot标签为[1,0]或[0,1])。对于这种问题,一般可以采用softmax或者logistic回归来完成,分别采用cross-entropy和mse损失函数来进行网络训练,分别输出概率分布和单个的sigmoid预测值(0,1)。
多分类就是说,一个目标的标签是几种之一(如:0,1,2…)
2.评价指标
主要包含了:准确率(accuracy),错误率(error rate),精确率(precision),召回率(recall)= 真阳率(TPR)= 灵敏度(sensitivity),F1-measure(包含了micro和macro两种),假阳率(FPR),特异度(specificity),ROC(receiver operation characteristic curve)(包含了micro和macro两种),AUC(area under curve),P-R曲线(precision-recall),混淆矩阵
① 准确率和错误率
accuracy = (TP+TN)/ (P+N)或者accuracy = (TP+TN)/ (T+F)
error rate = (FP+FN) / (P+N)或者(FP+FN) / (T+F)
accuracy = 1 - error rate
可见:准确率、错误率是对分类器在整体数据上的评价指标。
② 精确率
precision=TP /(TP+FP)
可见:精确率是对分类器在预测为阳性的数据上的评价指标。
③ 召回率/真阳率/灵敏度
recall = TPR = sensitivity = TP/(TP+FN)
可见:召回率/真阳率/灵敏度是对分类器在整个阳性数据上的评价指标。
④ F1-measure
F1-measure = 2 * (recall * precision / (recall + precision))
包含两种:micro和macro(对于多类别分类问题,注意区别于多标签分类问题)
1)micro
计算出所有类别总的precision和recall,然后计算F1-measure
2)macro
计算出每一个类的precison和recall后计算F1-measure,最后将F1-measure平均
可见:F1-measure是对两个矛盾指标precision和recall的一种调和。
⑤ 假阳率
FPR=FP / (FP+TN)
可见:假阳率是对分类器在整个阴性数据上的评价指标,针对的是假阳。
⑥ 特异度
specificity = 1- FPR
可见:特异度是对分类器在整个阴性数据上的评价指标,针对的是真阴。
⑦ ROC曲线和AUC
作用:灵敏度与特异度的综合指标
横坐标:FPR/1-specificity
纵坐标:TPR/sensitivity/recall
AUC是ROC右下角的面积,越大,表示分类器的性能越好
包含两种:micro和macro(对于多类别分类问题,注意区别于多标签分类问题)
假设一共有M个样本,N个类别。预测出来的概率矩阵P(M,N),标签矩阵L (M,N)
1)micro
根据P和L中的每一列(对整个数据集而言),计算出各阈值下的TPR和FPR,总共可以得到N组数据,分别画出N个ROC曲线,最后取平均
2)macro
将P和L按行展开,然后转置为两列,最后画出一个ROC曲线
⑧ P-R曲线
横轴:recall
纵轴:precision
评判:1)直观看,P-R包围的面积越大越好,P=R的点越大越好;2)通过F1-measure来看
比较ROC和P-R: 当样本中的正、负比例不平衡的时候,ROC曲线基本保持不变,而P-R曲线变化很大,原因如下:
当负样本的比例增大时,在召回率一定的情况下,那么表现较差的模型必然会召回更多的负样本,TP降低,FP迅速增加(对于性能差的分类器而言),precision就会降低,所以P-R曲线包围的面积会变小。
⑨ 混淆矩阵
行表示的是样本中的一种真类别被预测的结果,列表示的是一种被预测的标签所对应的真类别。
3.代码
注意:以下的代码是合在一起写的,有注释。

from sklearn import datasets
import numpy as np
from sklearn.preprocessing import label_binarize
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, precision_score, accuracy_score,recall_score, f1_score,roc_auc_score, precision_recall_fscore_support, roc_curve, classification_report
import matplotlib.pyplot as pltiris = datasets.load_iris()
x, y = iris.data, iris.target
print("label:", y)
n_class = len(set(iris.target))
y_one_hot = label_binarize(y, np.arange(n_class))# alpha = np.logspace(-2, 2, 20)  #设置超参数范围
# model = LogisticRegressionCV(Cs = alpha, cv = 3, penalty = 'l2')  #使用L2正则化
model = LogisticRegression()  # 内置了最大迭代次数了,可修改
model.fit(x, y)
y_score = model.predict(x)  # 输出的是整数标签
mean_accuracy = model.score(x, y)
print("mean_accuracy: ", mean_accuracy)
print("predict label:", y_score)
print(y_score==y)
print(y_score.shape)
y_score_pro = model.predict_proba(x)  # 输出概率
print(y_score_pro)
print(y_score_pro.shape)
y_score_one_hot = label_binarize(y_score, np.arange(n_class))  # 这个函数的输入必须是整数的标签哦
print(y_score_one_hot.shape)obj1 = confusion_matrix(y, y_score)  # 注意输入必须是整数型的,shape=(n_samples, )
print('confusion_matrix\n', obj1)print(y)
print('accuracy:{}'.format(accuracy_score(y, y_score)))  # 不存在average
print('precision:{}'.format(precision_score(y, y_score,average='micro')))
print('recall:{}'.format(recall_score(y, y_score,average='micro')))
print('f1-score:{}'.format(f1_score(y, y_score,average='micro')))
print('f1-score-for-each-class:{}'.format(precision_recall_fscore_support(y, y_score)))  # for macro
# print('AUC y_pred = one-hot:{}\n'.format(roc_auc_score(y_one_hot, y_score_one_hot,average='micro')))  # 对于multi-class输入必须是proba,所以这种是错误的# AUC值
auc = roc_auc_score(y_one_hot, y_score_pro,average='micro')  # 使用micro,会计算n_classes个roc曲线,再取平均
print("AUC y_pred = proba:", auc)
# 画ROC曲线
print("one-hot label ravelled shape:", y_one_hot.ravel().shape)
fpr, tpr, thresholds = roc_curve(y_one_hot.ravel(),y_score_pro.ravel())   # ravel()表示平铺开来,因为输入的shape必须是(n_samples,)
print("threshold: ", thresholds)
plt.plot(fpr, tpr, linewidth = 2,label='AUC=%.3f' % auc)
plt.plot([0,1],[0,1], 'k--')  # 画一条y=x的直线,线条的颜色和类型
plt.axis([0,1.0,0,1.0])  # 限制坐标范围
plt.xlabel('False Postivie Rate')
plt.ylabel('True Positive Rate')
plt.legend()
plt.show()# p-r曲线针对的是二分类,这里就不描述了ans = classification_report(y, y_score,digits=5)  # 小数点后保留5位有效数字
print(ans)

本人现在的研究方向是:
图像的语义分割,如果有志同道合的朋友,可以组队学习
haiyangpengai@gmail.com qq:1355365561

keras sklearn下两分类/多分类的技术杂谈(交叉验证和评价指标)相关推荐

  1. ML之分类预测之ElasticNet:利用ElasticNet回归对二分类数据集构建二分类器(DIY交叉验证+分类的两种度量PK)

    ML之分类预测之ElasticNet:利用ElasticNet回归对二分类数据集构建二分类器(DIY交叉验证+分类的两种度量PK) 目录 输出结果 设计思路 核心代码 输出结果 设计思路 核心代码 # ...

  2. 【大数据专业】机器学习分类模型评估和优化之交叉验证的多种方法

    学习目标: 机器学习: 分类评估模型及优化之交叉验证 交叉验证的三种基本方法: 1.将拆分与评价合并执行 sklearn.model_selection.cross_val_score 2.同时使用多 ...

  3. R语言使用caret包的train函数构建xgboost模型(基于linear算法)模型构建分类模型、trainControl函数设置交叉验证参数、自定义调优评估指标

    R语言使用caret包的train函数构建xgboost模型(基于linear算法)模型构建分类模型.trainControl函数设置交叉验证参数.自定义调优评估指标.tuneLength参数和tun ...

  4. R语言使用caret包的train函数构建xgboost模型(基于gbtree算法)模型构建分类模型、trainControl函数设置交叉验证参数、自定义调优评估指标

    R语言使用caret包的train函数构建xgboost模型(基于gbtree算法)模型构建分类模型.trainControl函数设置交叉验证参数.自定义调优评估指标.tuneLength参数和tun ...

  5. R语言使用caret包的train函数构建多项式核SVM模型(多项式核函数)模型构建分类模型、trainControl函数设置交叉验证参数、自定义调优评估指标

    R语言使用caret包的train函数构建多项式核SVM模型(多项式核函数)模型构建分类模型.trainControl函数设置交叉验证参数.自定义调优评估指标.tuneLength参数和tuneGri ...

  6. R语言使用caret包的train函数构建xgboost模型(基于dart算法)模型构建分类模型、trainControl函数设置交叉验证参数、自定义调优评估指标

    R语言使用caret包的train函数构建xgboost模型(基于dart算法)模型构建分类模型.trainControl函数设置交叉验证参数.自定义调优评估指标.tuneLength参数和tuneG ...

  7. R语言使用caret包的train函数构建惩罚判别分析模型(pda)构建分类模型、trainControl函数设置交叉验证参数、自定义调优评估指标

    R语言使用caret包的train函数构建惩罚判别分析模型(pda)构建分类模型.trainControl函数设置交叉验证参数.自定义调优评估指标.tuneLength参数和tuneGrid参数超参数 ...

  8. R语言惩罚逻辑回归、线性判别分析LDA、广义加性模型GAM、多元自适应回归样条MARS、KNN、二次判别分析QDA、决策树、随机森林、支持向量机SVM分类优质劣质葡萄酒十折交叉验证和ROC可视化

    最近我们被客户要求撰写关于葡萄酒的研究报告,包括一些图形和统计输出. 介绍 数据包含有关葡萄牙"Vinho Verde"葡萄酒的信息.该数据集有1599个观测值和12个变量,分别是 ...

  9. 数据分享|R语言逻辑回归、线性判别分析LDA、GAM、MARS、KNN、QDA、决策树、随机森林、SVM分类葡萄酒交叉验证ROC...

    全文链接:http://tecdat.cn/?p=27384 在本文中,数据包含有关葡萄牙"Vinho Verde"葡萄酒的信息(点击文末"阅读原文"获取完整代 ...

最新文章

  1. 请问:这里的空应怎么填呀?
  2. 排序算法——(2)Python实现十大常用排序算法
  3. NGINX + TOMCAT7 + MEMCACHED 实现SESSION 共享
  4. SQL Server 2005 Service Broker 初探 [摘抄]
  5. 卸载oracle——详细版
  6. Java设计模式(四)行为型 设计模式
  7. 大数据在职研究生哪个好_在职研究生大数据专业怎么样?
  8. 可编程式坐标--单位圆坐标
  9. STM32 ADC多通道采样声音传感器和环境光传感器
  10. hbuilderx运行支付宝小程序
  11. 亚马逊账号被关联能申诉得回来吗
  12. Python cv2读取/存储图片中含中文路径失败的解决方法
  13. 心情是一盏温茶的宁静
  14. 甘超波:什么是个人定位
  15. TSN(时间敏感网络)纯干货分享
  16. ZStack - 创建云主机
  17. 蓝桥杯训练(python)Day2
  18. oracle数据库block、tigger、function、package
  19. 沪嘉杭共建G60科创走廊
  20. esp12s 第十二章 舵机控制

热门文章

  1. 计算机驱动伺服的程序,伺服调试软件V-ASSISTANT始终找不到驱动-工业支持中心-西门子中国...
  2. matlab机液位置伺服系统,基于MATLAB的电液位置伺服系统仿真分析
  3. 网页防篡改技术_大数据让档案“活”起来:用区块链技术防篡改,用量子加密技术防盗窃...
  4. Javaspring 7-13课 Spring Bean
  5. Python利用openpyxl来操作Excel(一)
  6. 周期(KMP算法之Next数组的性质)
  7. CSDN Cookbook by Eric
  8. C语言中 . 和 - 区别详解(举例解释)
  9. OneNET物联网云平台HTTP数据流上传与下发,使用Fiddler调试开关应用,stm32 esp8266物联网家居远程开关
  10. 数据结构|-二叉查找树(二叉搜索树)的链式存储结构的实现