目录

  • 资源下载
  • 实现思路与核心函数解读
    • DecisionTreeClassifier分类决策树
    • tree.plot_tree决策树可视化
  • 1. 对决策树最大深度的研究与可视化
    • 绘图结果
    • 分析
  • 2. 对特征选择标准的研究与可视化
    • 绘图结果
    • 分析
  • 3. 对决策树其他参数的研究与可视化
    • 绘图结果
    • 分析
  • 总结

『机器学习』分享机器学习课程学习笔记,逐步讲述从简单的线性回归、逻辑回归到 ▪ 决策树算法 ▪ 朴素贝叶斯算法 ▪ 支持向量机算法 ▪ 随机森林算法 ▪ 人工神经网络算法 等算法的内容。

欢迎关注 『机器学习』 系列,持续更新中
欢迎关注 『机器学习』 系列,持续更新中

资源下载

拿来即用,所见即所得。
项目仓库:https://gitee.com/miao-zehao/machine-learning/tree/master


实现思路与核心函数解读

基于Python机器学习库sklearn.tree.DecisionTreeClassifier决策树分类模型,对鸢尾花数据集iris.csv建立决策树模型。

DecisionTreeClassifier分类决策树

sklearn.tree.DecisionTreeClassifier
创建一个决策树分类器模型实例:tree_model=DecisionTreeClassifier(criterion=“gini”,max_depth=3,random_state=0,splitter=“best”)

参数解读:

  • 特征选择标准 criterion: string, 默认是 “gini”) 设置为‘gini’(基尼系数)或是‘entropy’(信息熵)
  • 决策树最大深度 max_depth:int或None,可选(默认=None)树的最大深度。如果为 None,则扩展节点直到所有叶子都是纯的或直到所有叶子包含少于 min_samples_split 样本。
  • 随机数生成器使用的种子 random_state:int,建议设置一个常数保证在研究参数时不会被随机数干扰。如果为 None,则随机数生成器是 RandomState 使用的实例np.random
  • 拆分器 splitter:字符串,可选(默认=“最佳”)用于在每个节点处选择拆分的策略。默认是 default=”best”,或者是“random”

tree.plot_tree决策树可视化

sklearn.tree.plot_tree
创建一个决策树可视化实例:tree_model=DecisionTreeClassifier(criterion=“gini”,max_depth=3,random_state=0,splitter=“best”)

参数解读:

  • 模型对象名
  • feature_names 特征名称的列表
  • class_names 分类名称的列表

ps:可能会遇到如下报错,这个时候更新sklearn库到最新版本,报错就不会发生。


1. 对决策树最大深度的研究与可视化

将“决策树最大深度”分别设置为3和5建立决策树模型,并进行结果可视化,对比建模结果。

