目录

1 决策树训练和可视化

2 做出预测

3 估算类别概率

4 CART训练算法

5 正则化超参数

6 回归

7 不稳定性


1 决策树训练和可视化

下面简单看一下例子:

常规模块的导入以及图像可视化的设置:

# Common imports
import numpy as np
import os# to make this notebook's output stable across runs
np.random.seed(42)# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier,export_graphviziris = load_iris()
X = iris.data[:, 2:] # petal length and width
y = iris.targettree_clf = DecisionTreeClassifier(max_depth=2, random_state=42)
tree_clf.fit(X, y)#可视化决策树
#网站显示结构:http://webgraphviz.com/
#http://dreampuf.github.io/GraphvizOnline/将dot文件内容复制该网站即可,等待一会出图
export_graphviz(tree_clf,out_file="iris1_tree.dot")

默认路径下打开iris1_tree.dot文件:

digraph Tree {
node [shape=box, fontname="helvetica"] ;
edge [fontname="helvetica"] ;
0 [label="X[0] <= 2.45\ngini = 0.667\nsamples = 150\nvalue = [50, 50, 50]"] ;
1 [label="gini = 0.0\nsamples = 50\nvalue = [50, 0, 0]"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="X[1] <= 1.75\ngini = 0.5\nsamples = 100\nvalue = [0, 50, 50]"] ;
0 -> 2 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
3 [label="gini = 0.168\nsamples = 54\nvalue = [0, 49, 5]"] ;
2 -> 3 ;
4 [label="gini = 0.043\nsamples = 46\nvalue = [0, 1, 45]"] ;
2 -> 4 ;
}

具体可视化步骤已在本篇博文中讲述:

机器学习(18)——分类算法(补充)_WHJ226的博客-CSDN博客

简单步骤如下:首先打开该网站Graphviz Online ,最后将dot文件内容复制粘贴左侧代码区即可。

效果如下:(另外pycharm中的插件也可以实现决策树可视化,不过目前上述方法还没出现问题就未曾探索)

决策树可视化

2 做出预测

假设我们找到了一朵鸢尾花,想要归类,那么从根节点(深度0,顶部)开始:这朵花的花瓣长度是否小于2.45CM?如果是,则向下移动到根的左侧子节点(深度1,左) 。在上面的例子中,这是一个叶节点(即没有任何子节点),所以不在继续,直接查看预测类别。

我们找到的另一朵鸢尾花,花瓣长度大于2.45厘米。这次我们需要移动到根节点的右侧子节点(深度1,右),由于该节点不是叶节点,所以它提出另一个问题:花瓣宽度是否小于1.75CM? 然后再做出预测。

节点的samples属性统计它应用的训练实例数量。例如,有100个训练实例的花瓣长度大于2.45cm(深度1,右),其中54个花瓣宽度小于1.75cm(深度2,左)。节点的value属性说明了该节点上每个类别的训练实例数量:例如,右下节点应用在0个Setoca鸢尾花、1个Versicolor鸢尾花和45个Virginica鸢尾花实例上。节点的gini属性衡量其不纯度:如果应用的所有训练实例都属于同一个类别,那么节点就是“纯”的(gini=0).例如,深度1左侧节点仅应用于Setoca鸢尾花训练实例,所以它就是纯的,并且gini=0。下面的基尼不纯度公式将说明第i个节点的基尼系数  的计算方式。

例如,深度2左侧节点,基尼系数等于  。

下图是决策树的决策边界。加粗直线表示根节点(深度0)的决策边界:花瓣长度=2.45厘米。因为左侧区域是纯的(只有Setoca鸢尾花),所以不可再分。右侧区域不是纯的,所以深度1右侧的节点在花瓣宽度=1.75厘米处(虚线表示)再次分离。此处最大深度max_depth 设置为2,所以决策树在此停止。但若max_depth 设置为3,那么两个深度为2的节点将各自再产生一条决策边界(点线表示)。

决策树的决策边界

代码实现如下:

