作者|Nikhil Adithyan

编译|VK

来源|Towards Data Science

决策树

决策树是当今最强大的监督学习方法的组成部分。决策树基本上是一个二叉树的流程图,其中每个节点根据某个特征变量将一组观测值拆分。

决策树的目标是将数据分成多个组,这样一个组中的每个元素都属于同一个类别。决策树也可以用来近似连续的目标变量。在这种情况下,树将进行拆分,使每个组的均方误差最小。

决策树的一个重要特性是它们很容易被解释。你根本不需要熟悉机器学习技术就可以理解决策树在做什么。决策树图很容易解释。

利弊

决策树方法的优点是:

  • 决策树能够生成可理解的规则。

  • 决策树在不需要大量计算的情况下进行分类。

  • 决策树能够处理连续变量和分类变量。

  • 决策树提供了一个明确的指示,哪些字段是最重要的。

决策树方法的缺点是:

  • 决策树不太适合于目标是预测连续属性值的估计任务。

  • 决策树在类多、训练样本少的分类问题中容易出错。

  • 决策树的训练在计算上可能很昂贵。生成决策树的过程在计算上非常昂贵。在每个节点上,每个候选拆分字段都必须进行排序,才能找到其最佳拆分。在某些算法中,使用字段组合,必须搜索最佳组合权重。剪枝算法也可能是昂贵的,因为许多候选子树必须形成和比较。

Python决策树

Python是一种通用编程语言,它为数据科学家提供了强大的机器学习包和工具。在本文中,我们将使用python最著名的机器学习包scikit-learn来构建决策树模型。我们将使用scikit learn提供的“DecisionTreeClassifier”算法创建模型,然后使用“plot_tree”函数可视化模型。

步骤1:导入包

我们构建模型的主要软件包是pandas、scikit learn和NumPy。按照代码在python中导入所需的包。

import pandas as pd # 数据处理
import numpy as np # 使用数组
import matplotlib.pyplot as plt # 可视化
from matplotlib import rcParams # 图大小
from termcolor import colored as cl # 文本自定义from sklearn.tree import DecisionTreeClassifier as dtc # 树算法
from sklearn.model_selection import train_test_split # 拆分数据
from sklearn.metrics import accuracy_score # 模型准确度
from sklearn.tree import plot_tree # 树图rcParams['figure.figsize'] = (25, 20)

在导入构建我们的模型所需的所有包之后,是时候导入数据并对其进行一些EDA了。

步骤2:导入数据和EDA

在这一步中,我们将使用python中提供的“Pandas”包来导入并在其上进行一些EDA。我们将建立我们的决策树模型,数据集是一个药物数据集,它是基于特定的标准给病人开的处方。让我们用python导入数据!

Python实现:
df = pd.read_csv('drug.csv')
df.drop('Unnamed: 0', axis = 1, inplace = True)print(cl(df.head(), attrs = ['bold']))

「输出:」

   Age Sex      BP Cholesterol  Na_to_K   Drug
0   23   F    HIGH        HIGH   25.355  drugY
1   47   M     LOW        HIGH   13.093  drugC
2   47   M     LOW        HIGH   10.114  drugC
3   28   F  NORMAL        HIGH    7.798  drugX
4   61   F     LOW        HIGH   18.043  drugY

现在我们对数据集有了一个清晰的概念。导入数据后,让我们使用“info”函数获取有关数据的一些基本信息。此函数提供的信息包括条目数、索引号、列名、非空值计数、属性类型等。

Python实现:
df.info()

「输出:」

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 200 entries, 0 to 199
Data columns (total 6 columns):#   Column       Non-Null Count  Dtype
---  ------       --------------  -----  0   Age          200 non-null    int64  1   Sex          200 non-null    object 2   BP           200 non-null    object 3   Cholesterol  200 non-null    object 4   Na_to_K      200 non-null    float645   Drug         200 non-null    object
dtypes: float64(1), int64(1), object(4)
memory usage: 9.5+ KB

步骤3:数据处理

我们可以看到像Sex, BP和Cholesterol这样的属性在本质上是分类的和对象类型的。问题是,scikit-learn中的决策树算法本质上不支持X变量(特征)是“对象”类型。因此,有必要将这些“object”值转换为“binary”值。让我们用python来实现

