凌云时刻 · 技术

导读:这篇笔记我们来看看机器学习中一个重要的非参数学习算法,决策树。

作者 | 计缘

来源 | 凌云时刻(微信号:linuxpk)

什么是决策树

上图是一个向银行申请信用卡的示例,图中的树状图展示了申请人需要在银行过几道关卡后才能成功申请到一张信用卡的流程图。在图中树状图的根节点是申请人输入的信息,叶子节点是银行作出的决策,也就相当于是对申请者输入信息作出的分类决策。从第一个根节点到最后一个叶子节点经过的根节点数量称为树状图的深度(depth)。上图示例中的树状图从第一个根节点申请人是否办理过信用卡,到最后一个发放信用卡叶子节点共经过了三个根节点,所以深度为3。那么像这样使用树状图对输入信息一步步分类的方式就称为决策树方式。

我们再来看一个问题,上图中每一个根节点的输入信息都可以用来做判断分类,但是机器学习的样本数据都是数字,那么此时如果做判断呢?我们先来使用Scikit Learn中提供的决策树直观的看一下通过决策树对样本数据的分类过程和分类结果。

# 我们使用Scikit Learn提供的鸢尾花数据集
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets# 为了绘图方便,我们只使用鸢尾花数据的后两个特征
iris = datasets.load_iris()
X = iris.data[:, 2:]
y = iris.target# 将样本点绘制出来
plt.scatter(X[y==0, 0], X[y==0, 1])
plt.scatter(X[y==1, 0], X[y==1, 1])
plt.scatter(X[y==2, 0], X[y==2, 1])
plt.show()

# 引入Scikit Learn中的决策树分类器
from sklearn.tree import DecisionTreeClassifier# max_depth参数是决策树深度,criterion参数这里我们先不用管
dt_clf = DecisionTreeClassifier(max_depth=2, criterion="entropy")
dt_clf.fit(X, y)def plot_decision_boundary(model, axis):# meshgrid函数用两个坐标轴上的点在平面上画格,返回坐标矩阵X0, X1 = np.meshgrid(# 随机两组数,起始值和密度由坐标轴的起始值决定np.linspace(axis[0], axis[1], int((axis[1] - axis[0]) * 100)).reshape(-1, 1),np.linspace(axis[2], axis[3], int((axis[3] - axis[2]) * 100)).reshape(-1, 1),)# ravel()方法将高维数组降为一维数组,c_[]将两个数组以列的形式拼接起来,形成矩阵X_grid_matrix = np.c_[X0.ravel(), X1.ravel()]# 通过训练好的逻辑回归模型,预测平面上这些点的分类y_predict = model.predict(X_grid_matrix)y_predict_matrix = y_predict.reshape(X0.shape)# 设置色彩表from matplotlib.colors import ListedColormapmy_colormap = ListedColormap(['#EF9A9A', '#40E0D0', '#FFFF00'])# 绘制等高线,并且填充等高区域的颜色plt.contourf(X0, X1, y_predict_matrix, linewidth=5, cmap=my_colormap)plot_decision_boundary(dt_clf, axis=[0.5, 7.5, 0, 3])
plt.scatter(X[y==0, 0], X[y==0, 1])
plt.scatter(X[y==1, 0], X[y==1, 1])
plt.scatter(X[y==2, 0], X[y==2, 1])
plt.show()

从决策边界图中可以看到决策树分类器将鸢尾花的数据较好的进行的归类区分。那么再回到刚才的问题,那就是在对样本数据进行决策树分类时,根节点是怎么做判断的呢?

我们仔细看上图,假设我们将横轴表示的特征记为  ‍‍,纵轴表示的特征记为  ‍‍,那么红色的区域就是‍‍  的‍‍区域,所以第一个根节点就直接可定义为xx是否小于2.4,既样本数据的某一个特征数值是否小于2.4,如果小于2.4,那么将该样本点判定为蓝色分类的点。

然后再来看纵轴,观察可得,  ‍‍是蓝色和黄色区域的分界点,所以可将  ‍‍作为第二层的根节点,如果样本数据的另一个特征数值小于1.8,则判定为黄色分类的点,大于1.8则判定为绿色分类的点。那么此时的决策树为:

