在验证分类算法的好坏时,经常需要用到AUC曲线,而在做不同分类模型的对比实验时,需要将不同模型的AUC曲线绘制到一张图里。

计算机视觉——python在一张图中绘制多个模型的对比ROC线

  • 1. 小型分类模型对比,可以直接调用的
  • 2. 大型的CNN模型,无法直接得到结果。
    • 2.1 先分别运行每个分类模型,将预测的结果存入csv文件中。
    • 2.2 从csv文件读取每个模型的预测结果,绘制AUC曲线

1. 小型分类模型对比,可以直接调用的

用一样的数据集做示例,简单地直接分别得到每个分类模型预测的结果。

from sklearn.datasets import load_breast_cancer
from sklearn import metrics
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier
from sklearn.model_selection import train_test_split
import pylab as plt
import warnings; warnings.filterwarnings('ignore')dataset = load_breast_cancer()
data = dataset.data
target = dataset.target
X_train, X_test, y_train, y_test = train_test_split(data,target,test_size=0.2)
# 模型1调用sklearn中的RandomForestClassifier
rf1 = RandomForestClassifier(n_estimators=5)
rf1.fit(X_train, y_train)
pred1 = rf1.predict_proba(X_test)[:,1]
# 模型2调用sklearn中的ExtraTreesClassifier
rf2 = ExtraTreesClassifier(n_estimators=5)
rf2.fit(X_train, y_train)
pred2 = rf2.predict_proba(X_test)[:,1]# 画图部分
fpr1, tpr1, threshold1 = metrics.roc_curve(y_test, pred1)       # <class 'numpy.ndarray'> <class 'numpy.ndarray'>
roc_auc1 = metrics.auc(fpr1, tpr1)fpr2, tpr2, threshold2 = metrics.roc_curve(y_test, pred2)       # <class 'numpy.ndarray'> <class 'numpy.ndarray'>
roc_auc2 = metrics.auc(fpr2, tpr2)plt.figure(figsize=(6,6))
plt.title('Validation ROC')
plt.plot(fpr1, tpr1, 'b', label = 'RandomForestClassifier AUC = %0.3f' % roc_auc1)
plt.plot(fpr2, tpr2, 'b', label = 'ExtraTreesClassifier AUC = %0.3f' % roc_auc2)
plt.legend(loc = 'lower right')
plt.plot([0, 1], [0, 1],'r--')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.savefig("filename.png")
plt.show()
plt.close()

2. 大型的CNN模型,无法直接得到结果。

2.1 先分别运行每个分类模型,将预测的结果存入csv文件中。

from sklearn.datasets import load_breast_cancer
from sklearn import metrics
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier
from sklearn.model_selection import train_test_split
import pylab as plt
import warnings; warnings.filterwarnings('ignore')
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression   #线性回归
import csvdataset = load_breast_cancer()
data = dataset.data
target = dataset.target
X_train, X_test, y_train, y_test = train_test_split(data,target,test_size=0.2)
rf1 = RandomForestClassifier(n_estimators=5)
rf1.fit(X_train, y_train)
pred1 = rf1.predict_proba(X_test)[:,1]dataframe = pd.DataFrame({'label':y_test,'pred':pred1})
dataframe.to_csv("test1.csv",index=False,sep=',')rf2 = ExtraTreesClassifier(n_estimators=5)
rf2.fit(X_train, y_train)
pred2 = rf2.predict_proba(X_test)[:,1]dataframe = pd.DataFrame({'label':y_test,'pred':pred2})
dataframe.to_csv("test2.csv",index=False,sep=',')

2.2 从csv文件读取每个模型的预测结果,绘制AUC曲线

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, aucdef Draw_ROC(file1,file2):'''这里注意读取csv的编码方式,如果csv里有中文,在windows系统上可以直接用encoding='ANSI',但是到了Mac或者Linux系统上会报错:`LookupError: unknown encoding: ansi`。解决方法:1. 可以改成encoding='gbk';2. 或者把csv文件里的列名改成英文,就不用选择encoding的方式了。'''data1=pd.read_csv(file1, encoding='ANSI')data1=pd.DataFrame(data1)data2=pd.read_csv(file2, encoding='ANSI')data2=pd.DataFrame(data2)print(list(data1['label']), list(data1['pred']))print(list(data2['label']), list(data2['pred']))fpr_CSNN,tpr_CSNN,thresholds=roc_curve(list(data1['label']),list(data1['pred']))roc_auc_CSSSNN=auc(fpr_CSNN,tpr_CSNN)fpr_NN,tpr_NN,thresholds=roc_curve(list(data2['label']),list(data2['pred']))roc_auc_DL=auc(fpr_NN,tpr_NN)font = {'family': 'Times New Roman','size': 12,}'''这里很多电脑上也许默认是'DejaVu Sans'格式,但是在写论文时,往往需要'Times New Roman'格式,可以参考[这篇教程](https://blog.csdn.net/weixin_43543177/article/details/109723328)'''sns.set(font_scale=1.2)plt.rc('font',family='Times New Roman')plt.plot(fpr_NN,tpr_NN,'purple',label='NN_AUC = %0.2f'% roc_auc_DL)plt.plot(fpr_CSNN,tpr_CSNN,'blue',label='CSNN_AUC = %0.2f'% roc_auc_CSSSNN)plt.legend(loc='lower right',fontsize = 12)plt.plot([0,1],[0,1],'r--')plt.ylabel('True Positive Rate',fontsize = 14)plt.xlabel('Flase Positive Rate',fontsize = 14)plt.show()if __name__=="__main__":Draw_ROC('./test1.csv','./test2.csv')