Python实现:
for i in df.Sex.values:if i  == 'M':df.Sex.replace(i, 0, inplace = True)else:df.Sex.replace(i, 1, inplace = True)for i in df.BP.values:if i == 'LOW':df.BP.replace(i, 0, inplace = True)elif i == 'NORMAL':df.BP.replace(i, 1, inplace = True)elif i == 'HIGH':df.BP.replace(i, 2, inplace = True)for i in df.Cholesterol.values:if i == 'LOW':df.Cholesterol.replace(i, 0, inplace = True)else:df.Cholesterol.replace(i, 1, inplace = True)print(cl(df, attrs = ['bold']))

「输出:」

     Age  Sex  BP  Cholesterol  Na_to_K   Drug
0     23    1   2            1   25.355  drugY
1     47    1   0            1   13.093  drugC
2     47    1   0            1   10.114  drugC
3     28    1   1            1    7.798  drugX
4     61    1   0            1   18.043  drugY
..   ...  ...  ..          ...      ...    ...
195   56    1   0            1   11.567  drugC
196   16    1   0            1   12.006  drugC
197   52    1   1            1    9.894  drugX
198   23    1   1            1   14.020  drugX
199   40    1   0            1   11.349  drugX[200 rows x 6 columns]

我们可以观察到所有的“object”值都被处理成“binary”值来表示分类数据。例如,在胆固醇属性中,显示“低”的值被处理为0,“高”则被处理为1。现在我们准备好从数据中创建因变量和自变量。

步骤4:拆分数据

在将我们的数据处理为正确的结构之后,我们现在设置“X”变量(自变量),“Y”变量(因变量)。让我们用python来实现

Python实现:
X_var = df[['Sex', 'BP', 'Age', 'Cholesterol', 'Na_to_K']].values # 自变量
y_var = df['Drug'].values # 因变量print(cl('X variable samples : {}'.format(X_var[:5]), attrs = ['bold']))
print(cl('Y variable samples : {}'.format(y_var[:5]), attrs = ['bold']))

「输出:」

X variable samples : [[ 1.     2.    23.     1.    25.355][ 1.     0.    47.     1.    13.093][ 1.     0.    47.     1.    10.114][ 1.     1.    28.     1.     7.798][ 1.     0.    61.     1.    18.043]]
Y variable samples : ['drugY' 'drugC' 'drugC' 'drugX' 'drugY']

我们现在可以使用scikit learn中的“train_test_split”算法将数据分成训练集和测试集,其中包含我们定义的X和Y变量。按照代码在python中拆分数据。

Python实现:
X_train, X_test, y_train, y_test = train_test_split(X_var, y_var, test_size = 0.2, random_state = 0)print(cl('X_train shape : {}'.format(X_train.shape), attrs = ['bold'], color = 'black'))
print(cl('X_test shape : {}'.format(X_test.shape), attrs = ['bold'], color = 'black'))
print(cl('y_train shape : {}'.format(y_train.shape), attrs = ['bold'], color = 'black'))
print(cl('y_test shape : {}'.format(y_test.shape), attrs = ['bold'], color = 'black'))

「输出:」

X_train shape : (160, 5)
X_test shape : (40, 5)
y_train shape : (160,)
y_test shape : (40,)

现在我们有了构建决策树模型的所有组件。所以,让我们继续用python构建我们的模型。

步骤5:建立模型和预测

在scikit学习包提供的“DecisionTreeClassifier”算法的帮助下,构建决策树是可行的。之后,我们可以使用我们训练过的模型来预测我们的数据。最后,我们的预测结果的精度可以用“准确度”评估指标来计算。让我们用python来完成这个过程!

Python实现:
model = dtc(criterion = 'entropy', max_depth = 4)
model.fit(X_train, y_train)pred_model = model.predict(X_test)print(cl('Accuracy of the model is {:.0%}'.format(accuracy_score(y_test, pred_model)), attrs = ['bold']))

「输出:」

Accuracy of the model is 88%

在代码的第一步中,我们定义了一个名为“model”变量的变量,我们在其中存储DecisionTreeClassifier模型。接下来,我们将使用我们的训练集对模型进行拟合和训练。之后,我们定义了一个变量,称为“pred_model”变量,其中我们将模型预测的所有值存储在数据上。最后,我们计算了我们的预测值与实际值的精度,其准确率为88%。

步骤6:可视化模型