也就是决策树在对样本数据分类时,会先选定一个维度,或者说一个特征,再选定和这个维度想对应阈值构成一个根节点,既判断条件。

通过上面的示例,大家应该什么是决策树有了直观的了解,并且也能看出来决策树的一些特点:

  • 决策树算法是非参数学习算法。

  • 决策树算法可以解决分类问题。

  • 决策树算法可以天然解决多分类问题。

  • 决策树算法处理样本数据的结果具有非常好的可解释性。

了解了决策树后,那么问题来了,对于决策树来说核心的工作是确定根节点,那么这个根节点该如何确定呢?在哪个维度做划分?某个维度在哪个值上做划分?我们往后看。

信息熵

这一节我们先来看看什么是信息熵。其实信息熵的概念很简单,熵在信息论中代表随机变量不确定的度量

  • 熵越大,数据的不确定性越高。

  • 熵越小,数据的不确定性越低。

 信息熵的公式

上面解释了信息熵的定义,那么这一小节我们使用信息熵的公式来解释它的定义。

上面的公式就是香农提出的信息熵的公式。逐个解释一下:

  • 假如一组数据有k类信息,那么每一个信息所占的比例就是  ‍‍。比如鸢尾花数据包含三种鸢尾花的数据,那么每种鸢尾花所占的比例就是  ,那么‍‍‍‍  、‍‍  ‍‍、  ‍‍就分别为  。

  • 因为  ‍‍只可能是小于1的,所以  ‍‍始终是负数。所以需要在公式最前面加负号,让整个熵的值大于0。

我们来举几个例子看一下,首先用鸢尾花的例子,三种鸢尾花各占  :

那么代入信息熵的公式可得:

再来看一个例子:

‍‍

代入公式可得:

从上面两个例子可以看出,第二个例子的信息熵比一个例子的小,那么意味着第二个示例的数据不确定性要低于第一个示例的数据。其实从数据中也能看出,其中有一类信息占全部信息的  ,所以大多数据是能确定在某一类中的,故而不确定性低。而第一个示例中每类信息都占了全部信息‍‍‍‍的  ,所以数据不能很明确的确定是哪类,故而不确定性高。

再来看一个极端的例子,  ,‍‍将其代入信息熵公式后得到的值是0。因为整个数据中就一种类型的数据,所以不确定性达到了最低,既信息熵的最小值为0。

 信息熵曲线

这一小节我们在Jupyter Notebook中将信息熵的曲线绘制出来再让大家感性的理解一下。假设一组信息中有两个类别,那么当一个类别所占比例为  ‍‍时,另一个类别的所占比例肯定是‍‍  ,‍‍将其代入信息熵的公式展开后可‍‍‍‍‍‍‍‍‍‍‍‍得:

下面我们用代码来绘制一下这个曲线:

import numpy as np
import matplotlib.pyplot as pltdef entropy(p):return -p * np.log(p) - (1-p) * np.log(1-p)# 构建不同的比例,这里避免log(0)的计算,所以要避免p=0和1-p=0
x = np.linspace(0.01, 0.99, 200)# 将不同比例和不同信息熵的值绘制出来
plt.plot(x, entropy(x))
plt.show()

从图中可以看到,曲线是以横轴0.5为中心的抛物线。当比例xx为0.5时,纵轴达到最大值,也就是信息熵达到了最大值,表示此时变量不确定性最大。因为两个类别各以0.5的比例出现,很难确定是哪个类型,所以不确定性大。

当以0.5为中心向两边延展后,可以看到信息熵在逐渐减小,意味着不确定性在减小。因为此时要么有一个比例在减小,要么有一个比例在增大。这两种情况都可以表明其中一种类别的比例在增大,所以更容易确定是哪种类别,故而不确定性减小。

使用信息熵寻找最优划分

之前我们有带着两个问题进入到了信息熵小节。那么这一节就来解答这两个问题:

  • 决策树每个节点在哪个维度做划分?

  • 某个维度在哪个值上做划分?

那么我们要做的事情就是找到一个维度和一个阈值,使得通过该维度和阈值划分后的信息熵最低,此时这个划分才是最好的划分。

