本文整理自《Python机器学习》

决策树

决策树可视为数据从顶向下划分的一种方法,通常为二叉树。
通过决策树算法,从树根开始,基于可获得的最大信息增益(Information Gain, IG)的特征对数据进行划分。
目标函数能够在每次划分时实现对信息增益的最大化,其定义如下:
IG ( D p , f ) = I ( D p ) − ∑ j = 1 m N j N p I ( D j ) \text{IG}(D_p,f)=I(D_p)-\sum_{j=1}^m\frac{N_j}{N_p}I(D_j) IG(Dp​,f)=I(Dp​)−j=1∑m​Np​Nj​​I(Dj​)
其中 f f f为将要进行划分的特征, D p D_p Dp​与 D j D_j Dj​分别为父节点和第 j j j个子节点, I I I为不纯度衡量标准, N p N_p Np​为父节点中样本数量, N j N_j Nj​为第 j j j个子节点中样本的数量。上式即表示,信息增益是父节点的不纯度与所有子节点不纯度总和之差,子节点的不纯度越低,信息增益越大。
对于二叉树(scikit-learn中的实现方式)有:
IG ( D p , a ) = I ( D p ) − N l e f t N p I ( D l e f t ) − N r i g h t N p I ( D r i g h t ) \text{IG}(D_p,a)=I(D_p)-\frac{N_{left}}{N_p}I(D_{left})-\frac{N_{right}}{N_p}I(D_{right}) IG(Dp​,a)=I(Dp​)−Np​Nleft​​I(Dleft​)−Np​Nright​​I(Dright​)
二叉决策树主要有三类不纯度衡量标准。
熵(entropy):
I H ( t ) = − ∑ i = 1 c p ( i ∣ t ) log ⁡ 2 p ( i ∣ t ) I_H(t)=-\sum_{i=1}^cp(i|t)\log_2p(i|t) IH​(t)=−i=1∑c​p(i∣t)log2​p(i∣t)
基尼系数(Gini index):
I G ( t ) = 1 − ∑ i = 1 c p ( i ∣ t ) 2 I_G(t)=1-\sum_{i=1}^cp(i|t)^2 IG​(t)=1−i=1∑c​p(i∣t)2
误分类率(classification error)
I E = 1 − max ⁡ { p ( i ∣ t ) } I_E=1-\max\{p(i|t)\} IE​=1−max{p(i∣t)}
p ( i ∣ t ) p(i|t) p(i∣t)为特定节点 t t t中,属于类别 i i i的样本占特定节点 t t t中样本总数的比例。
实践中,基尼系数和熵会产生非常相似的效果,不会花大量时间用不纯度评判决策树的好坏,而尝试使用不同的剪枝算法,误分类率是对于剪枝方法的一个很好的准则但不建议用于决策树的构建。
样本属于类别1,概率介于[0,1]情况下三种不纯度的图像可由如下代码构建:

import matplotlib.pyplot as plt
import numpy as npdef gini(p):return (p)*(1-(p)) + (1-p)*(1-(1-p))def entropy(p):return -p*np.log2(p)-(1-p)*np.log2((1-p))def error(p):return 1-np.max([p, 1-p])x = np.arange(0, 1, 0.01)
giniVal=gini(x)
ent = [entropy(p) if p !=0 else None for p in x]
sc_ent = [e*0.5 if e else None for e in ent] # 按0.5比例缩放
err = [error(i) for i in x]
fig = plt.figure()
ax = plt.subplot(111)
for i, lab, ls, c in zip([ent, sc_ent, gini(x), err], ['Entropy', 'Entropy (scaled)', 'Gini Impurity', 'Missclassification Error'], ['-', '-', '--','-.'],['black','lightgray', 'red', 'green', 'cyan']):line = ax.plot(x, i, label=lab, linestyle=ls, lw=2, color=c)
ax.legend(loc='upper center', bbox_to_anchor=(0.5,1.15), ncol=3, fancybox=True, shadow=False)
ax.axhline(y=0.5, linewidth=1, color='k', linestyle='--') # horizon line
ax.axhline(y=1.0, linewidth=1, color='k', linestyle='--')
plt.ylim([0, 1.1])
plt.xlabel('p(i=1)')
plt.ylabel('Impurity Index')
plt.show()

