文章目录

  • 一. Scikit-Learn阀值概述
  • 二. 代码实现
  • 参考:

一. Scikit-Learn阀值概述

Scikit-Learn不允许直接设置阈值,但它可以得到决策分数,调用其decision_function()方法,而不是调用分类器的predict()方法,该方法返回每个实例的分数,然后使用想要的阈值根据这些分数进行预测。

对于这种分类问题,不同的分类阈值可以给出不同的输出结果,但是在sklearn中,无法直接通过直接修改阈值而输出结果,但是我们可以首先得到决策函数得到的结果,然后再手动确定阈值,得到预测的结果。
为了使得模型更加完善,我们需要选择合适的阈值,即使得准确率和召回率都比较大,因此在这里我们可以首先绘制出准确率和召回率随阈值的变化关系,然后再选择合适的阈值。

二. 代码实现

代码:

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold, cross_val_score
from sklearn.metrics import confusion_matrix,recall_score,classification_report
import itertoolsdef plot_confusion_matrix(cm, classes,title='Confusion matrix',cmap=plt.cm.Blues):"""This function prints and plots the confusion matrix."""plt.imshow(cm, interpolation='nearest', cmap=cmap)plt.title(title)plt.colorbar()tick_marks = np.arange(len(classes))plt.xticks(tick_marks, classes, rotation=0)plt.yticks(tick_marks, classes)thresh = cm.max() / 2.for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):plt.text(j, i, cm[i, j],horizontalalignment="center",color="white" if cm[i, j] > thresh else "black")plt.tight_layout()plt.ylabel('True label')plt.xlabel('Predicted label')data = pd.read_csv("E:/file/creditcard.csv")# 将金额数据处理成 范围为[-1,1] 之间的数值
# 机器学习默认数值越大,特征就越重要,不处理容易造成的问题是 金额这个特征值的重要性远大于V1-V28特征
data['normAmount'] = StandardScaler().fit_transform(data['Amount'].values.reshape(-1, 1))
# 删除暂时不用的特征值
data = data.drop(['Time','Amount'],axis=1)X = data.values[:, data.columns != 'Class']
y = data.values[:, data.columns == 'Class']# 获取异常交易数据的总行数及索引
number_records_fraud = len(data[data.Class == 1])
fraud_indices = np.array(data[data.Class == 1].index)# 获取正常交易数据的索引值
normal_indices = data[data.Class == 0].index# 在正常样本当中, 随机采样得到指定个数的样本, 并取其索引
random_normal_indices = np.random.choice(normal_indices, number_records_fraud, replace = False)
random_normal_indices = np.array(random_normal_indices)# 有了正常和异常的样本后把他们的索引都拿到手
under_sample_indices = np.concatenate([fraud_indices,random_normal_indices])# 根据索引得到下采样的所有样本点
under_sample_data = data.iloc[under_sample_indices,:]X_undersample = under_sample_data.loc[:, under_sample_data.columns != 'Class']
y_undersample = under_sample_data.loc[:, under_sample_data.columns == 'Class']# 对整个数据集进行划分, X为特征数据, Y为标签, test_size为测试集比列, random_state 为随机种子, 目的是使得每次随机的结果都一样
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)# 下采样数据集进行划分
X_train_undersample, X_test_undersample, y_train_undersample, y_test_undersample = train_test_split(X_undersample,y_undersample,test_size = 0.3,random_state = 0)# 计算混淆矩阵
lr = LogisticRegression(C=0.01, penalty='l2')
lr.fit(X_train_undersample, y_train_undersample.values.ravel())
y_pred_undersample_proba = lr.predict_proba(X_test_undersample.values)# 指定阀值
thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]plt.figure(figsize=(10, 10))j = 1
for i in thresholds:y_test_predictions_high_recall = y_pred_undersample_proba[:, 1] > iplt.subplot(3, 3, j)j += 1# Compute confusion matrixcnf_matrix = confusion_matrix(y_test_undersample, y_test_predictions_high_recall)np.set_printoptions(precision=2)print("Recall metric in the testing dataset: ", cnf_matrix[1, 1] / (cnf_matrix[1, 0] + cnf_matrix[1, 1]))# Plot non-normalized confusion matrixclass_names = [0, 1]plot_confusion_matrix(cnf_matrix, classes=class_names, title='Threshold >= %s' % i)plt.show()

测试记录:
Recall metric in the testing dataset: 0.9931972789115646
Recall metric in the testing dataset: 0.9523809523809523
Recall metric in the testing dataset: 0.9319727891156463
Recall metric in the testing dataset: 0.8979591836734694
Recall metric in the testing dataset: 0.8775510204081632
Recall metric in the testing dataset: 0.8503401360544217
Recall metric in the testing dataset: 0.8503401360544217
Recall metric in the testing dataset: 0.8163265306122449
Recall metric in the testing dataset: 0.7619047619047619

参考:

  1. https://study.163.com/course/introduction.htm?courseId=1003590004#/courseDetail?tab=1