import matplotlib
import sklearn.tree
from sklearn.datasets import load_iris
# 导入决策树分类器
from sklearn.tree import DecisionTreeClassifier, plot_tree
# 导入分割数据集的方法
from sklearn.model_selection import train_test_split
# 导入科学计算包
import numpy as np
# 导入绘图库
import matplotlib.pyplot as plt# 加载鸢尾花数据集
iris_dataset = load_iris()
# 分割训练集与测试集
X_train,X_test,y_train,y_test=train_test_split(iris_dataset['data'],iris_dataset['target'],test_size=0.2,random_state=0)def Mytest_max_depth(my_max_depth):# 创建决策时分类器-tree_model=DecisionTreeClassifier(criterion="gini",max_depth=my_max_depth,random_state=0,splitter="best")# - 特征选择标准 criterion: string, 默认是 “gini”) 设置为‘gini’(基尼系数)或是‘entropy’(信息熵)# - 决策树最大深度 max_depth:int或None,可选(默认=None)树的最大深度。如果为 None,则扩展节点直到所有叶子都是纯的或直到所有叶子包含少于 min_samples_split 样本。# - 随机数生成器使用的种子 random_state:int,建议设置一个常数保证在研究参数时不会被随机数干扰。如果为 None,则随机数生成器是 RandomState 使用的实例np.random# - 拆分器 splitter:字符串,可选(默认=“最佳”)用于在每个节点处选择拆分的策略。默认是 default=”best”,或者是“random”# 喂入数据tree_model.fit(X_train,y_train)# 打印模型评分print("模型评分:{}".format(tree_model.score(X_test,y_test)))# 随机生成一组数据使用我们的模型预测分类X_iris_test=np.array([[1.0,3.4,1.5,0.2]])# 用训练好的模型预测随机生成的样本数据的出的分类结果predict_result=tree_model.predict(X_iris_test)# 打印预测分类结果print(predict_result)print("分类结果:{}".format(iris_dataset['target_names'][predict_result]))# 模型可视化iris_feature_names=iris_dataset.feature_names#鸢尾花特征名列表 ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']iris_class_names=iris_dataset.target_names#鸢尾花分类类名列表 ['setosa' 'versicolor' 'virginica']fig = plt.figure(figsize=(20, 12))#图片画布大小比例matplotlib.rcParams['font.sans-serif'] = [u'SimHei']  # 用来设置字体样式以正常显示中文标签matplotlib.rcParams['axes.unicode_minus'] = False  # 设置为 Fasle 来解决负号的乱码问题sklearn.tree.plot_tree(tree_model,feature_names=iris_feature_names, class_names=iris_class_names, rounded=True, filled= True, fontsize=14)# 模型对象名# feature_names 特征名称的列表# class_names 分类名称的列表plt.title("决策树最大深度={}的可视化图".format(my_max_depth))plt.savefig("1/决策树最大深度={}的可视化图.png".format(my_max_depth))plt.show()Mytest_max_depth(3)
Mytest_max_depth(4)
Mytest_max_depth(5)

绘图结果

  • 决策树最大深度=3的可视化图

  • 决策树最大深度=4的可视化图

  • 决策树最大深度=5的可视化图

分析

最大深度限制树的最大深度,超过设定深度的树枝全部剪掉。
高维度低样本量时效果比较好,但是决策树生长层数的增加会导致对样本量的需求会增加一倍,树深度较低是能够有效地限制过拟合。实际使用时,要逐步尝试,比如从3开始看看拟合的效果再决定是否增加设定深度。


2. 对特征选择标准的研究与可视化

将“特征选择标准”分别设置为‘gini’(基尼系数)和‘entropy’(信息熵)建立决策树模型,并进行结果可视化,对比建模结果;