用大白话解释一下就是,我们在所有数据中寻找到信息熵最低的维度和阈值,然后将数据划分为多个部分,再寻找划分后每部分信息熵最低的维度和阈值,继续划分下去,最终形成完整的决策树。这一节就来看看如何使用信息熵寻找最优划分。

 划分函数

我们先定义一个划分函数,也就是相当于构建决策树根节点的作用:

import numpy as np# 四个参数分别为样本特征数据、样本目标数据、维度、阈值
def split(X, y, d, v):# 划分后左侧和右侧的索引数组index_l = (X[:, d] <= v)index_r = (X[:, d] > v)return X[index_l], X[index_r], y[index_l], y[index_r]

举个例子来看看上面的划分函数:

# 构建一个五行四列的样本数据
X = np.linspace(1, 10, 20)
X = X.reshape(5, 4)
y = 2 * X + 3X
# 结果
array([[  1.        ,   1.47368421,   1.94736842,   2.42105263],[  2.89473684,   3.36842105,   3.84210526,   4.31578947],[  4.78947368,   5.26315789,   5.73684211,   6.21052632],[  6.68421053,   7.15789474,   7.63157895,   8.10526316],[  8.57894737,   9.05263158,   9.52631579,  10.        ]])# 期望以第1个维度,既第1列特征为划分维度,以5为划分阈值
X_l, X_r, y_l, y_r = split(X, y, 1, 5)X_l
# 结果
array([[ 1.        ,  1.47368421,  1.94736842,  2.42105263],[ 2.89473684,  3.36842105,  3.84210526,  4.31578947]])# 如果以第0个维度,既第0列特征为划分维度,以5为划分阈值,那么X_l为三行四列矩阵,因为第0列的第三行值也小于5
X_l, X_r, y_l, y_r = split(X, y, 0, 5)X_l
# 结果
array([[ 1.        ,  1.47368421,  1.94736842,  2.42105263],[ 2.89473684,  3.36842105,  3.84210526,  4.31578947],[ 4.78947368,  5.26315789,  5.73684211,  6.21052632]])

 计算信息熵函数

然后定义计算信息熵的函数:

from collections import Counter
from math import log
# 计算信息熵时不关心样本特征,只关心样本目标数据的类别和每个类别的数量
def entropy(y):# 使用Counter生成字典,key为y的值,value为等于该值的元素数量counter_y = Counter(y)entropy_result = 0for num in counter_y.values():p = num / len(y)# 将所有类别的占比加起来,得到信息熵entropy_result += -p * log(p)return entropy_result

举个例子验证一下:

test_y = np.array([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3])
test_y1 = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])# test_y信息熵应该比较大,因为有4种类别
entropy(test_y)
# 结果
1.3421257227487469# test_y1信息熵应该比较小,因为有2种类别
entropy(test_y1)
# 结果
0.5982695885852573

 寻找维度和阈值函数

最后我们来定义寻找最优信息熵、维度、阈值的函数,该函数的最基本思路就是遍历样本数据的每个维度,对该维度中每两个相邻的值求均值作为阈值,然后求出信息熵,最终找到最小的信息熵,此时计算出该信息熵的维度和阈值既是最优维度和最优阈值:

# 寻找最优信息熵
def try_split(X, y):# 最优信息熵初始取无穷大best_entropy = float('inf')# 最优维度和最优阈值,初始值为-1best_d, best_v = -1, -1# 对样本特征数据的每个维度,既每个特征进行搜索for d in range(X.shape[1]):# 在d这个维度上,将每两个样本点中间的值作为阈值# 对样本数据在d这个维度上进行排序,返回排序索引sorted_index = np.argsort(X[:, d])# 遍历每行样本数据,注意从第一行开始,因为需要用上一行的值和该行的值求均值for row in range(1, len(X)):# 如果两个值相等,那么均值无法区分这两个值,所以忽略这种情况if X[sorted_index[row-1], d] != X[sorted_index[row], d]:v = (X[sorted_index[row-1], d] + X[sorted_index[row], d]) / 2# 使用split()函数做划分X_l, X_r, y_l, y_r = split(X, y, d, v)# 求划分后两部分的信息熵e = entropy(y_l) + entropy(y_r)# 保存最优信息熵、维度、阈值if e < best_entropy:best_entropy, best_d, best_v = e, d, vreturn best_entropy, best_d, best_v

