MachineLearning(7)-决策树基础+sklearn.DecisionTreeClassifier简单实践
sklearn.DecisionTreeClassifier决策树简单使用
- 1.决策树算法基础
- 2.sklearn.DecisionTreeClassifier简单实践
- 2.1 决策树类
- 2.3 决策树构建
- 2.3.1全数据集拟合,决策树可视化
- 2.3.2交叉验证实验
- 2.3.3超参数搜索
- 2.3.4模型保存与导入
- 2.3.5固定随机数种子
- 参考资料
1.决策树算法基础
决策树模型可以用来做 回归/分类 任务。
每次选择一个属性/特征,依据特征的阈值,将特征空间划分为 与 坐标轴平行的一些决策区域。如果是分类问题,每个决策区域的类别为该该区域中多数样本的类别;如果为回归问题,每个决策区域的回归值为该区域中所有样本值的均值。
决策树复杂程度 依赖于 特征空间的几何形状。根节点->叶子节点的一条路径产生一条决策规则。
决策树最大优点:可解释性强
决策树最大缺点:不是分类正确率最高的模型
决策树的学习是一个NP-Complete问题,所以实际中使用启发性的规则来构建决策树。
step1:选最好的特征来划分数据集
step2:对上一步划分的子集重复步骤1,直至停止条件(节点纯度/分裂增益/树深度)
不同的特征衡量标准,产生了不同的决策树生成算法:
算法 | 最优特征选择标准 |
---|---|
ID3 | 信息增益:Gain(A)=H(D)−H(D∥A)Gain(A)=H(D)-H(D\|A)Gain(A)=H(D)−H(D∥A) |
C4.5 | 信息增益率:GainRatio(A)=Gain(A)/Split(A)GainRatio(A)=Gain(A)/Split(A)GainRatio(A)=Gain(A)/Split(A) |
CART | gini指数增益:Gini(D)−Gini(D∥A)Gini(D)-Gini(D\|A)Gini(D)−Gini(D∥A) |
k个类别,类别分布的gini 指数如下,gini指数越大,样本的不确定性越大:
Gini(D)=∑k=1Kpk(1−pk)=1−∑k=1Kpk2Gini(D) =\sum_{k=1}^Kp_k(1-p_k)=1-\sum_{k=1}^Kp_k^2Gini(D)=k=1∑Kpk(1−pk)=1−k=1∑Kpk2
CART – Classification and Regression Trees 的缩写1984年提出的一个特征选择算法,对特征进行是/否判断,生成一棵二叉树。且每次选择完特征后不对特征进行剔除操作,所有同一条决策规则上可能出现重复特征的情况。
2.sklearn.DecisionTreeClassifier简单实践
Scikit-learn(sklearn)是机器学习中常用的第三方模块,其建立在NumPy、Scipy、MatPlotLib之上,包括了回归,降维,分类,聚类方法。
sklearn 通过以下两个类实现了 决策分类树 和 决策回归树
sklearn 实现了ID3和Cart 算法,criterion默认为"gini"系数,对应为CART算法。还可设置为"entropy",对应为ID3。(计算机最擅长做的事:规则重复计算,sklearn通过对每个特征的每个切分点计算信息增益/gini增益,得到当前数据集合最优的特征及最优划分点)
2.1 决策树类
sklearn.tree.DecisionTreeClassifier(criterion=’gini’*,splitter=’best’, max_depth=None,
min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0,
max_features=None, random_state=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None, class_weight=None, presort=False)
DecisionTreeRegressor(criterion=’mse’, splitter=’best’,
max_depth=None, min_samples_split=2, min_samples_leaf=1,
min_weight_fraction_leaf=0.0, max_features=None, random_state=None,
max_leaf_nodes=None, min_impurity_decrease=0.0,
min_impurity_split=None, presort=False)
Criterion | 选择属性的准则–gini–cart算法 |
---|---|
splitter | 特征划分点的选择策略:best 特征的所有划分点中找最优 |
random 部分划分点中找最优 | |
max_depth | 决策树的最大深度,none/int 限制/不限制决策树的深度 |
min_samples_split | 节点 继续划分需要的最小样本数,如果少于这个数,节点将不再划分 |
min_samples_leaf | 限制叶子节点的最少样本数量,如果叶子节点的样本数量过少会被剪枝 |
min_weight_fraction_leaf | 叶子节点的剪枝规则 |
max_features | 选取用于分类的特征的数量 |
random_state | 随机数生成的一些规则、 |
max_leaf_nodes | 限制叶子节点的数量,防止过拟合 |
min_impurity_decrease | 表示结点减少的最小不纯度,控制节点的继续分割规律 |
min_impurity_split | 表示结点划分的最小不纯度,控制节点的继续分割规律 |
class_weight | 设置各个类别的权重,针对类别不均衡的数据集使用 |
不适用于决策树回归 | |
presort | 控制决策树划分的速度 |
2.3 决策树构建
采用sklearn内置数据集鸢尾花数据集做实验。
导入第三方库
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
import graphviz
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score
import joblib
plt.switch_backend('agg')
2.3.1全数据集拟合,决策树可视化
def demo1():# 全数据集拟合,决策树可视化iris = load_iris()x, y = load_iris(return_X_y = True) # x[list]-feature,y[]-label clf = tree.DecisionTreeClassifier() # 实例化了一个类,可以指定类参数,定制决策树模型clf = clf.fit(x,y) # 训练模型print("feature name ", iris.feature_names) # 特征列表, 自己的数据可视化时,构建一个特征列表即可print("label name ",iris.target_names) # 类别列表dot_data = tree.export_graphviz(clf, out_file = None, feature_names = iris.feature_names, class_names = iris.target_names ) graph = graphviz.Source(dot_data) # 能绘制树节点的一个接口graph.render("iris") # 存成pdf图
tree.export_graphviz 参数 | |
---|---|
feature_names | 特征列表list,和训练时的特征列表排列顺序对其即可 |
class_names | 类别l列表ist,和训练时的label列表排列顺序对其即可 |
filled | False/True,会依据criterion的纯度将节点显示成不同的颜色 |
value中的值显示的是各个类别样本的数量(二分类就是[负样本数,正样本数])
2.3.2交叉验证实验
def demo2():# n-折实验iris = load_iris()iris_feature = iris.data # 与demo1中的x,y是同样的数据iris_target = iris.target# 数据集合划分参数:train_x, test_x, train_y, test_y = train_test_split(iris_feature,iris_target,test_size = 0.2, random_state = 1)dt_model = DecisionTreeClassifier()dt_model.fit(train_x, train_y) # 模型训练predict_y = dt_model.predict(test_x) # 模型预测输出# score = dt_model.score(test_x,test_y) # 模型测试性能: 输入:feature_test,target_test , 输出acc# print(score) # 性能指标print("label: \n{0}".format(test_y[:5])) # 输出前5个labelprint("predict: \n{0}".format(predict_y[:5])) # 输出前5个label# sklearn 内置acc, recall, precision统计接口print("test acc: %.3f"%(accuracy_score(test_y, predict_y)))# print("test recall: %.3f"%(recall_score(test_y, predict_y))) # 多类别统计召回率需要指定平均方式# print("test precision: %.3f"%(precision_score(test_y, predict_y))) # 多类别统计准确率需要指定平均方式
2.3.3超参数搜索
def model_search(feas,labels):# 模型参数选择,全数据5折交叉验证,出结果min_impurity_de_entropy = np.linspace(0, 0.01, 10) # 纯度增益下界,划分后降低量少于这个值,将不进行分裂min_impurity_split_entropy = np.linspace(0, 0.4, 10) # 当前节点纯度小于这个值将不分裂,较高版本中已经取消这个参数max_depth_entropy = np.arange(1,11) # 决策树的深度# param_grid = {"criterion" : ["entropy"], "min_impurity_decrease" : min_impurity_de_entropy,"max_depth" : max_depth_entropy,"min_impurity_split" : min_impurity_split_entropy }param_grid = {"criterion" : ["entropy"], "max_depth" : max_depth_entropy, "min_impurity_split" : min_impurity_split_entropy }clf = GridSearchCV(DecisionTreeClassifier(), param_grid, cv = 5) # 遍历以上超参, 通过多次五折交叉验证得出最优的参数选择clf.fit(feas, label) print("best param:", clf.best_params_) # 输出最优参数选择print("best score:", clf.best_score_)
2.3.4模型保存与导入
模型保存
joblib.dump(clf,"./dtc_model.pkl")
模型导入
model_path = “./dtc_model.pkl”
clf = joblib.load(model_path)
2.3.5固定随机数种子
1.五折交叉验证,数据集划分随机数设置 random_state
train_test_split(feas, labels, test_size = 0.2, random_state = 1 )
2.模型随机数设置 andom_state
DecisionTreeClassifier(random_state = 1)
参考资料
1.官网类接口说明:
https://scikit-learn.org/dev/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier
可视化接口说明https://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html
2.决策树超参数调参技巧:https://www.jianshu.com/p/230be18b08c2
3.Sklearn.metrics 简介及应用示例:https://blog.csdn.net/Yqq19950707/article/details/90169913
4.sklearn的train_test_split()各函数参数含义解释(非常全):https://www.cnblogs.com/Yanjy-OnlyOne/p/11288098.html
5.sklearn.tree.DecisionTreeClassifier 详细说明:https://www.jianshu.com/p/8f3f1e706f11
6.使用scikit-learn中的metrics以及DecisionTreeClassifier重做《机器学习实战》中的隐形眼镜分类问题:http://keyblog.cn/article-235.html
7.决策树算法:https://www.cnblogs.com/yanqiang/p/11600569.html
MachineLearning(7)-决策树基础+sklearn.DecisionTreeClassifier简单实践相关推荐
- 【自然语言处理】word2vec/doc2vec基础学习以及简单实践
文章目录 一.前言 二. 向量化算法word2vec 2.1 引言 2.2 word2vec原理 2.3 词的表示 三.神经网络语言模型 四.C&W模型 五.CBOW模型 5.1 CBOW模型 ...
- MachineLearning(8)-PCA,LDA基础+sklearn 简单实践
PCA,LDA基础+sklearn 简单实践 1.PCA+sklearn.decomposition.PCA 1.PCA理论基础 2.sklearn.decomposition.PCA简单实践 2.L ...
- 【sklearn入门】决策树在sklearn中的实现--实战红酒分类案例
scikit-learn简介 scikit-learn,又写作sklearn,是一个开源的基于python语言的机器学习工具包.它通过NumPy, SciPy和 Matplotlib等python数值 ...
- 决策树在sklearn中的实现
1 概述 1.1 决策树是如何工作的 1.2 构建决策树 1.2.1 ID3算法构建决策树 1.2.2 简单实例 1.2.3 ID3的局限性 1.3 C4.5算法 & CART算法 1.3.1 ...
- 监督学习 | 决策树之Sklearn实现
文章目录 1. Sklearn中决策树的超参数 1.1 最大深度 max_depth 1.2 每片叶子的最小样本数 min_samples_leaf 1.3 每次分裂的最小样本数 min_sample ...
- C++(11)--编程实践1-经典养成类游戏简单实践
经典养成类游戏简单实践-小公主养成记 <老九学堂C++课程>学习笔记.<老九学堂C++课程>详情请到B站搜索<老九零基础学编程C++入门> ------------ ...
- Java数字图像处理基础-------Java Swing简单使用,图形绘画---画五角星
Java数字图像处理基础-------Java Swing简单使用,图形绘画-画五角星 一:简介 要画出五角星出来,我们只需要在面板上产生5个点,然后把这5个点进行连接就可实现: 二:代码演示 imp ...
- 运用京东云代码托管、云编译、云部署等产品进行蓝绿部署简单实践
干货 | 运用京东云代码托管.云编译.云部署等产品进行蓝绿部署简单实践 前几天我们以一种较为传统的方式在京东云上简单实践了基于Jenkins+Docker+Git 的CI流程,主要利用一些开源技术来实 ...
- 计算机实践学什么作用,大学计算机基础:计算机操作实践
大学计算机基础:计算机操作实践 语音 编辑 锁定 讨论 上传视频 <大学计算机基础:计算机操作实践>是人民邮电出版社出版的图书,ISBN是9787115182425.[1] 书 名 ...
最新文章
- display函数怎么使用_Chapter19:拷贝构造函数
- 令AI费解的图像层出不穷 计算机视觉远未达到完美
- 放弃java转战kotlin,我的心路历程
- 10道海量数据处理的面试题
- 5 select 选择的值_表单元素之选择类型
- 5G RRC——为NAS层提供连接管理,消息传递等服务; 对接入网的底层协议实体提供参数配置的功能; 负责UE移动性管理相关的测量、控制等功能...
- sdut 2139BFS
- 异常解决(二)-- AttributeError: cannot assign module before Module.__init__() call
- 番石榴15 –新功能
- java中数组输出空格_如何使用数字元素和空格分割字符串并将其存储到Java中的可索引数组中?...
- 精细化的风险管理,评分的应用策略之道
- bootstrap table中文文档_用Python完成一件小事:自动生成文档报告
- iOS开发实践之网络检測Reachability
- java通过TscLibDll调用佳博热敏票据打印机(580130IVC)打印小票
- Shellsploit注入器简单利用
- 炉石兄弟 修复图腾师问题 by大神beebee102, 还有阴燃电鳗
- Nginx报错:nginx: [error] invalid PID number in /run/nginx.pid的解决方案
- 使用自签证书利用浏览器进行HTTPS接口的安全访问
- 如何将数据移动到新硬盘(装机)
- lightoj1219Mafia
热门文章
- 如何查看Linux版本号(内核版本号和发行版本号)
- 设计模式C++实现(7)——装饰模式
- Windows CE创建桌面快捷方式
- jsp思维导图_2019年经济法基础思维导图
- 神经网络与深度学习——TensorFlow2.0实战(笔记)(二)(Anaconda软件使用)
- 【转】DCMTK 开源库的学习笔记2:直接操作dcm文件中像素数据的尝试
- 【转】.net框架读书笔记---CLR内存管理\垃圾收集(六)
- 【转】人工智能-1.1.1 什么是神经网络
- 【转】The underlying connection was closed
- 第十五节:Asp.Net Core MVC和WebApi路由规则的总结和对比-第二十节