from matplotlib.colors import ListedColormapdef plot_decision_boundary(clf, X, y, axes=[0, 7.5, 0, 3], iris=True, legend=False, plot_training=True):x1s = np.linspace(axes[0], axes[1], 100)x2s = np.linspace(axes[2], axes[3], 100)x1, x2 = np.meshgrid(x1s, x2s)X_new = np.c_[x1.ravel(), x2.ravel()]y_pred = clf.predict(X_new).reshape(x1.shape)custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=custom_cmap)if not iris:custom_cmap2 = ListedColormap(['#7d7d58','#4c4c7f','#507d50'])plt.contour(x1, x2, y_pred, cmap=custom_cmap2, alpha=0.8)if plot_training:plt.plot(X[:, 0][y==0], X[:, 1][y==0], "yo", label="Iris-Setosa")plt.plot(X[:, 0][y==1], X[:, 1][y==1], "bs", label="Iris-Versicolor")plt.plot(X[:, 0][y==2], X[:, 1][y==2], "g^", label="Iris-Virginica")plt.axis(axes)if iris:plt.xlabel("Petal length", fontsize=14)plt.ylabel("Petal width", fontsize=14)else:plt.xlabel(r"$x_1$", fontsize=18)plt.ylabel(r"$x_2$", fontsize=18, rotation=0)if legend:plt.legend(loc="lower right", fontsize=14)plt.figure(figsize=(8, 4))
plot_decision_boundary(tree_clf, X, y)
plt.plot([2.45, 2.45], [0, 3], "k-", linewidth=2)
plt.plot([2.45, 7.5], [1.75, 1.75], "k--", linewidth=2)
plt.plot([4.95, 4.95], [0, 1.75], "k:", linewidth=2)
plt.plot([4.85, 4.85], [1.75, 3], "k:", linewidth=2)
plt.text(1.40, 1.0, "Depth=0", fontsize=15)
plt.text(3.2, 1.80, "Depth=1", fontsize=13)
plt.text(4.05, 0.5, "(Depth=2)", fontsize=11)plt.show()

3 估算类别概率

决策树同样可以估算某个实例属于特定类别k的概率:首先,我们跟随决策树找到该实例的叶节点,然后返回该节点中类别k的训练实例占比。例如,我们发现一朵鸢尾花,花瓣长5厘米,宽1.5厘米。相应的叶节点为深度2左侧节点,因此决策树输出如下概率:Setoca鸢尾花,0%(0/54);Versicolor鸢尾花,90.7%(49/54);Virginica鸢尾花,9.3%(5/54)。

代码展示:

tree_clf.predict_proba([[5, 1.5]])

运行结果如下:

array([[0.        , 0.90740741, 0.09259259]])

预测类别:

tree_clf.predict([[5, 1.5]])

运行结果如下:

array([1])

4 CART训练算法

分类与回归树(Classification And Regression Tree,简称CART):首先,使用单个特征k和阈值 (例如,花瓣长度≤2.45厘米)将训练集分成两个子集。k和阈值 :是产生出最纯子集(受其大小加权)的 k和阈值 就是经算法搜索确定的(,)。算法尝试最小化的成本函数公式如下:

一旦成功将训练集一分为二,它将使用相同的逻辑,继续分裂子集,然后是子集的子集,依次循环递进。直到抵达最大深度(超参数 max_depth 控制) ,或是再也找不到能够降低不纯度的分裂,它才会停止。

注意,CART是一种贪婪算法:从顶层开始搜索最优分裂,然后每层重复这个过程。几层分裂之后,它不会检视这个分裂的不纯度是否为可能的最低值。通常会产生一个相当不错的解,但不能保证是最优解。

5 正则化超参数

决策树极少对训练数据作出假设。如果不加以限制,树的结构将跟随训练集变化,有可能出现过度拟合。这种模型通常被称为非参数模型,不是说它不包含任何参数,而是指在训练之前没有确定参数的数量,导致模型结构自由而紧密地贴近数据。未避免过度拟合,需要在训练过程中降低决策树的自由度,这个过程就是正则化。Scikit-Learn中,这由超参数 max_depth 控制。减小max_depth可使模型正则化,从而降低过度拟合的风险。

DecisionTreeClassifier 也有一些参数,可以限制决策树的形状:min_samples_split (分裂前节点必须有的最小样本数)、 min_samples_leaf(叶节点必须有的最小样本数量)、 min_weight_fraction_leaf(同min_samples_leaf,但表现为加权实例总数的占比)、 max_leaf_nodes(最大叶节点数量)、 max_features(分裂每个节点评估的最大特征数量)。 增大超参数min_* 或减小 max_* 将使模型正则化。