Python数据分析与机器学习21- 逻辑回归项目实战5-阀值相关推荐

  1. 逻辑回归三部曲——逻辑回归项目实战(信贷数据+Python代码实现)

         逻辑回归已经在各大银行和公司都实际运用于业务,已经有很多前辈写过逻辑回归.本文将从我实际应用的角度阐述逻辑回归的由来,致力于让逻辑回归变得清晰.易懂.逻辑回归又叫对数几率回归,是一种广义线性 ...

  2. python之逻辑回归项目实战——信用卡欺诈检测

    信用卡欺诈检测 1.项目介绍 2.项目背景 3.分析项目 4.数据读取与分析 4.1 加载数据 4.2 查看数据的标签分布 5.数据预处理 5.1 特征标准化 5.2. 使用下采样解决样本数据不均衡 ...

  3. 跟着迪哥学python 经管之家_跟着迪哥学Python数据分析与机器学习实战

    本书结合了机器学习.数据分析和Python语言,通过案例以通俗易懂的方式讲解了如何将算法应用到实际任务. 全书共20章,大致分为4个部分.第一部分介绍了Python的工具包,包括科学计算库Numpy. ...

  4. 【机器学习】逻辑回归原理介绍

    [机器学习]逻辑回归原理介绍 [机器学习]逻辑回归python实现 [机器学习]逻辑回归sklearn实现 Logistic 回归模型是目前广泛使用的学习算法之一,通常用来解决二分类问题,虽然名字中有 ...

  5. 机器学习(三)逻辑回归以及python简单实现

    虽然有回归两个字,但是依然是解决的时分类问题,是最经典的二分类算法. 分类算法有很多,例如支持向量机和神经网络.而逻辑回归算法应用的比较广,往往是优先选择的算法. Sigmod函数 表达式: g(z) ...

  6. 机器学习:逻辑回归(logistics regression)

    title: 机器学习:逻辑回归(logistics regression) date: 2019-11-30 20:55:06 mathjax: true categories: 机器学习 tags ...

  7. 基于python的数据建模与分析案例_基于案例详解Python数据分析与机器学习

    课程概述: 使用数据领域最主流语言Python及其分析与建模库作为核心武器.对于机器学习经典算法给出完整的原理推导并基于实例进行讲解,基于案例演示如何应用机器学习算法解决实际问题. 课程特色: 通俗易 ...

  8. 机器学习-了解逻辑回归的逻辑过程

    机器学习-逻辑回归 预测乳腺癌案例 import numpy as np import pandas as pd # 机器学习 import sklearn # 逻辑回归 from sklearn.l ...

  9. 23神经网络 :唐宇迪《python数据分析与机器学习实战》学习笔记

    唐宇迪<python数据分析与机器学习实战>学习笔记 23神经网络 1.初识神经网络 百度深度学习研究院的图,当数据规模较小时差异较小,但当数据规模较大时深度学习算法的效率明显增加,目前大 ...

  10. 机器学习_2逻辑回归

    机器学习_逻辑回归 分类问题 二分类--Sigmoid函数 Sigmoid函数代码实现 逻辑回归 数学原理 求解方式 正则化 逻辑回归数据集应用样例--代码实现 样例1:有清晰的线性决策边界 决策边界 ...

最新文章

  1. R语言与数据分析(12)向量
  2. FD_WRITE触发条件
  3. 最舒适的路线 第六届
  4. USTC English Club Note20171021
  5. 星特朗望远镜怎么样_入手曝光评测双筒望远镜星特朗和博冠有何区别?哪个好?体验报告揭秘...
  6. Mysql 会导致锁表的语法
  7. java8 stream多次map_java8streamapi:如何将列表转换为在列表中具有重复键的MapLong,Set?...
  8. 5款最适合 Windows 命令行/控制台的替代品
  9. BZOJ2795: [Poi2012]A Horrible Poem
  10. 深入解读Docker底层技术cgroup系列(2)——cgroup的初始化
  11. CurvySplines基础
  12. 工作薄与工作表的区别:
  13. 设备树基本语法及属性分析
  14. 计算机硬件和软件的主要功能,网络技术在计算机软硬件的作用
  15. python多个箱线图_python-matplotlib | 箱线图及解读
  16. linux下文件对比工具详解(diff、diff3、sdiff、vimdiff和comm)
  17. 白菜个人导航页2.0
  18. JAVA与MAVEN打包
  19. github以网页的方式查看.html
  20. 英语数学不好能学Java吗?

热门文章

  1. Unity 音频理解与优化
  2. 仿牛客论坛项目(上)
  3. cocos2dx的图片加载
  4. 网狐棋牌手端无法进入登录页面
  5. 安卓开发的一些uuid,imei,meid,imsi,clientid,uuid
  6. QQ自动给好友点赞(网页的QQ空间)
  7. 分公司和子公司的法律地位
  8. python股票量化交易(3)---趋势类指标MACD
  9. IDEA安装、配置及卸载
  10. Micrel的1588方案