计算机视觉——python在一张图中画多条ROC线相关推荐

  1. python在一张图上画多个线条

    python 在一张图上画多个roc ptyon在一张图上添加图例 python将多个roc曲线画到一张图上 说明 我写了一个画图函数,这个函数可以画很多图在一个图上: 可以自由的确定画图个数 调用 ...

  2. R语言可视化包ggplot2在一张图中画出两条线实战

    R语言可视化包ggplot2在一张图中画出两条线实战 目录 R语言可视化包ggplot2在一张图中画出两条线实战

  3. R语言ggplot2可视化、在一张图中画出两条曲线(two lines in same ggplot2 graph)、使用pdf函数将ggplot2可视化图像保存到指定目录的pdf格式文件中

    R语言ggplot2可视化.在一张图中画出两条曲线(two lines in same ggplot2 graph).使用pdf函数将ggplot2可视化图像保存到指定目录的pdf格式文件中 目录

  4. origin如何在一张图中画多种线

    origin如何在一张图中画多种线 1.选中第一条线的数据后,绘图 右键单击空白处 2.新建后,右键单击图层2 选择图层内容 首先选中所需数据,然后绘制

  5. MATLAB中利用cftool导出代码实现一张图中拟合多条平滑曲线

    MATLAB中自带的cftool拟合工具箱不能将多条曲线同时画在同一副图中,而常规的plot()函数又不能拟合平滑直线,接下来总结一种可以利用cftool导出的代码,在一张图中拟合多条平滑曲线. 首先 ...

  6. R语言使用pROC包在同一图中绘制两条ROC曲线并通过假设检验检验ROC曲线的AUC或者偏AUC的差异(输出p值)

    R语言使用pROC包在同一图中绘制两条ROC曲线并通过假设检验检验ROC曲线的AUC或者偏AUC的差异(输出p值) 目录

  7. matlab在一张图上画两条折线图,excel2013怎么在一张曲线图上绘制多条曲线?

    office软件每年的更新速度很快,虽然更新后的版本肯定能够实现更新前的功能,但是由于版式方面的改进,使得使用者初次使用时不是特别得心应手.下面重点讲述一下,如何利用excel2013在同一张图中做多 ...

  8. Matlab在一张图上画多条曲线或分别画

    1.在plot曲线时,有时想在一张图上重合画多条曲线,我们只需要在画图命令之前加上hold on就好,比如: t = 1:0.1:10: y1 = sin(2*pi*t); y1 = cos(2*pi ...

  9. python在图中画一条垂直线(matplotlib)

    matplotlib.pyplot.axvline https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.axvline.html?h ...

最新文章

  1. 分享cropper剪切单张图片demo
  2. Learning ROS: Service and Client (C++)
  3. [置顶]       ibatis做分页
  4. 第三次学JAVA再学不好就吃翔(part22)--匿名对象
  5. golang实现聊天室(三)
  6. php返回200,关于API 使用 HTTP 状态码还是全部返回 200
  7. Dependency Walker使用说明[转]
  8. 小程序字符串拼接_小程序突袭预约!Yeezy 350quot;氧化满天星quot;拼接配色本月发售!...
  9. eclipse New菜单项的显示问题
  10. L3-019 代码排版 (30 分)-PAT 团体程序设计天梯赛 GPLT
  11. GDB+coredump定位段错误
  12. SM3算法的C++实现(代码)
  13. 第四方支付跟第三方支付的区别,支付源码有什么用
  14. java 合并两个音频_如何利用音乐合成软件将多段音频合并为一段?快速合并音频的方法...
  15. PS的快捷键(7.14)
  16. 泰拉瑞亚 阿里云服务器搭建记录
  17. 458、Java框架112 -【MyBatis - 一级缓存、二级缓存】 2020.12.28
  18. Endnote格式下载
  19. 大数据专业毕业后前景如何?能做什么职位?
  20. git克隆项目带用户名密码

热门文章

  1. linux 连接宽带
  2. Android安卓平板设备获取唯一标识
  3. linux查看cpt硬盘命令,常用Linux命令、文件操作解压缩相关、Linux命令大全、测试查询...
  4. 20170609问题记录1
  5. 直升飞机java游戏_java飞机游戏
  6. 17225 狼人游戏
  7. 史上最全docker安装方法!
  8. Python实现淘宝京东秒杀!源码拿去吧!
  9. 图数据库解决了什么问题?和关系型数据库的对比有哪些优势
  10. oracle知识点总结