import matplotlib
import sklearn.tree
from sklearn.datasets import load_iris
# 导入决策树分类器
from sklearn.tree import DecisionTreeClassifier, plot_tree
# 导入分割数据集的方法
from sklearn.model_selection import train_test_split
# 导入科学计算包
import numpy as np
# 导入绘图库
import matplotlib.pyplot as plt# 加载鸢尾花数据集
iris_dataset = load_iris()
# 分割训练集与测试集
X_train,X_test,y_train,y_test=train_test_split(iris_dataset['data'],iris_dataset['target'],test_size=0.2,random_state=0)def Mytest_criterion(my_criterion):# 创建决策时分类器-tree_model=DecisionTreeClassifier(criterion=my_criterion,max_depth=4,random_state=0,splitter="best")# - 特征选择标准 criterion: string, 默认是 “gini”) 设置为‘gini’(基尼系数)或是‘entropy’(信息熵)# - 决策树最大深度 max_depth:int或None,可选(默认=None)树的最大深度。如果为 None,则扩展节点直到所有叶子都是纯的或直到所有叶子包含少于 min_samples_split 样本。# - 随机数生成器使用的种子 random_state:int,建议设置一个常数保证在研究参数时不会被随机数干扰。如果为 None,则随机数生成器是 RandomState 使用的实例np.random# - 拆分器 splitter:字符串,可选(默认=“最佳”)用于在每个节点处选择拆分的策略。默认是 default=”best”,或者是“random”# 喂入数据tree_model.fit(X_train,y_train)# 打印模型评分print("模型评分:{}".format(tree_model.score(X_test,y_test)))# 随机生成一组数据使用我们的模型预测分类X_iris_test=np.array([[1.0,3.4,1.5,0.2]])# 用训练好的模型预测随机生成的样本数据的出的分类结果predict_result=tree_model.predict(X_iris_test)# 打印预测分类结果print(predict_result)print("分类结果:{}".format(iris_dataset['target_names'][predict_result]))# 模型可视化iris_feature_names=iris_dataset.feature_names#鸢尾花特征名列表 ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']iris_class_names=iris_dataset.target_names#鸢尾花分类类名列表 ['setosa' 'versicolor' 'virginica']fig = plt.figure(figsize=(20, 12))#图片画布大小比例matplotlib.rcParams['font.sans-serif'] = [u'SimHei']  # 用来设置字体样式以正常显示中文标签matplotlib.rcParams['axes.unicode_minus'] = False  # 设置为 Fasle 来解决负号的乱码问题sklearn.tree.plot_tree(tree_model,feature_names=iris_feature_names, class_names=iris_class_names, rounded=True, filled= True, fontsize=14)# 模型对象名# feature_names 特征名称的列表# class_names 分类名称的列表plt.title("决策树特征选择标准={}的可视化图".format(my_criterion))plt.savefig("2/特征选择标准={}的可视化图.png".format(my_criterion))plt.show()Mytest_criterion("gini")#基尼系数
Mytest_criterion("entropy")#信息熵

绘图结果

  • 特征选择标准=entropy的可视化图
  • 特征选择标准=gini的可视化图

分析

  • 从图上看到我们的两种特征选择标准实际上在这个鸢尾花数据集里对于分类的结果趋向一致。实际上因为在数学建模中经常使用熵权法,我一般比较喜欢信息熵ID3算法,所以着重讲讲信息熵。

  • 信息熵:ID3算法—信息增益
    信息增益是针对一个具体的特征而言的,某个特征的有无对于整个系统、集合的影响程度就可以用“信息增益”来描述。我们知道,经过一次 if-else 判别后,原来的类别集合就被被分裂成两个集合,而我们的目的是让其中一个集合的某一类别的“纯度”尽可能高,如果分裂后子集的纯度比原来集合的纯度要高,那就说明这是一次 if-else 划分是有效过的。通过比较使的“纯度”最高的那个划分条件,也就是我们要找的“最合适”的特征维度判别条件。


3. 对决策树其他参数的研究与可视化

尝试修改决策树模型中的其他参数进行建模,并对比建模结果。

