作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

本文网址:


目录

第1章 scikit-learn对决策树的支持

1.1 决策树的基本原理

1.2 决策树的核心问题

1.3 sklearn中的决策树

1.4 sklearn实现决策树的三步骤

第2章 代码实现示例

2.1 DecisionTreeClassifier类接口

2.2 构建数据集

2.3 构建模型并训练模型

2.4 显示训练好决策树

2.5 剪枝

2.6 精修参数max_features & min_impurity_decrease

2.7 确认最优的剪枝参数


第1章 scikit-learn对决策树的支持

1.1 决策树的基本原理

https://blog.csdn.net/HiWangWenBing/article/details/123340741

在这个决策过程中,我们一直在对记录的特征进行提问。

最初的问题所在的地方叫做根节点,

在得到结论前的每一个问题都是中间节点,

而得到的每一个结论(动物的类别)都叫做叶子节点。

关键概念:节点

根节点:没有进边,有出边。包含最初的,针对特征的提问。

中间节点:既有进边也有出边,进边只有一条,出边可以有很多条。都是针对特征的提问。

叶子节点:有进边,没有出边,每个叶子节点都是一个类别标签。

*子节点和父节点:在两个相连的节点中,更接近根节点的是父节点,另一个是子节点。

1.2 决策树的核心问题

决策树算法的核心是要解决两个问题:

1)如何从数据表中找出最佳根节点、最佳分枝、叶子节点?

2)如何让决策树停止生长,防止过拟合?

几乎所有决策树有关的模型调整方法,都围绕这两个问题展开。

1.3 sklearn中的决策树

本文重点关注:分类决策树tree.DecisionTreeClassifier和tree.export_graphviz。

1.4 sklearn实现决策树的三步骤

第2章 代码实现示例

2.1 DecisionTreeClassifier类接口

DecisionTreeClassifier(criterion='entropy', min_samples_leaf=3) 类为创建一个决策树模型,

其类的参数含义如下所示:

  • criterion:判决依据,gini或者entropy,前者是基尼系数,后者是信息熵。这两种差别不大,信息上的最大值为1,而激励系数最大值为0.5.
  • splitter: best or random,前者是在所有特征中找最好的切分点,后者是在部分特征中,默认的”best”适合样本量不大的时候,而如果样本数据量非常大,此时决策树构建推荐”random” 。
  • max_features:None(所有),log2,sqrt,N  特征小于50的时候一般使用所有的特征。
  • max_depth:  int or None, optional (default=None) 设置决策随机森林中的决策树的最大深度,深度越大,越容易过拟合,推荐树的深度为:5-20之间。
  • min_samples_split:设置结点的最小样本数量,当样本数量可能小于此值时,结点将不会在划分。
  • min_samples_leaf: 这个值限制了叶子节点包含的最少的样本数,如果某叶子节点数目小于样本数,则会和兄弟节点一起被剪枝,并合并到上一层节点中。防止过拟合。叶子节点包含的样本数少,表明该判断分支,只能判断出极少的样本,不具备普遍性, 叶子节点中,包含的样本数越多,则泛化能力越强,普遍性越强。
  • min_weight_fraction_leaf: 这个值限制了叶子节点所有样本权重和的最小值,如果小于这个值,则会和兄弟节点一起被剪枝默认是0,就是不考虑权重问题。
  • max_leaf_nodes: 通过限制最大叶子节点数,可以防止过拟合,默认是"None”,即不限制最大的叶子节点数。
  • class_weight: 指定样本各类别的的权重,主要是为了防止训练集某些类别的样本过多导致训练的决策树过于偏向这些类别。这里可以自己指定各个样本的权重,如果使用“balanced”,则算法会自己计算权重,样本量少的类别所对应的样本权重会高。
  • min_impurity_split: 这个值限制了决策树的增长,如果某节点的不纯度(基尼系数,信息增益,均方差,绝对差)小于这个阈值则该节点不再生成子节点。即为叶子节点 。即没有必要使得不纯度得到0,只要不纯度小于某个门限即可。