我们同样使用鸢尾花的数据进行验证,在验证之前,我们先将之前通过Scikit Learn的决策树求出的鸢尾花决策边界的图贴出来,做以对比:

from sklearn import datasets
iris = datasets.load_iris()
X = iris.data[:, 2:]
y = iris.targetbest_entropy, best_d, best_v = try_split(X, y)
print("best_entropy = ", best_entropy)
print("best_d = ", best_d)
print("best_v = ", best_v)# 结果
best_entropy =  0.6931471805599453
best_d =  0
best_v =  2.45

此时样本数据的第一个根节点的判断条件就求出来了,从上面的决策边界图中可以看到,第一次划分确实是从第0个维度,既横轴开始,以2.4左右为阈值进行的。我们来看看通过这个根节点划分后的数据是什么样的:

X_l1, X_r1, y_l1, y_r1 = split(X, y, best_d, best_v)
entropy(y_l1)# 结果
0.0entropy(y_r1)# 结果
0.6931471805599453

从上面结果可以看到,经过第一个根节点划分后,对左侧数据求信息熵的结果为0,说明左侧的数据现在只有一个类型了。从上图也可以清晰的看到,以横轴2.4往左的区域全部是蓝色的点。那么对右侧而言,它的信息熵是0.69,说明对右侧还可以继续划分:

best_entropy2, best_d2, best_v2 = try_split(X_r1, y_r1)
print("best_entropy2 = ", best_entropy2)
print("best_d2 = ", best_d2)
print("best_v2 = ", best_v2)# 结果
best_entropy2 =  0.4132278899361904
best_d2 =  1
best_v2 =  1.75

第二次划分以第一个维度上的1.75为阈值进行,这和上图中的纵轴的划分界限基本是一致的。再来看看划分后的结果:

X_l2, X_r2, y_l2, y_r2 = split(X_r1, y_r1, best_d2, best_v2)
entropy(y_l2)# 结果
0.30849545083110386entropy(y_r2)# 结果
0.10473243910508653

可以看到第二次划分后,两部分的信息熵都不为零,其实还可以对每部分再进行划分。不过在之前使用Scikit Learn的决策树划分时,我们将深度设为了2,所以就只划分到了目前的阶段,如果将深度设的更大的话,那么就会继续划分下去。在我们模拟寻找最优划分的过程中,就不再继续划分下去了,大家理解了划分过程就可以了。

END

往期精彩文章回顾

机器学习笔记(二十八):高斯核函数

机器学习笔记(二十七):核函数(Kernel Function)

机器学习笔记(二十六):支撑向量机(SVM)(2)

机器学习笔记(二十五):支撑向量机(SVM)

机器学习笔记(二十四):召回率、混淆矩阵

机器学习笔记(二十三):算法精准率、召回率

机器学习笔记(二十二):逻辑回归中使用模型正则化

机器学习笔记(二十一):决策边界

机器学习笔记(二十):逻辑回归(2)

机器学习笔记(十九):逻辑回归

长按扫描二维码关注凌云时刻

每日收获前沿技术与科技洞见