import matplotlib
import sklearn.tree
from sklearn.datasets import load_iris
# 导入决策树分类器
from sklearn.tree import DecisionTreeClassifier, plot_tree
# 导入分割数据集的方法
from sklearn.model_selection import train_test_split
# 导入科学计算包
import numpy as np
# 导入绘图库
import matplotlib.pyplot as plt# 加载鸢尾花数据集
iris_dataset = load_iris()
# 分割训练集与测试集
X_train,X_test,y_train,y_test=train_test_split(iris_dataset['data'],iris_dataset['target'],test_size=0.2,random_state=0)def Mytest_splitter(my_splitter):# 创建决策时分类器-tree_model=DecisionTreeClassifier(criterion="entropy",max_depth=4,random_state=0,splitter=my_splitter)# - 特征选择标准 criterion: string, 默认是 “gini”) 设置为‘gini’(基尼系数)或是‘entropy’(信息熵)# - 决策树最大深度 max_depth:int或None,可选(默认=None)树的最大深度。如果为 None,则扩展节点直到所有叶子都是纯的或直到所有叶子包含少于 min_samples_split 样本。# - 随机数生成器使用的种子 random_state:int,建议设置一个常数保证在研究参数时不会被随机数干扰。如果为 None,则随机数生成器是 RandomState 使用的实例np.random# - 拆分器 splitter:字符串,可选(默认=“最佳”)用于在每个节点处选择拆分的策略。默认是 default=”best”,或者是“random”# 喂入数据tree_model.fit(X_train,y_train)# 打印模型评分print("模型评分:{}".format(tree_model.score(X_test,y_test)))# 随机生成一组数据使用我们的模型预测分类X_iris_test=np.array([[1.0,3.4,1.5,0.2]])# 用训练好的模型预测随机生成的样本数据的出的分类结果predict_result=tree_model.predict(X_iris_test)# 打印预测分类结果print(predict_result)print("分类结果:{}".format(iris_dataset['target_names'][predict_result]))# 模型可视化iris_feature_names=iris_dataset.feature_names#鸢尾花特征名列表 ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']iris_class_names=iris_dataset.target_names#鸢尾花分类类名列表 ['setosa' 'versicolor' 'virginica']fig = plt.figure(figsize=(20, 12))#图片画布大小比例matplotlib.rcParams['font.sans-serif'] = [u'SimHei']  # 用来设置字体样式以正常显示中文标签matplotlib.rcParams['axes.unicode_minus'] = False  # 设置为 Fasle 来解决负号的乱码问题sklearn.tree.plot_tree(tree_model,feature_names=iris_feature_names, class_names=iris_class_names, rounded=True, filled= True, fontsize=14)# 模型对象名# feature_names 特征名称的列表# class_names 分类名称的列表plt.title("决策树拆分器={}的可视化图".format(my_splitter))plt.savefig("3/决策树拆分器={}的可视化图.png".format(my_splitter))plt.show()Mytest_splitter("best")#
Mytest_splitter("random")#

绘图结果

  • 决策树拆分器=best的可视化图
  • 决策树拆分器=random的可视化图

分析

在每个节点处选择拆分的策略,“best”与“random”这两种模式使得决策树有了很大的区别,实际上来拿个个决策树的最终分类效果相同,只是位置有些变动。


总结

决策树是一个预测模型;他代表的是对象属性与对象值之间的一种映射关系。树中每个节点表示某个对象,而每个分叉路径则代表的某个可能的属性值,而每个叶结点则对应从根节点到该叶节点所经历的路径所表示的对象的值。决策树仅有单一输出,若欲有复数输出,可以建立独立的决策树以处理不同输出。
我在高等数学建模大赛中经常使用熵权法(被指导老师dis为比较初级的方法,烂大街的方法)实际上这个就是决策树的基本原理,通过对属性进行分割,从而降低整体的混乱程度。即对一个属性的不同取值进行分组以后,每一组的混乱程度做个加权和,根据权重大小衡量属性的重要性。

大家喜欢的话,给个

