1. sklearn介绍

scikit-learn, 又写作sklearn, 是一个开源的基于python语言的机器学习工具包. 它通过numpy, Scipy和 Matplotlib等python数值计算的库实现的算法应用, 并且涵盖了几乎所有主流机器学习算法.

在工程应用中, 用python手写代码来从头实现一个算法的可能性非常低, 这样不仅耗时耗力, 还不一定能够写出构架清晰, 稳定性强的模型. 更多情况下, 是分析采集到的数据, 根据数据特征选择适合的算法, 在工具包中调用算法, 调整算法的参数, 获取需要的信息, 从而实现算法效率和效果之间的平衡. 而sklearn, 正是这样一个可以帮助我们高效实现算法应用的工具包.

官方网站: scikit-learn: machine learning in Python — scikit-learn 0.24.2 documentation

2. sklearn中的决策树

sklearn中决策树的类都在“tree”这个模块之下. 这个模块总共包含五个类:

tree.DecisionTreeClassifier 分类树
tree.DecisionTreeRegressor 回归树
tree.export_graphviz 将生成的决策树导出位DOT格式, 画图专用
tree.ExtraTreeClassifier 高随机版本的分类树
tree.ExtraTreeRegressor 高随机版本的回归树

sklearn的基本建模流程:

1.实例化, 建立评估对象

2.通过模型接口训练模型

3.通过模型接口提取需要的信息

分类树所对应代码:

import sklearn import treeclf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, Y_train)
result = clf.score(X_test, Y_test)

3. DecisionTreeClassifier

class sklearn.tree.DecisionTreeClassifier(criterion = 'gini', splitter = 'best', max_depth = None, min_samples_split = 2, min_samples_leaf = 1, min_weight_fraction_leaf = 0.0, max_features = None, random_state =None, max_leaf_nodes = None, min_impurity_decrease = 0.0, min_impurity_split = None, class_weight = None, presort =Flase)

3.1 重要参数

3.1.1 criterion

Criterion这个参数正是用来决定不纯度的计算方法的。

sklearn提供了两种选择:

1)输入“entropy”, 使用信息熵(Entropy)

2)输入“gini”,使用基尼系数(Gini Impurity)

注意, 当使用信息熵时, sklearn实际计算的是基于信息熵的信息增益(Information Gain), 即父节点的信息熵和子节点的信息熵之差. 比起基尼系数, 信息熵对不纯度更加敏感, 对不纯度的惩罚最强. 但是在实际使用中,信息熵和基尼系数的效果基本相同. 信息熵的计算比基尼系数缓慢一些, 因为基尼系数的计算不涉及对数. 另外, 因为信息熵对不纯度更加敏感, 所以信息熵作为指标时, 决策树的生长会更加“精细”, 因此对于高维数据或者噪音很多的数据, 信息熵很容易过拟合, 基尼系数在这种情况下效果往往比较好. 当模型拟合程度不足的时候, 即当模型在训练集和测试集上都表 现不太好的时候, 使用信息熵. 当然, 这些不是绝对的.

3.1.2 random_state & splitter

random_state用来设置分枝中的随机模式的参数, 默认None, 在高维度时随机性会表现更明显,低维度的数据 (比如鸢尾花数据集), 随机性几乎不会显现. 输入任意整数, 会一直长出同一棵树, 让模型稳定下来.  splitter也是用来控制决策树中的随机选项的, 有两种输入值, 输入”best", 决策树在分枝时虽然随机, 但是还是会优先选择更重要的特征进行分枝(重要性可以通过属性feature_importances_查看), 输入“random”, 决策树在 分枝时会更加随机, 树会因为含有更多的不必要信息而更深更大,并因这些不必要信息而降低对训练集的拟合. 这也是防止过拟合的一种方式. 当你预测到你的模型会过拟合, 用这两个参数来帮助你降低树建成之后过拟合的可能性. 当然, 树一旦建成,我们依然是使用剪枝参数来防止过拟合.

3.2 剪枝参数

不加限制的情况下, 一棵决策树会生长到衡量不纯度的指标最优, 或者没有更多的特征可用为止. 这样的决策树往往会过拟合, 这就是说, 它会在训练集上表现很好, 在测试集上却表现糟糕. 我们收集的样本数据不可能和整体的状况完全一致, 因此当一棵决策树对训练数据有了过于优秀的解释性, 它找出的规则必然包含了训练样本中的噪声, 并使它对未知数据的拟合程度不足.