其实,我们还可以先不加约束的训练模型,然后再对不必要的节点进行删除。如果一个节点的子节点全部为叶节点,则该节点可被认为不必要,除非它所表示的纯度提升有重要的统计意义。标准统计测试,比如  测试,用来估算“提升纯粹是处于偶然”  (被称为虚假设)的概率。如果这个概率(称之为p值)高于一个给定阈值(通常为5%,超参数控制),那么这个节点可被认为不必要,其子节点可被删除。

下图显示的是在卫星数据集上训练的两个决策树。左图使用默认参数(无约束)来训练决策树,右图的决策树应用 min_samples_leaf=4 进行训练。

min_samples_leaf 正则化

很明显,左图模型过度拟合,右图泛化效果好。

代码实现如下:

from sklearn.datasets import make_moons
Xm, ym = make_moons(n_samples=100, noise=0.25, random_state=53)deep_tree_clf1 = DecisionTreeClassifier(random_state=42)
deep_tree_clf2 = DecisionTreeClassifier(min_samples_leaf=4, random_state=42)
deep_tree_clf1.fit(Xm, ym)
deep_tree_clf2.fit(Xm, ym)plt.figure(figsize=(11, 4))
plt.subplot(121)
plot_decision_boundary(deep_tree_clf1, Xm, ym, axes=[-1.5, 2.5, -1, 1.5], iris=False)
plt.title("No restrictions", fontsize=16)
plt.subplot(122)
plot_decision_boundary(deep_tree_clf2, Xm, ym, axes=[-1.5, 2.5, -1, 1.5], iris=False)
plt.title("min_samples_leaf = {}".format(deep_tree_clf2.min_samples_leaf), fontsize=14)plt.show()

6 回归

我们可以使用Scikit-Learn的 DecisionTreeRegressor 来构建一个回归树。下面我们在一个带有噪声的二次数据集上进行训练,其中max_depth = 2 :

# Quadratic training set + noise
np.random.seed(42)
m = 200
X = np.random.rand(m, 1)
y = 4 * (X - 0.5) ** 2
y = y + np.random.randn(m, 1) / 10from sklearn.tree import DecisionTreeRegressortree_reg = DecisionTreeRegressor(max_depth=2, random_state=42)
tree_reg.fit(X, y)export_graphviz(tree_reg,out_file="random_tree.dot")

结果如下:

这棵树与之前的差别在于,每个节点上不再是预测一个类别而是预测一个值。假设,我们想对一个x1=0.6的新实例进行预测,最后到达value=0.111的叶节点。该预测结果其实就是与这个叶节点关联的110个实例的平均目标值。在这110个实例上,预测产生的均方根误差为0.015。

下图显示了该模型的预测。如果设置max_depth=3,将得到右图预测。注意,每个区域的的预测值永远等于该区域内实例的目标平均值。算法分裂每个区域的方法,就是使最多的训练实例尽可能接近这个预测值。

两个决策树回归模型的对比

代码实现如下:

tree_reg1 = DecisionTreeRegressor(random_state=42, max_depth=2)
tree_reg2 = DecisionTreeRegressor(random_state=42, max_depth=3)
tree_reg1.fit(X, y)
tree_reg2.fit(X, y)def plot_regression_predictions(tree_reg, X, y, axes=[0, 1, -0.2, 1], ylabel="$y$"):x1 = np.linspace(axes[0], axes[1], 500).reshape(-1, 1)y_pred = tree_reg.predict(x1)plt.axis(axes)plt.xlabel("$x_1$", fontsize=18)if ylabel:plt.ylabel(ylabel, fontsize=18, rotation=0)plt.plot(X, y, "b.")plt.plot(x1, y_pred, "r.-", linewidth=2, label=r"$\hat{y}$")plt.figure(figsize=(11, 4))
plt.subplot(121)
plot_regression_predictions(tree_reg1, X, y)
for split, style in ((0.1973, "k-"), (0.0917, "k--"), (0.7718, "k--")):plt.plot([split, split], [-0.2, 1], style, linewidth=2)
plt.text(0.21, 0.65, "Depth=0", fontsize=15)
plt.text(0.01, 0.2, "Depth=1", fontsize=13)
plt.text(0.65, 0.8, "Depth=1", fontsize=13)
plt.legend(loc="upper center", fontsize=18)
plt.title("max_depth=2", fontsize=14)plt.subplot(122)
plot_regression_predictions(tree_reg2, X, y, ylabel=None)
for split, style in ((0.1973, "k-"), (0.0917, "k--"), (0.7718, "k--")):plt.plot([split, split], [-0.2, 1], style, linewidth=2)
for split in (0.0458, 0.1298, 0.2873, 0.9040):plt.plot([split, split], [-0.2, 1], "k:", linewidth=1)
plt.text(0.3, 0.5, "Depth=2", fontsize=13)
plt.title("max_depth=3", fontsize=14)plt.show()