现在我们有了决策树模型,让我们利用python中scikit learn包提供的“plot_tree”函数来可视化它。按照代码从python中的决策树模型生成一个漂亮的树图。

Python实现:
feature_names = df.columns[:5]
target_names = df['Drug'].unique().tolist()plot_tree(model, feature_names = feature_names, class_names = target_names, filled = True, rounded = True)plt.savefig('tree_visualization.png')

「输出:」

结论

有很多技术和其他算法用于优化决策树和避免过拟合,比如剪枝。虽然决策树通常是不稳定的,这意味着数据的微小变化会导致最优树结构的巨大变化,但其简单性使其成为广泛应用的有力候选。在神经网络流行之前,决策树是机器学习中最先进的算法。其他一些集成模型,比如随机森林模型,比普通决策树模型更强大。

决策树由于其简单性和可解释性而非常强大。决策树和随机森林在用户注册建模、信用评分、故障预测、医疗诊断等领域有着广泛的应用。我为本文提供了完整的代码。

完整代码:

import pandas as pd # 数据处理
import numpy as np # 使用数组
import matplotlib.pyplot as plt # 可视化
from matplotlib import rcParams # 图大小
from termcolor import colored as cl # 文本自定义from sklearn.tree import DecisionTreeClassifier as dtc # 树算法
from sklearn.model_selection import train_test_split # 拆分数据
from sklearn.metrics import accuracy_score # 模型准确度
from sklearn.tree import plot_tree # 树图rcParams['figure.figsize'] = (25, 20)df = pd.read_csv('drug.csv')
df.drop('Unnamed: 0', axis = 1, inplace = True)print(cl(df.head(), attrs = ['bold']))df.info()for i in df.Sex.values:if i  == 'M':df.Sex.replace(i, 0, inplace = True)else:df.Sex.replace(i, 1, inplace = True)for i in df.BP.values:if i == 'LOW':df.BP.replace(i, 0, inplace = True)elif i == 'NORMAL':df.BP.replace(i, 1, inplace = True)elif i == 'HIGH':df.BP.replace(i, 2, inplace = True)for i in df.Cholesterol.values:if i == 'LOW':df.Cholesterol.replace(i, 0, inplace = True)else:df.Cholesterol.replace(i, 1, inplace = True)print(cl(df, attrs = ['bold']))X_var = df[['Sex', 'BP', 'Age', 'Cholesterol', 'Na_to_K']].values # 自变量
y_var = df['Drug'].values # 因变量print(cl('X variable samples : {}'.format(X_var[:5]), attrs = ['bold']))
print(cl('Y variable samples : {}'.format(y_var[:5]), attrs = ['bold']))X_train, X_test, y_train, y_test = train_test_split(X_var, y_var, test_size = 0.2, random_state = 0)print(cl('X_train shape : {}'.format(X_train.shape), attrs = ['bold'], color = 'red'))
print(cl('X_test shape : {}'.format(X_test.shape), attrs = ['bold'], color = 'red'))
print(cl('y_train shape : {}'.format(y_train.shape), attrs = ['bold'], color = 'green'))
print(cl('y_test shape : {}'.format(y_test.shape), attrs = ['bold'], color = 'green'))model = dtc(criterion = 'entropy', max_depth = 4)
model.fit(X_train, y_train)pred_model = model.predict(X_test)print(cl('Accuracy of the model is {:.0%}'.format(accuracy_score(y_test, pred_model)), attrs = ['bold']))feature_names = df.columns[:5]
target_names = df['Drug'].unique().tolist()plot_tree(model, feature_names = feature_names, class_names = target_names, filled = True, rounded = True)plt.savefig('tree_visualization.png')

原文链接:https://towardsdatascience.com/building-and-visualizing-decision-tree-in-python-2cfaafd8e1bb

往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑
获取本站知识星球优惠券,复制链接直接打开:
https://t.zsxq.com/y7uvZF6
本站qq群704220115。加入微信群请扫码:

【机器学习基础】用Python构建和可视化决策树相关推荐

  1. 用Python构建和可视化决策树

    决策树 决策树是当今最强大的监督学习方法的组成部分.决策树基本上是一个二叉树的流程图,其中每个节点根据某个特征变量将一组观测值拆分. 决策树的目标是将数据分成多个组,这样一个组中的每个元素都属于同一个 ...

  2. Python机器学习基础之Python的基本语法(一)

    当今世界已经进入了大数据的时代.随着信息化的不断发展,人工智能.机器学习等词语越来越被人们所熟知,而他们也渐渐地成了这个时代的弄潮儿,走在了信息时代的前端.从本篇博客开始,小编将带领大家一起走进人工智 ...

  3. 【机器学习】使用 Python 构建电影推荐系统

    本文将余弦相似度与 KNN.Seaborn.Scikit-learn 和 Pandas 结合使用,创建一个使用用户评分数据的电影推荐系统. 在日常数据挖掘工作中,除了会涉及到使用Python处理分类或 ...

  4. 【机器学习基础】Python机器学习的神器- Scikit-learn使用说明

    全文共 26745 字,106 幅图表, 预计阅读时间 67 分钟. 0 引言 Sklearn (全称 Scikit-Learn) 是基于 Python 语言的机器学习工具.它建立在 NumPy, S ...

  5. 【机器学习基础】Python机器学习入门指南(全)

    前言 机器学习 作为人工智能领域的核心组成,是计算机程序学习数据经验以优化自身算法,并产生相应的"智能化的"建议与决策的过程. 一个经典的机器学习的定义是: A computer ...

  6. 【机器学习基础】Python实现行转列?!超简单,赶快get起来

    ◆ ◆ ◆  ◆ ◆ 前言 数据的行转列操作,在实际工作过程中应用非常广泛. 由于不同人员.不同部门对数据结构的认识是不大相同的,尤其是从基层人员手里拿到的数据,更是五花八门,横七竖八. 比如有这样一 ...

  7. 【机器学习基础】Python数据预处理:彻底理解标准化和归一化

    数据预处理 数据中不同特征的量纲可能不一致,数值间的差别可能很大,不进行处理可能会影响到数据分析的结果,因此,需要对数据按照一定比例进行缩放,使之落在一个特定的区域,便于进行综合分析. 常用的方法有两 ...

  8. 1.2机器学习基础下--python深度机器学习

    1. 机器学习更多应用举例: 人脸识别 2. 机器学习就业需求: LinkedIn所有职业技能需求量第一:机器学习,数据挖掘和统计分析人才      http://blog.linkedin.com/ ...

  9. Python机器学习基础之Matplotlib库的使用

    声明:代码的运行环境为Python3.Python3与Python2在一些细节上会有所不同,希望广大读者注意.本博客以代码为主,代码中会有详细的注释.相关文章将会发布在我的个人博客专栏<Pyth ...

最新文章

  1. 赋值、浅拷贝、深拷贝
  2. r语言必学的十个包肖凯_家长专栏自闭症儿童语言康复训练
  3. 高性能服务器架构思路「不仅是思路」
  4. H264和AAC合成FLV案例
  5. 聚合搜索V2.0泛目录站群二开源码 可做指定关键词
  6. 一些在Android中的小设置~~~持续添加
  7. android数据存放map_Android存储数据到本地文件
  8. LEFT OUTER JOIN
  9. Python urllib2 设置超时时间并处理超时异常
  10. Javaweb技术的校运会报名及比赛管理系统
  11. 浏览器 之 无头浏览器
  12. 【题解】UVA177 分治
  13. 《深度学习--基于python的理论与实现》学习笔记6:第三章神经网络(2)
  14. 笔记 How Powerful are Spectral Graph Neural Networks
  15. 菜鸟学数据库——大话 char、varchar、 nchar、nvarchar之间剪不断理还乱的关系
  16. android 远程调试工具,【教程】搭配Android studio,如何实现app远程真机debug...
  17. Mac制作win to go后的驱动文件
  18. 一念非凡之薛定谔:量子力学是本征值问题
  19. 启动AutoCAD Electrical提示“缺少缺少驱动程序AceRedist”的解决办法
  20. 如何限制同一客户端登录的用户数量以及禁止同一用户同时在不同客户端登录?

热门文章

  1. 转载:flash 跨域 crossdomain.xml
  2. 会员教程翻译:性能和时间
  3. 51nod 1115 最大M子段和 V3
  4. 百炼1001: Exponentiation 解题
  5. iOS App版本号compare
  6. B计划 第四周(开学第一周)
  7. mycat 编辑schema.xml
  8. 写给笨蛋徒弟的学习手册(1)——完整C#项目中各个文件含义
  9. AngularJS学习笔记(一)
  10. UNIX网络编程读书笔记:辅助数据