所得结果如下:

使用scikit-learn中的决策树对鸢尾花进行分类

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviziris = datasets.load_iris()
X = iris.data[:, [2, 3]]
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.3,random_state=0)sc = StandardScaler()
sc.fit(X_train)
X_train_std = sc.transform(X_train)
X_test_std: object = sc.transform(X_test)def plot_decision_regions(X, y, classifier, test_idx=None, resolution=0.02):markers = ('s', 'x', 'o', '^', 'v')colors = ('red', 'blue', 'lightgreen', 'gray', 'cyan')cmap = ListedColormap(colors[:len(np.unique(y))])x1_min, x1_max = X[:, 0].min() - 1, X[:, 0].max() + 1x2_min, x2_max = X[:, 1].min() - 1, X[:, 1].max() + 1xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, resolution),np.arange(x2_min, x2_max, resolution))Z = classifier.predict(np.array([xx1.ravel(), xx2.ravel()]).T)Z = Z.reshape(xx1.shape)plt.contourf(xx1, xx2, Z, alpha=0.4, cmap=cmap)plt.xlim(xx1.min(), xx1.max())plt.ylim = (xx2.min(), xx2.max())X_test, y_test = X[test_idx, :], y[test_idx]for idx, cl in enumerate(np.unique(y)):plt.scatter(x=X[y == cl, 0], y=X[y == cl, 1], alpha=0.8, c=cmap(idx), marker=markers[idx], label=cl)if test_idx:X_test, y_test = X[test_idx, :], y[test_idx]plt.scatter(X_test[:, 0], X_test[:, 1], c='black', alpha=0.8, linewidths=1, marker='o', s=10, label='test set')tree = DecisionTreeClassifier(criterion='entropy', max_depth=3, random_state=0)
tree.fit(X_train, y_train)
X_combined=np.vstack((X_train, X_test))
y_combined=np.hstack((y_train, y_test))
plot_decision_regions(X_combined, y_combined,classifier=tree, test_idx=range(105, 150))
plt.xlabel('petal length [cm]')
plt.ylabel('petal width [cm]')
plt.legend(loc='upper left')
plt.show()export_graphviz(tree, out_file='tree.dot',feature_names=['petal length', 'petal width']) # 导出为dot文件

分类结果如下:

对于输出的tree.dot文件,我们可以通过GraphViz在命令行中输入指令

dot -Tpng tree.dot -o tree.png

转换为决策树的直观图片:

GraphViz可以在www.graphviz.org免费下载。