CART算法的工作原理跟前面介绍的大致相同,唯一不同在于,它分类训练集的方式不是最小化纯度,而是最小化MSE。下面的公式为该算法尝试最小化的成本函数。

CART回归成本函数

前面我们曾说过决策树过度拟合。如果没有任何正则化(即使用默认超参数),我们将得到下图所示的预测结果,显然左图出现了过度拟合。 我们可以通过设置 min_samples_leaf=10 ,得到一个看起来合理的模型,右图所示。

回归决策树正则化

代码实现如下:

tree_reg1 = DecisionTreeRegressor(random_state=42)
tree_reg2 = DecisionTreeRegressor(random_state=42, min_samples_leaf=10)
tree_reg1.fit(X, y)
tree_reg2.fit(X, y)x1 = np.linspace(0, 1, 500).reshape(-1, 1)
y_pred1 = tree_reg1.predict(x1)
y_pred2 = tree_reg2.predict(x1)plt.figure(figsize=(11, 4))plt.subplot(121)
plt.plot(X, y, "b.")
plt.plot(x1, y_pred1, "r.-", linewidth=2, label=r"$\hat{y}$")
plt.axis([0, 1, -0.2, 1.1])
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", fontsize=18, rotation=0)
plt.legend(loc="upper center", fontsize=18)
plt.title("No restrictions", fontsize=14)plt.subplot(122)
plt.plot(X, y, "b.")
plt.plot(x1, y_pred2, "r.-", linewidth=2, label=r"$\hat{y}$")
plt.axis([0, 1, -0.2, 1.1])
plt.xlabel("$x_1$", fontsize=18)
plt.title("min_samples_leaf={}".format(tree_reg2.min_samples_leaf), fontsize=14)plt.show()

7 不稳定性

你可能注意到,决策树青睐正交的决策边界(所有分裂都与轴线垂直),因此他们对训练集的旋转非常敏感。下图是一个简单的线性可分离数据集:左图中决策树可以轻松分裂,右图中,数据集旋转了45°后,决策边界产生了不必要的卷曲。这也导致右侧模型可能泛化不佳。限制这种问题的方法之一就是用PCA。

代码实现如下:

np.random.seed(6)
Xs = np.random.rand(100, 2) - 0.5
ys = (Xs[:, 0] > 0).astype(np.float32) * 2angle = np.pi / 4
rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])
Xsr = Xs.dot(rotation_matrix)tree_clf_s = DecisionTreeClassifier(random_state=42)
tree_clf_s.fit(Xs, ys)
tree_clf_sr = DecisionTreeClassifier(random_state=42)
tree_clf_sr.fit(Xsr, ys)plt.figure(figsize=(11, 4))
plt.subplot(121)
plot_decision_boundary(tree_clf_s, Xs, ys, axes=[-0.7, 0.7, -0.7, 0.7], iris=False)
plt.subplot(122)
plot_decision_boundary(tree_clf_sr, Xsr, ys, axes=[-0.7, 0.7, -0.7, 0.7], iris=False)plt.show()

更概括的说,决策树的主要问题是它们对训练数据中的小变化非常敏感。例如,我们从鸢尾花数据集中移除花瓣最宽的Versicolor鸢尾花(花瓣长4.8厘米,宽1.8厘米),然后我们重新训练一个决策树,将得到下图模型。

事实上,由于Scikit-Learn所使用的算法是随机的,如果我们想要得到相同的模型,需要对超参数random_state 进行设置。

学习笔记——《机器学习实战:基于Scikit-Learn和TensorFlow》