机器学习笔记(二十九):决策树、信息熵相关推荐

  1. 【OpenGL】笔记二十九、抗锯齿(MSAA)

    1. 流程 经过之前的教程,我们目前渲染出来的画面已经有了足够的表现力,但是还是有一些缺陷,比如当我们的渲染画面分辨率跟不上屏幕分辨率时,在我们渲染的图形边缘一些比较严重的锯齿效果就会显现: 自然,这 ...

  2. (转载)机器学习知识点(二十九)LDA入门级学习笔记

    入门级学习笔记 1.1文本建模相关 统计文本建模的目的其实很简单:就是估算一组参数,这组参数使得整个语料库出现的概率最大.这是很简单的极大似然的思想了,就是认为观测到的样本的概率是最大的. 建模的目标 ...

  3. 机器学习笔记二十四 中文分词资料整理

    一.常见的中文分词方案 1. 基于字符串匹配(词典) 基于规则的常见的就是最大正/反向匹配,以及双向匹配. 规则里糅合一定的统计规则,会采用动态规划计算最大的概率路径的分词. 以上说起来很简单,其中还 ...

  4. 机器学习笔记(十九)——最大熵原理和模型定义

    一.最大熵原理 最大熵原理是概率模型学习的一个准则.最大熵原理认为,在学习概率模型时,在所有可能的概率分布中,熵最大的模型是最好的模型.通常用约束条件来确定概率模型的集合,所以,最大熵模型也可以表述为 ...

  5. C++语法学习笔记二十九: 详解decltype含义,decltype主要用途

    实例代码 // 详解decltype含义,decltype主要用途#include <iostream> #include <functional> #include < ...

  6. 嵌入式Linux驱动笔记(二十九)------内存管理之伙伴算法(Buddy)分析

    你好!这里是风筝的博客, 欢迎和我一起交流. 我们知道,在一个通用操作系统里,频繁申请内存释放内存都会出现一个非常著名的内存管理问题:内存碎片. 学过操作系统的都知道,有很多行之有效的方法(比如:记录 ...

  7. 吴恩达机器学习(二十九)大规模机器学习

    目录 1.随机梯度下降 2.Mini-Batch梯度下降 3.随机梯度下降收敛 4.减少映射与数据并行 1.随机梯度下降   对很多机器学习算法,例如线性回归.逻辑回归和神经网络,推导算法的方法是提出 ...

  8. opencv学习笔记二十九:SIFT特征点检测与匹配

    SIFT(Scale-invariant feature transform)是一种检测局部特征的算法,该算法通过求一幅图中的特征点(interest points,or corner points) ...

  9. 机器学习(二十九)——Temporal-Difference Learning

    https://antkillerfarm.github.io/ Temporal-Difference Learning(续) TD vs. MC-3 再来看如下示例: 已现有两个状态(A和B),M ...

  10. Mr.J-- jQuery学习笔记(二十九)--属性操作方法(获取属性判断)

    获取 attr() <span class="span1" name="it666"></span> <span class=&q ...

最新文章

  1. TextView中文字实现跑马灯
  2. 一文带你理解Java中Lock的实现原理
  3. 招了一大群学生的游戏代码
  4. MATLAB如何进行系统辨识(传递函数)
  5. /usr/include/stdio.h:27:10: fatal error: bits/libc-header-start.h: No such file or directory 报错解决
  6. 5.7 Components — Sending Actions From Components to Your Application
  7. 灵感加油站|当设计师没有灵感时怎么办?
  8. 【推荐算法】知识驱动的智能化推荐算法(附交流视频和PPT下载链接)
  9. 【pandas】dataframe根据某列是否是null筛选数据
  10. Kafka技术资料总结(不断更新中)
  11. OpenCV-Python教程8-图像混合
  12. Linux进程、线程、任务调度(1)贵在坚持
  13. 又一个Python数据分析学习利器!
  14. WIFI 信道 channel
  15. 【转】框架(蔡学镛)
  16. 中国银联移动支付技术规范
  17. 【GoCN酷Go推荐】protobuf生成Go代码插件gogo/protobuf
  18. 聊聊什么是自动化测试,什么是自动化测试框架
  19. 那些漂亮有创意的思维导图真的更吸引人吗?
  20. JavaScript高级(二)|函数进阶+正则表达式

热门文章

  1. JBoss AS 7中Domain Mode 和 Standalone Mode
  2. spring-第六篇之创建bean的3种方式
  3. 洛谷 P1168 中位数(优先队列)
  4. BZOJ1196 [HNOI2006]公路修建问题 【二分 + Kruskal】
  5. .10-浅析webpack源码之graceful-fs模块
  6. unity3d在菜单栏,一键设置Player setting及自动打包并设置apk的存储位置
  7. 为什么有人把《海贼王》当作人生信条
  8. linux 中的快捷键
  9. HDU 4050 wolf5x 概率dp 难度:1
  10. [WCF编程]12.事务:事务概述