使用决策树对鸢尾花进行分类相关推荐

  1. 基于决策树对鸢尾花进行分类

    决策树的划分依据:信息增益 特征A对训练数据集D的信息增益g(D,A),定义为集合D的信息熵H(D)与特征A给定条件下D的信息条件熵H(D|A)之差,公式为: g(D,A) = H(D) - H(D| ...

  2. 决策树算法:对鸢尾花进行分类

    文章目录 一.什么是决策树 二.对鸢尾花进行分类 三.决策树可视化 三.决策树算法总结 一.什么是决策树 不用语言描述,直接看图 关于是否去上课,首先看的是是否能起床,然后课程,然后人数,然后是否点名 ...

  3. 机器学习:KNN算法对鸢尾花进行分类

    机器学习:KNN算法对鸢尾花进行分类 1.KNN算法的理解: 1.算法概述 KNN(K-NearestNeighbor)算法经常用来解决分类与回归问题, KNN算法的原理可以总结为"近朱者赤 ...

  4. 机器学习(9)决策树(决策树分类鸢尾花)

    目录 一.基础理论 二.决策树分类鸢尾花 API 1.读取数据 2.划分数据集 3.创建决策树预估器,训练 4.模型评估 方法一:比对法 方法二:计算错误率 代码 一.基础理论 决策树思想: 程序设计 ...

  5. python决策树分类鸢尾花_基于决策树—鸢尾花分类

    决策树算法广泛应用于:语音识别.医疗诊断.客户关系管理.模式识别.专家系统等,在实际工作中,必须根据数据类型的特点及数据集的大小,选择合适的算法. 本文选择经典案例--<鸢尾花分类> 一. ...

  6. BP神经网络对鸢尾花进行分类

    题目:BP神经网络分类器 一.实验项目: BP神经网络对鸢尾花进行分类 二.实验目的: 掌握BP神经网络学习算法,利用BP神经网络进行数据分类 三.实验内容: 1.编程实现BP神经网络算法 2.建立三 ...

  7. 【机器学习】决策树案例二:利用决策树进行鸢尾花数据集分类预测

    利用决策树进行鸢尾花数据集分类预测 2 利用决策树进行鸢尾花数据集分类预测 2.1 导入模块与加载数据 2.2 划分数据 2.3 模型创建与应用 2.4 模型可视化 手动反爬虫,禁止转载: 原博地址 ...

  8. 使用scikit-learn对鸢尾花进行分类

    我们可以使用scikit-learn训练感知器和逻辑斯谛模型以对鸢尾花进行分类,在这里我们使用三种鸢尾花,代码引自<python机器学习>. 使用线性分类的感知器的实现如下: from s ...

  9. 机器学习实践:基于支持向量机算法对鸢尾花进行分类

    摘要:List item使用scikit-learn机器学习包的支持向量机算法,使用全部特征对鸢尾花进行分类. 本文分享自华为云社区<支持向量机算法之鸢尾花特征分类[机器学习]>,作者:上 ...

最新文章

  1. Java数据库foreign,mysql中的外键foreign key 作者:Java_xb
  2. DNN数据库核心表结构及设计思路探研
  3. cdoj916-方老师的分身 III 【拓扑排序】
  4. 验算双中心重叠积分程序
  5. Mac-安装Homebrew报错error: could not lock config file .git/config:
  6. linux摄像头内核驱动开发,怎么在Linux下开发摄像头驱动
  7. 天气数据获取接口和网址汇总
  8. 菜鸟关于mvc导出Excel的想法
  9. HTML5 Canvas平移,放缩,旋转演示
  10. 自动驾驶模拟器Carla之python编程-(1)简介
  11. 服务器tomcat优化知识复习总结
  12. 迷你MVVM框架 avalonjs 学习教程6、插入移除处理
  13. 一年级下册健康教育教案
  14. MapGIS Mobile开发
  15. 「ds」网络操作系统和分布式操作系统之间的区别
  16. 人脸识别活体检测技术讨论:基于背景人脸相对运动的活体判断方法
  17. Leetcode力扣 MySQL数据库 1194 竞标赛优胜者
  18. SpringBoot导出txt文件
  19. 通过网线连接两台主机
  20. JAVA:实现PigeonholeSort鸽巢排序算法(附完整源码)

热门文章

  1. 尖叫吧!2015创新中国春季峰会 880元VIP门票免费送
  2. 最小点权覆盖集最大点权独立集
  3. 量子通信基础知识简介(一)
  4. 锐捷交换机,路由器,无线,ESS,EG所有操作配置命令合集
  5. Datatable 插件出现DataTable is not a function 错误
  6. 联通宽带开启 IPV6 的方法
  7. Flutter插件一野狗云实时通信
  8. 我想团:聚划算的反向电子商务实践
  9. 未来智安创始人兼CEO唐伽佳荣膺36氪X·36Under36 “S级创业者”
  10. HDU2567:寻梦