机器学习实战(6)——决策树相关推荐

  1. 机器学习实战之决策树熵的概述

    机器学习实战之决策树熵的概述 一.决策树简介 二.决策树的一般流程 三.决策树构建的准备工作 1.特征选择 (1)香农熵 (2)编写代码计算经验熵 (3)信息增益 (4)编写代码计算信息增益 2.决策 ...

  2. 机器学习实战——绘制决策树(代码)

    最近在学习Peter Harrington的<机器学习实战>,代码与书中的略有不同,但可以顺利运行. import matplotlib.pyplot as plt# 定义文本框和箭头格式 ...

  3. 机器学习实战之决策树(一)构造决策树

    决策树(一)构造决策树 1.简介 1.1 优缺点 1.2 流程 1.3 决策树的构造 1.4 海洋生物数据 2.信息增益 2.1 几个概念 2.2 计算给定数据集的熵 3 划分数据集 3.1 按照给定 ...

  4. 刻意练习:机器学习实战 -- Task01. 决策树

    背景 这是我们为拥有 Python 基础的同学推出的精进技能的"机器学习实战" 刻意练习活动,这也是我们本学期推出的第三次活动了. 我们准备利用8周时间,夯实机器学习常用算法,完成 ...

  5. 《机器学习实战》—— 决策树

    目录 一.决策树的构造 1. 信息增益 2. 划分数据集 3. 递归构建决策树 二.在 Python 中使用 Matplotlib 注解绘制树形图 1. Matplotlib 2. 构造注解树 三.测 ...

  6. 《机器学习实战》——决策树

    一 决策树 决策树是什么?决策树(decision tree)是一种基本的分类与回归方法.举个通俗易懂的例子,如下图所示的流程图就是一个决策树,长方形代表判断模块(decision block),椭圆 ...

  7. 机器学习实战之决策树

    你是否玩过二十个问题的游戏,游戏的规则很简单:参与游戏的一方在脑海里想某个事物,其他参与者向他提问题,只允许提20个问题,问题的答案也只能用对或错回答.问问题的人通过 推断分解,逐步缩小待猜测事物的范 ...

  8. 《机器学习实战》决策树的应用

    课本中给出的是一个预测隐形眼镜的例子. 数据集样式如下: young    myope    no    reduced    no lenses young    myope    no    no ...

  9. 机器学习实战之决策树(四)示例:预测隐形眼镜类型(含数据集)

    决策树(四)示例:预测隐形眼镜类型 流程 代码 决策树小结 转载请注明作者和出处:https://blog.csdn.net/weixin_45814668 微信公众号:qiongjian0427 知 ...

  10. 机器学习实战(二)决策树DT(Decision Tree、ID3算法)

    目录 0. 前言 1. 信息增益(ID3) 2. 决策树(Decision Tree) 3. 实战案例 3.1. 隐形眼镜案例 3.2. 存储决策树 3.3. 决策树画图表示 学习完机器学习实战的决策 ...

最新文章

  1. Centos5.6 VNC安装配置【无错版】
  2. 比特币现金(BCH)和比特币(BTC)之争到底在争些什么?
  3. 【cs229-Lecture7】支持向量机(SVM)
  4. 极光商智®服务器2007今日正式发布
  5. 为你总结了N个真实线上故障,从容应对面试官!
  6. IOS背景半透明渐变问题
  7. Java 中 finally 与 return 的执行顺序详解
  8. sql 时态表的意义_在SQL Server 2016中拉伸时态历史记录表
  9. GetConsoleWindow was not declared in this scope
  10. Spring+SpringMVC+MyBatis整合基础篇
  11. 安卓Toast显示提示消息(自定义view,根据子线程消息显示提示)
  12. 2015年 安防圈的明星代言人有哪些?
  13. mysql 分页 pageindex_mysql 超1亿数据,优化分页查询
  14. mkdir用大括号同时建立多个同级和下级目录
  15. docker常用到的一些命令
  16. 我爱淘冲刺阶段站立会议2每天任务2
  17. 华硕T100 安装linux,华硕T100重装win10系统教程
  18. 激光雷达的应用及发展前景
  19. 大数据处理技术的总结与分析
  20. 如何安装node.js

热门文章

  1. 计算机网络之物理层(理论附带题目)
  2. vs2015+qt国际化翻译问题:Linguist中源代码不可见
  3. Qt Linguist(语言家)与QtCreator集成
  4. 方寸知识篇 — 芯片的失效机理
  5. 无线充电技术究竟有何神秘之处?一篇文章带你读懂什么是无线充电技术
  6. linux伤硬盘,硬盘安装linux
  7. python 列表删除元素
  8. JAVA租房网站计算机毕业设计Mybatis+系统+数据库+调试部署
  9. 看图工具 -- 蓝湖 Axure 墨刀
  10. Ubantu错误汇总