2.2 构建数据集

from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split# 加载内置数据集
wine = load_wine()
print("\n数据形状: 178个样本,13列特征\n",wine.data.shape)
print("\n分类标签名:\n",wine.target)# 显示数据集表格
import pandas as pdpd.concat([pd.DataFrame(wine.data),pd.DataFrame(wine.target)],axis=1)print("\n特征名:\n", wine.feature_names)
print("\n分类标签名:\n",wine.target_names)# 分割数据集
Xtrain, Xtest, Ytrain, Ytest = train_test_split(wine.data,wine.target,test_size=0.3)
print("\n训练集:", Xtrain.shape)
print("\n测试集:", Xtest.shape)
数据形状: 178个样本,13列特征(178, 13)分类标签名:[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 00 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 11 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 11 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 22 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]特征名:['alcohol', 'malic_acid', 'ash', 'alcalinity_of_ash', 'magnesium', 'total_phenols', 'flavanoids', 'nonflavanoid_phenols', 'proanthocyanins', 'color_intensity', 'hue', 'od280/od315_of_diluted_wines', 'proline']分类标签名:['class_0' 'class_1' 'class_2']训练集: (124, 13)测试集: (54, 13)

2.3 构建模型并训练模型

# 构建模型对象
#DecisionTreeObj = tree.DecisionTreeClassifier(criterion="entropy", random_state=30, splitter="random")
DecisionTreeObj = tree.DecisionTreeClassifier(criterion="entropy", random_state=30, splitter="random")# 用数据集训练模型,生成决策树
DecisionTreeObj = DecisionTreeObj.fit(Xtrain, Ytrain)# 评估模型:预测的准确度
score = DecisionTreeObj.score(Xtest, Ytest)
print("测试集打分:",score)score = DecisionTreeObj.score(Xtrain, Ytrain)
print("训练集打分:",score)
测试集打分: 0.9629629629629629
训练集打分: 1.0

在训练集上,准确率100%。

2.4 显示训练好决策树