【机器学习】07. 决策树模型DecisionTreeClassifier(代码注释,思路推导)相关推荐

  1. 《scikit-learn机器学习》决策树③ -泰坦尼克号幸存者预测【思路+代码】

    泰坦尼克号预测: 1.思路 1.1 数据处理 1.2 选择模型并训练 1.3 用前剪枝对模型进行优化 1.4 试试其他的决策树模型 2.具体代码实现(代码来源于本书,不做详细解释) 2.1 数据处理 ...

  2. 模型开发-GBDT决策树模型开发代码

    GBDT(Gradient Boosting Decision Tree) 又叫 MART(Multiple Additive Regression Tree),是一种迭代的决策树算法,该算法由多棵决 ...

  3. 机器学习实战 —— 决策树(完整代码)

    声明: 此笔记是学习<机器学习实战> -- Peter Harrington 上的实例并结合西瓜书上的理论知识来完成,使用Python3 ,会与书上一些地方不一样. 机器学习实战-- 决策 ...

  4. sklearn机器学习:决策树tree.DecisionTreeClassifier()

    sklearn中的决策树分类器 sklearn中的决策树分类器函数,格式如下: sklearn.tree.DecisionTreeClassifier(criterion='gini', splitt ...

  5. 机器学习之决策树模型最优属性选择方法

    决策树模型是用于解决分类问题的一个模型,它的特点是简答.逻辑清晰.可解释性好. 决策树是基于"树"结构进行决策的. 每个"内部结点"对应于某个属性上的" ...

  6. 贝叶斯机器学习:经典模型与代码实现!

    Datawhale干货 Author:louwill,贝叶斯机器学习 贝叶斯定理是概率模型中最著名的理论之一,在机器学习中也有着广泛的应用.基于贝叶斯理论常用的机器学习概率模型包括朴素贝叶斯和贝叶斯网 ...

  7. 【机器学习】贝叶斯机器学习:经典模型与代码实现

    贝叶斯机器学习 Author:louwill Machine Learning Lab 贝叶斯定理是概率模型中最著名的理论之一,在机器学习中也有着广泛的应用.基于贝叶斯理论常用的机器学习概率模型包括朴 ...

  8. 贝叶斯机器学习:经典模型与代码实现

    贝叶斯定理是概率模型中最著名的理论之一,在机器学习中也有着广泛的应用.基于贝叶斯理论常用的机器学习概率模型包括朴素贝叶斯和贝叶斯网络.本章在对贝叶斯理论进行简介的基础上,分别对朴素贝叶斯和贝叶斯网络理 ...

  9. 机器学习实战-决策树 java版代码开发实现

    话不多说,直接上代码,若有帮助,帮忙点赞哦 python版,或其他机器学习算法,可发邮箱:476562571@qq.com 主要实现功能: 特征 二值判别 递归遍历文件目录加载训练数据集 召回率计算 ...

最新文章

  1. 【 MATLAB 】信号处理工具箱之fft简介及案例分析
  2. [转]十分钟搞定Vue搭建
  3. 第二阶段冲刺10天 第五天
  4. lower_bound upper_bound
  5. python调研报告总结体会_学习调研心得体会
  6. Oracle 建立包 和 包体
  7. 函数参数传递、数组指针、二级指针、左值、引用
  8. C# File类的操作
  9. flume学习(十):如何使用Spooling Directory Source
  10. Spark核心编程原理
  11. IOS的Application以及IOS目录的介绍
  12. Java中文姓名拆分
  13. 拉丁字母表及中英文发音
  14. AOP之基于Schema配置总结与案例
  15. 2017年sfdc工作总结_Name 顺序
  16. 使用fit函数时,报错KeyError: ‘squared_error‘
  17. mysql 查询view_MySQL之视图(VIEW)
  18. 阐述商务礼仪的重要性
  19. 易基因技术推介|高通量单细胞甲基化测序技术介绍(sc-RBS)
  20. 7-1 判断两个数是否互质

热门文章

  1. 浅谈5G通信面临的电磁兼容挑战及解决方法
  2. Geany下载与安装
  3. vue-seamless-scroll数据量少时,暂停滚动,继续滚动
  4. 服务器的系统日志路径,DirectAdmin 日志路径各种系统中查看方法Windows服务器操作系统 -电脑资料...
  5. vue中Uncaught (in promise) TypeError: Object(...) is not a function报错
  6. python 爬取财经新闻_如何用 100 行 Python 代码实现新闻爬虫?
  7. python毕业设计能做什么工作_用python可以做什么毕业设计项目|融资公司的主要业务...
  8. HTML5Canvas实现简易画图工具(铅笔,直线,矩形,圆,文本框,橡皮擦等)
  9. php对接腾讯云直播,聊天,im,云录制产生回放
  10. RMAN下CROSSCHECK命令详解