为了让决策树有更好的泛化性, 我们要对决策树进行剪枝. 剪枝策略对决策树的影响巨大, 正确的剪枝策略是优化决策树算法的核心. sklearn为我们提供了不同的剪枝策略:

1) max_depth

限制树的最大深度, 超过设定深度的树枝全部剪掉.

这是用得最广泛的剪枝参数, 在高维度低样本量时非常有效. 决策树多生长一层, 对样本量的需求会增加一倍, 所以限制树深度能够有效地限制过拟合. 在集成算法中也非常实用. 实际使用时, 建议从=3开始尝试, 看看拟合的效果再决定是否增加设定深度.

2)min_samples_leaf & min_samples_split

min_samples_leaf限定, 一个节点在分枝后的每个子节点都必须包含至少min_samples_leaf个训练样本, 否则分枝就不会发生;  或者, 分枝会朝着满足每个子节点都包含min_samples_leaf个样本的方向去发生.

一般搭配max_depth使用, 在回归树中有神奇的效果, 可以让模型变得更加平滑. 这个参数的数量设置得太小会引 起过拟合, 设置得太大就会阻止模型学习数据. 一般来说, 建议从=5开始使用. 如果叶节点中含有的样本量变化很大, 建议输入浮点数作为样本量的百分比来使用. 同时, 这个参数可以保证每个叶子的最小尺寸, 可以在回归问题 中避免低方差, 过拟合的叶子节点出现. 对于类别不多的分类问题, =1通常就是最佳选择. min_samples_split限定, 一个节点必须要包含至少min_samples_split个训练样本, 这个节点才允许被分枝, 否则分枝就不会发生.

3) max_features & min_impurity_decrease

max_features限制分枝时考虑的特征个数, 超过限制个数的特征都会被舍弃. 和max_depth异曲同工,  max_features是用来限制高维度数据的过拟合的剪枝参数, 但其方法比较暴力, 是直接限制可以使用的特征数量而强行使决策树停下的参数, 在不知道决策树中的各个特征的重要性的情况下, 强行设定这个参数可能会导致模型学习不足. 如果希望通过降维的方式防止过拟合, 建议使用PCA, ICA或者特征选择模块中的降维算法.

下面用红酒数据集画一颗决策树:

from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
import pandas as pd
import graphvizwine = load_wine()
# print(wine.target)
# pd.concat([pd.DataFrame(wine.data),pd.DataFrame(wine.target)],axis=1)
Xtrain, Xtest, Ytrain, Ytest = train_test_split(wine.data, wine.target, test_size=0.3)clf = tree.DecisionTreeClassifier(criterion="entropy"# ,random_state=30, splitter="random", max_depth=3# ,min_samples_leaf=10# ,min_samples_split=10)
clf = clf.fit(Xtrain, Ytrain)
score = clf.score(Xtest, Ytest)  # 返回预测的准确度
feature_name = ['酒精', '苹果酸', '灰', '灰的碱性', '镁', '总酚', '类黄酮','非黄烷类酚类', '花青素', '颜色强度', '色调', '稀释葡萄酒', '脯氨酸']
target_name = ['琴酒', '雪莉', '贝尔摩德']
dot_data = tree.export_graphviz(clf, out_file=None, feature_names=feature_name, class_names=target_name, filled=True, rounded=True)
graph = graphviz.Source(dot_data.replace('helvetica', '"Microsoft YaHei"'), encoding='utf-8')graph.view("tree")

结果如下图所示:


原作者: B站up主菜菜

利用sklearn对红酒数据集分类相关推荐

  1. 【机器学习】决策树案例二:利用决策树进行鸢尾花数据集分类预测

    利用决策树进行鸢尾花数据集分类预测 2 利用决策树进行鸢尾花数据集分类预测 2.1 导入模块与加载数据 2.2 划分数据 2.3 模型创建与应用 2.4 模型可视化 手动反爬虫,禁止转载: 原博地址 ...

  2. 利用神经网络对鸢尾花数据集分类

    利用神经网络对鸢尾花数据集分类 详细实现代码请见:https://download.csdn.net/download/weixin_43521269/12578696 一.简介 一个人工神经元网络是 ...

  3. 以红酒数据集分类为例做决策树的可视化

    文章目录 前言 决策树原理 可视化决策树举例 gini entropy 总结 前言 本文是决策树可视化例子 决策树原理 决策树的分类原理有ID3(信息增益最大准则).C4.5(信息增益比准则).CAR ...

  4. 利用决策树算法对sklearn中红酒数据集进行可视化分类

    '''决策树是一种在分类和回归中都广泛应用的算法,它的原理是通过对一系列问题进行if/else进行推导,最终实现决策''' '''决策树最大的优势就是可以轻易的将模型可视化,而且决策树算法对每个样本的 ...

  5. TF之CNN:利用sklearn(自带手写数字图片识别数据集)使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线

    TF之CNN:利用sklearn(自带手写数字图片识别数据集)使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线 目录 输出结果 设计代码 输出结果 设计代码 ...

  6. PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析

    PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析 目录 输出结果 核心代码 输出结果 核心代码 #PyTorch:采用skle ...

  7. python决策树分类 导入数据集_python+sklearn实现决策树(分类树)

    整理今天的代码-- 采用的是150条鸢尾花的数据集fishiris.csv # 读入数据,把Name列取出来作为标签(groundtruth) import pandas as pd data = p ...

  8. (NO.1)利用sklearn进行鸢尾花分类

    文章目录 利用sklearn进行鸢尾花分类 preheat 联库 版本查询 practice summary 利用sklearn进行鸢尾花分类 preheat 联库 sklearn是基于Numpy和S ...

  9. ML之Xgboost:利用Xgboost模型对数据集(比马印第安人糖尿病)进行二分类预测(5年内是否患糖尿病)

    ML之Xgboost:利用Xgboost模型对数据集(比马印第安人糖尿病)进行二分类预测(5年内是否患糖尿病) 目录 输出结果 设计思路 核心代码 输出结果 X_train内容: [[ 3. 102. ...

  10. ML之分类预测之ElasticNet:利用ElasticNet回归对二分类数据集构建二分类器(DIY交叉验证+分类的两种度量PK)

    ML之分类预测之ElasticNet:利用ElasticNet回归对二分类数据集构建二分类器(DIY交叉验证+分类的两种度量PK) 目录 输出结果 设计思路 核心代码 输出结果 设计思路 核心代码 # ...

最新文章

  1. php中file_get_contents如何读取大容量文件
  2. 软件测试2019:第四次作业
  3. java ee核心技术_【科普】JavaEE的核心技术
  4. ubuntu android设备 no permissions
  5. 对Android源码分析总结(Z)
  6. 全网最新Spring Boot2.5.1整合Activiti5.22.0企业实战教程<流程挂起与激活篇>
  7. 深度学习stride_深度强化学习成名作——DQN
  8. 堆 堆栈 java_java的栈和堆
  9. Java怎样获取Content-Type的文件类型Mime Type
  10. mysql唯一性约束冲突_如何解决逻辑删除与数据库唯一约束冲突
  11. 从俄罗斯方块,迈向强化学习大门
  12. scrapy如何指定生成python3的项目_python3+Scrapy爬虫实战(一)—— 初识Scrapy
  13. Accurate, Large Minibatch SGD
  14. 双硬盘安装ubuntu18.04踩坑及解决全过程
  15. APP专项测试——弱网测试
  16. 线性方程组求解——基于MTALAB/Octave,Numpy,Sympy和Maxima
  17. 《地球帝国2》中文版秘籍
  18. 《那些年啊,那些事——一个程序员的奋斗史》——41
  19. 开启子进程的两种方式,孤儿进程与僵尸进程,守护进程,互斥锁,IPC机制,生产者与消费者模型...
  20. 便签插件可以贴在手机桌面上显示吗?怎么设置呢?

热门文章

  1. 我的Android进阶之旅------Android ListView优化详解
  2. jQuery、Ajax,DataTable数据如何转换成Json格式
  3. RHEL 6.2 Error: Cannot create GC thread. Out of system resources.
  4. Repeater使用方法---基础数据绑定+多级嵌套
  5. 那些年,我们一起学过的汇编----之子程序设计
  6. 也谈UpdatePanel与UrlRewrite一起work时出现Form Action属性的问题
  7. 09月28日 pytorch与resnet(四)三种主要的转移学习方案,微调ConvNet,ConvNet 作为固定特征提取器
  8. 【翻译】Geometric Features-Based Parking Slot Detection
  9. 【导入篇】Robotics:Perception课程_导入篇、四周课程内容、week 1st Perspective Projection
  10. Android内存泄漏(转)