# 展现决策树
# conda install graphviz
import graphvizfeature_name = ['酒精','苹果酸','灰','灰的碱性','镁','总酚','类黄酮','非黄烷类酚类','花青素','颜色强度','色调','稀释葡萄酒','脯氨酸']
print("特征名称:", feature_name)
print("特征个数:", len(feature_name))dot_data = tree.export_graphviz(DecisionTreeObj,out_file = None, feature_names = feature_name, class_names = ["分类1","分类2","分类3"]  #用来替代'class_0' 'class_1' 'class_2', filled=True    # 填充颜色, rounded=True   # 圆角图形)
graph = graphviz.Source(dot_data)
graph
特征名称: ['酒精', '苹果酸', '灰', '灰的碱性', '镁', '总酚', '类黄酮', '非黄烷类酚类', '花青素', '颜色强度', '色调', '稀释葡萄酒', '脯氨酸']
特征个数: 13

2.5 剪枝

在不加限制的情况下,一棵决策树会生长到衡量不纯度的指标最优,或者没有更多的特征可用为止。这样的决策树往往会过拟合,这就是说,它会在训练集上表现很好,在测试集上却表现糟糕。我们收集的样本数据不可能和整体的状况完全一致,因此当一棵决策树对训练数据有了过于优秀的解释性,它找出的规则必然包含了训练样本中的噪声,并使它对未知数据的拟合程度不足。

为了让决策树有更好的泛化性,我们要对决策树进行剪枝。剪枝策略对决策树的影响巨大,正确的剪枝策略是优化决策树算法的核心。sklearn为我们提供了不同的剪枝策略:
(1)max_depth
限制树的最大深度,超过设定深度的树枝全部剪掉这是用得最广泛的剪枝参数,在高维度低样本量时非常有效。

决策树多生长一层,对样本量的需求会增加一倍,所以限制树深度能够有效地限制过拟合。

在集成算法中也非常实用。

实际使用时,建议从=3开始尝试,看看拟合的效果再决定是否增加设定深度。逐步增加层数,训练后再测试集上检验,最后寻找最好的层数。

(2)min_samples_leaf

min_samples_leaf限定,一个节点在分枝后的每个子节点都必须包含至少min_samples_leaf个训练样本,否则分枝就不会发生,或者,分枝会朝着满足每个子节点都包含min_samples_leaf个样本的方向去发生。

一般搭配max_depth使用,在回归树中有神奇的效果,可以让模型变得更加平滑。

这个参数的数量设置得太小会引起过拟合,设置得太大就会阻止模型学习数据。

一般来说,建议从=5开始使用。如果叶节点中含有的样本量变化很大,建议输入浮点数作为样本量的百分比来使用。同时,这个参数可以保证每个叶子的最小尺寸,可以在回归问题
中避免低方差,过拟合的叶子节点出现。对于类别不多的分类问题,=1通常就是最佳选择。

(3)min_samples_split

min_samples_split限定,一个节点必须要包含至少min_samples_split个训练样本,这个节点才允许被分枝,否则分枝就不会发生。

min_samples_leaf限制的是子节点中样本的个数。

min_samples_split限制的是当前被分隔的节点中样本的个数。

节点中样本的个数越多,泛化能力越强!!!

DecisionTreeObj = tree.DecisionTreeClassifier(criterion="entropy",random_state=30,splitter="random",max_depth=3,min_samples_leaf=10,min_samples_split=10)DecisionTreeObj = DecisionTreeObj.fit(Xtrain, Ytrain)print("训练集分数:",DecisionTreeObj.score(Xtrain,Ytrain))print("测试集分数:",DecisionTreeObj.score(Xtest,Ytest))dot_data = tree.export_graphviz(DecisionTreeObj,feature_names= feature_name,class_names=["琴酒","雪莉","贝尔摩德"],filled=True,rounded=True) graph = graphviz.Source(dot_data)graph
训练集分数: 0.8709677419354839
测试集分数: 0.9259259259259259
训练集分数: 0.8709677419354839
测试集分数: 0.9259259259259259

2.6 精修参数max_features & min_impurity_decrease

一般max_depth使用,用作树的”精修“。
max_features限制分枝时考虑的特征个数,超过限制个数的特征都会被舍弃。和max_depth异曲同工,
max_features是用来限制高维度数据的过拟合的剪枝参数,但其方法比较暴力,是直接限制可以使用的特征数量而强行使决策树停下的参数,在不知道决策树中的各个特征的重要性的情况下,强行设定这个参数可能会导致模型学习不足。如果希望通过降维的方式防止过拟合,建议使用PCA,ICA或者特征选择模块中的降维算法。
min_impurity_decrease限制信息增益的大小,信息增益小于设定数值的分枝不会发生。这是在0.19版本中更新的功能,在0.19版本之前时使用min_impurity_split。

2.7 确认最优的剪枝参数

那具体怎么来确定每个参数填写什么值呢?这时候,我们就要使用确定超参数的曲线来进行判断了,继续使用我们已经训练好的决策树模型clf。超参数的学习曲线,是一条以超参数的取值为横坐标,模型的度量指标为纵坐标的曲线,它是用来衡量不同超参数取值下模型的表现的线。在我们建好的决策树里,我们的模型度量指标就是score。

import matplotlib.pyplot as plt# 存放参数调优后的打分值
test = []# 动态调参:max_depth
for i in range(10):clf = tree.DecisionTreeClassifier(max_depth=i+1,criterion="entropy",random_state=30,splitter="random")clf = clf.fit(Xtrain, Ytrain)score = clf.score(Xtest, Ytest)test.append(score)plt.plot(range(1,11),test,color="red",label="max_depth")
plt.legend()
plt.show()
[0.46296296296296297, 0.9259259259259259, 0.9444444444444444, 0.9629629629629629, 0.9629629629629629, 0.9629629629629629, 0.9629629629629629, 0.9629629629629629, 0.9629629629629629, 0.9629629629629629]

从上图可以看出:

当max_depth变为4之后,模型的打分的分数就不再进一步增长,也就是说,max_depth>4后,以及对模型的优化没有优化作用了。

备注:自动优化其他参数,也可以采用上述类似的方法。

2.8 目标权重参数:class_weight & min_weight_fraction_leaf

完成样本标签平衡的参数。

样本不平衡是指在一组数据集中,有一类样本特征的样本数量特别少,如标签的一类天生占有很大的比例。比如说,在银行要判断“一个办了信用卡的人是否会违约”,就是是vs否(1%:99%)的比例。

这种分类状况下,即便模型什么也不做,全把结果预测成“否”,正确率也能有99%。

因此我们要使用class_weight参数对样本标签进行一定的均衡,给少量的标签更多的权重,让模型更偏向少数类,向捕获少数类的方向建模。该参数默认None,此模式表示自动给与数据集中的所有标签相同的权重。
有了权重之后,样本量就不再是单纯地记录数目,而是受输入的权重影响了,因此这时候剪枝,就需要搭配min_weight_fraction_leaf这个基于权重的剪枝参数来使用。另请注意,基于权重的剪枝参数(例如min_weight_fraction_leaf)将比不知道样本权重的标准(比如min_samples_leaf)更少偏向主导类。如果样本是加权的,则使用基于权重的预修剪标准来更容易优化树结构,这确保叶节点至少包含样本权重的总和的一小部分。

2.9 其他重要的属性与API

属性是在模型训练之后,能够调用查看的模型的各种性质。对决策树来说,最重要的是feature_importances_,能够查看各个特征对模型的重要性。

sklearn中许多算法的接口都是相似的,比如说我们之前已经用到的fit和score,几乎对每个算法都可以使用。

除了这两个接口之外,决策树最常用的接口还有apply和predict。

apply中输入测试集返回每个测试样本所在的叶子节点的索引,

predict输入测试集返回每个测试样本的标签。返回的内容一目了然并且非常容易,

另外,所有接口中要求输入X_train和X_test的部分,输入的特征矩阵必须至少是一个二维矩阵。
sklearn不接受任何一维矩阵作为特征矩阵被输入。

如果你的数据的确只有一个特征,那必须用reshape(-1,1)来给矩阵增维;

如果你的数据只有一个特征和一个样本,使用reshape(1,-1)来给你的数据增维。

#apply返回每个测试样本在决策上中的位置
print("\n每个样本在决策树上的位置", clf.apply(Xtest))#predict返回每个测试样本的预测结果
print("\n每个样本的预测结果\n", clf.predict(Xtest))print("\n每个特征的重要程度:\n", clf.feature_importances_)
每个样本在决策树上的位置 [13  7 17 17 13 13 11 13 17 11 11  8  4  4  7  4  7 17 11 13 17 11  4 174  4 11 17  4 13  4  8 17  4 13 13  4  4 17 17  4 11 15 17 11 17  4 114  4  4 11 11  7]每个样本的预测结果[1 2 0 0 1 1 1 1 0 1 1 1 2 2 2 2 2 0 1 1 0 1 2 0 2 2 1 0 2 1 2 1 0 2 1 1 22 0 0 2 1 1 0 1 0 2 1 2 2 2 1 1 2]每个特征的重要程度:[0.22216858 0.         0.         0.03661358 0.02534749 0.0.44666311 0.         0.         0.         0.07418021 0.146158320.0488687 ]

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

本文网址:

[机器学习与scikit-learn-15]:算法-决策树-分类问题代码详解相关推荐

  1. 机器学习06|两万字:决策树 【jupyter代码详解篇】

    文章目录 任务一: 导入包和创建数据集 任务二:ID3树 2.1完成香农熵计算函数 2.2 完成基本功能函数 2.3 用信息增益选择待分类的特征 2.4 生成ID3决策树 备注: 任务三:C4.5树 ...

  2. 《机器学习实战》kNN算法及约会网站代码详解

    使用kNN算法进行分类的原理是:从训练集中选出离待分类点最近的kkk个点,在这kkk个点中所占比重最大的分类即为该点所在的分类.通常kkk不超过202020 kNN算法步骤: 计算数据集中的点与待分类 ...

  3. 基于PyTorch搭建CNN实现视频动作分类任务代码详解

    数据及具体讲解来源: 基于PyTorch搭建CNN实现视频动作分类任务 import torch import torch.nn as nn import torchvision.transforms ...

  4. 标准oc算法的推导与99行代码详解

    文章目录 标准oc算法的推导与代码详解 问题描述 OC算法的数学描述 结果展示 OC算法的matlab代码及注释 参考文献 标准oc算法的推导与代码详解 对于变密度的参数化方法,设计变量x为材料相对密 ...

  5. 机器学习应用篇(五)——决策树分类实例

    机器学习应用篇(五)--决策树分类实例 文章目录 机器学习应用篇(五)--决策树分类实例 一.数据集 二.实现过程 1 数据特征分析 2 利用决策树模型在二分类上进行训练和预测 3 利用决策树模型在多 ...

  6. kmeans python interation flag_机器学习经典算法-logistic回归代码详解

    一.算法简要 我们希望有这么一种函数:接受输入然后预测出类别,这样用于分类.这里,用到了数学中的sigmoid函数,sigmoid函数的具体表达式和函数图象如下: 可以较为清楚的看到,当输入的x小于0 ...

  7. 【OpenCV/C++】KNN算法识别数字的实现原理与代码详解

    KNN算法识别数字 一.KNN原理 1.1 KNN原理介绍 1.2 KNN的关键参数 二.KNN算法识别手写数字 2.1 训练过程代码详解 2.2 预测分类的实现过程 三.KNN算法识别印刷数字 2. ...

  8. [Pytorch系列-61]:循环神经网络 - 中文新闻文本分类详解-3-CNN网络训练与评估代码详解

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  9. python模拟手写笔迹_Python实现基于KNN算法的笔迹识别功能详解

    本文实例讲述了Python实现基于KNN算法的笔迹识别功能.分享给大家供大家参考,具体如下: 需要用到: Numpy库 Pandas库 手写识别数据 点击此处本站下载. 数据说明: 数据共有785列, ...

最新文章

  1. 金黄色的LED灯带感光特性测量
  2. 复选框 全选 全不选 反选 实现
  3. 机器学习线性回归算法实验报告_从零实现机器学习算法(九)线性回归
  4. JS中函数和变量声明的提升
  5. Java程序员应在2018年学习的3种JVM语言
  6. swift x输入流_Swift 中不同窗体的切换和传递数据 (segue 的用法)
  7. java中使用lua脚本
  8. 赚钱只要找到方法,就如吸空气
  9. c语言中sprintf函数_在C / C ++中使用sprintf()函数
  10. android 高通替换开机logo,高通平台 开机logo 替换
  11. 安装Windows XP出现0X0000007B的解决方法
  12. ODC预端接光纤配线箱
  13. 数据库概念设计与逻辑设计
  14. virt a mate(vam)版本1.20.77.9介绍和下载
  15. “生死看淡”的雷军要造车,这对中国的汽车产业意味着什么?
  16. Linux安装配置MySQL8.0 打war包 启动项目
  17. 绿色石化高质量发展 茂名天源石化碳三碳四资源利用项目开工
  18. 爬虫-大众点评评论信息(思路)
  19. Cadence Allegro如何加密PCB文件?
  20. 个人收款平台 XorPay 对比 Payjs

热门文章

  1. 学习Python, 没有工作经验没学历能找到工作吗?
  2. python3 获取商店里App评论+解析+存档+筛选
  3. 原创超简单代码(1.21.50)
  4. JZOJ 1403.渡河
  5. 《萌小甜动图字帖》使用简介
  6. 电商卖家如何有效提升转化率?
  7. ML.net 3-情绪预测
  8. Word文档插入图片的问题
  9. SyntaxError: Non-UTF-8 code starting with ‘\xc6‘ in file xxxbut no encoding declared
  10. HTML中的meta标签