一、CART决策树算法简介

CART(Classification And Regression Trees 分类回归树)算法是一种树构建算法,既可以用于分类任务,又可以用于回归。相比于 ID3 和 C4.5 只能用于离散型数据且只能用于分类任务,CART 算法的适用面要广得多,既可用于离散型数据,又可以处理连续型数据,并且分类和回归任务都能处理。

本文仅讨论基本的CART分类决策树构建,不讨论回归树和剪枝等问题。

首先,我们要明确以下几点:
1. CART算法是二分类常用的方法,由CART算法生成的决策树是二叉树,而 ID3 以及 C4.5 算法生成的决策树是多叉树,从运行效率角度考虑,二叉树模型会比多叉树运算效率高。
2. CART算法通过基尼(Gini)指数来选择最优特征。

二、基尼系数

基尼系数代表模型的不纯度,基尼系数越小,则不纯度越低,注意这和 C4.5的信息增益比的定义恰好相反。

分类问题中,假设有K个类,样本点属于第k类的概率为pk,则概率分布的基尼系数定义为:

若CART用于二类分类问题(不是只能用于二分类),那么概率分布的基尼系数可简化为


假设使用特征 A 将数据集 D 划分为两部分 D1 和 D2,此时按照特征 A 划分的数据集的基尼系数为:

三、CART决策树生成算法

输入:训练数据集D,停止计算的条件
输出:CART决策树
根据训练数据集,从根结点开始,递归地对每个结点进行以下操作,构建二叉决策树:
(1)计算现有特征对该数据集的基尼指数,如上面所示;
(2)选择基尼指数最小的值对应的特征为最优特征,对应的切分点为最优切分点(若最小值对应的特征或切分点有多个,随便取一个即可);
(3)按照最优特征和最优切分点,从现结点生成两个子结点,将训练数据集中的数据按特征和属性分配到两个子结点中;
(4)对两个子结点递归地调用(1)(2)(3),直至满足停止条件。
(5)生成CART树。
算法停止的条件:结点中的样本个数小于预定阈值,或样本集的基尼指数小于预定阈值(样本基本属于同一类,如完全属于同一类则为0),或者特征集为空。

注:最优切分点是将当前样本下分为两类(因为我们要构造二叉树)的必要条件。对于离散的情况,最优切分点是当前最优特征的某个取值;对于连续的情况,最优切分点可以是某个具体的数值。具体应用时需要遍历所有可能的最优切分点取值去找到我们需要的最优切分点。

四、CART算法的Python实现

from math import log# 构造数据集
def create_dataset():dataset = [['youth', 'no', 'no', 'just so-so', 'no'],['youth', 'no', 'no', 'good', 'no'],['youth', 'yes', 'no', 'good', 'yes'],['youth', 'yes', 'yes', 'just so-so', 'yes'],['youth', 'no', 'no', 'just so-so', 'no'],['midlife', 'no', 'no', 'just so-so', 'no'],['midlife', 'no', 'no', 'good', 'no'],['midlife', 'yes', 'yes', 'good', 'yes'],['midlife', 'no', 'yes', 'great', 'yes'],['midlife', 'no', 'yes', 'great', 'yes'],['geriatric', 'no', 'yes', 'great', 'yes'],['geriatric', 'no', 'yes', 'good', 'yes'],['geriatric', 'yes', 'no', 'good', 'yes'],['geriatric', 'yes', 'no', 'great', 'yes'],['geriatric', 'no', 'no', 'just so-so', 'no']]features = ['age', 'work', 'house', 'credit']return dataset, features# 计算当前集合的Gini系数
def calcGini(dataset):# 求总样本数num_of_examples = len(dataset)labelCnt = {}# 遍历整个样本集合for example in dataset:# 当前样本的标签值是该列表的最后一个元素currentLabel = example[-1]# 统计每个标签各出现了几次if currentLabel not in labelCnt.keys():labelCnt[currentLabel] = 0labelCnt[currentLabel] += 1# 得到了当前集合中每个标签的样本个数后,计算它们的p值for key in labelCnt:labelCnt[key] /= num_of_exampleslabelCnt[key] = labelCnt[key] * labelCnt[key]# 计算Gini系数Gini = 1 - sum(labelCnt.values())return Gini# 提取子集合
# 功能:从dataSet中先找到所有第axis个标签值 = value的样本
# 然后将这些样本删去第axis个标签值,再全部提取出来成为一个新的样本集
def create_sub_dataset(dataset, index, value):sub_dataset = []for example in dataset:current_list = []if example[index] == value:current_list = example[:index]current_list.extend(example[index + 1 :])sub_dataset.append(current_list)return sub_dataset# 将当前样本集分割成特征i取值为value的一部分和取值不为value的一部分(二分)
def split_dataset(dataset, index, value):sub_dataset1 = []sub_dataset2 = []for example in dataset:current_list = []if example[index] == value:current_list = example[:index]current_list.extend(example[index + 1 :])sub_dataset1.append(current_list)else:current_list = example[:index]current_list.extend(example[index + 1 :])sub_dataset2.append(current_list)return sub_dataset1, sub_dataset2def choose_best_feature(dataset):# 特征总数numFeatures = len(dataset[0]) - 1# 当只有一个特征时if numFeatures == 1:return 0# 初始化最佳基尼系数bestGini = 1# 初始化最优特征index_of_best_feature = -1# 遍历所有特征,寻找最优特征和该特征下的最优切分点for i in range(numFeatures):# 去重,每个属性值唯一uniqueVals = set(example[i] for example in dataset)# Gini字典中的每个值代表以该值对应的键作为切分点对当前集合进行划分后的Gini系数Gini = {}# 对于当前特征的每个取值for value in uniqueVals:# 先求由该值进行划分得到的两个子集sub_dataset1, sub_dataset2 = split_dataset(dataset,i,value)# 求两个子集占原集合的比例系数prob1 prob2prob1 = len(sub_dataset1) / float(len(dataset))prob2 = len(sub_dataset2) / float(len(dataset))# 计算子集1的Gini系数Gini_of_sub_dataset1 = calcGini(sub_dataset1)# 计算子集2的Gini系数Gini_of_sub_dataset2 = calcGini(sub_dataset2)# 计算由当前最优切分点划分后的最终Gini系数Gini[value] = prob1 * Gini_of_sub_dataset1 + prob2 * Gini_of_sub_dataset2# 更新最优特征和最优切分点if Gini[value] < bestGini:bestGini = Gini[value]index_of_best_feature = ibest_split_point = valuereturn index_of_best_feature, best_split_point# 返回具有最多样本数的那个标签的值('yes' or 'no')
def find_label(classList):# 初始化统计各标签次数的字典# 键为各标签,对应的值为标签出现的次数labelCnt = {}for key in classList:if key not in labelCnt.keys():labelCnt[key] = 0labelCnt[key] += 1# 将classCount按值降序排列# 例如:sorted_labelCnt = {'yes': 9, 'no': 6}sorted_labelCnt = sorted(labelCnt.items(), key = lambda a:a[1], reverse = True)# 下面这种写法有问题# sortedClassCount = sorted(labelCnt.iteritems(), key=operator.itemgetter(1), reverse=True)# 取sorted_labelCnt中第一个元素中的第一个值,即为所求return sorted_labelCnt[0][0]def create_decision_tree(dataset, features):# 求出训练集所有样本的标签# 对于初始数据集,其label_list = ['no', 'no', 'yes', 'yes', 'no', 'no', 'no', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'no']label_list = [example[-1] for example in dataset]# 先写两个递归结束的情况:# 若当前集合的所有样本标签相等(即样本已被分“纯”)# 则直接返回该标签值作为一个叶子节点if label_list.count(label_list[0]) == len(label_list):return label_list[0]# 若训练集的所有特征都被使用完毕,当前无可用特征,但样本仍未被分“纯”# 则返回所含样本最多的标签作为结果if len(dataset[0]) == 1:return find_label(label_list)# 下面是正式建树的过程# 选取进行分支的最佳特征的下标和最佳切分点index_of_best_feature, best_split_point = choose_best_feature(dataset)# 得到最佳特征best_feature = features[index_of_best_feature]# 初始化决策树decision_tree = {best_feature: {}}# 使用过当前最佳特征后将其删去del(features[index_of_best_feature])# 子特征 = 当前特征(因为刚才已经删去了用过的特征)sub_labels = features[:]# 递归调用create_decision_tree去生成新节点# 生成由最优切分点划分出来的二分子集sub_dataset1, sub_dataset2 = split_dataset(dataset,index_of_best_feature,best_split_point)# 构造左子树decision_tree[best_feature][best_split_point] = create_decision_tree(sub_dataset1, sub_labels)# 构造右子树decision_tree[best_feature]['others'] = create_decision_tree(sub_dataset2, sub_labels)return decision_tree# 用上面训练好的决策树对新样本分类
def classify(decision_tree, features, test_example):# 根节点代表的属性first_feature = list(decision_tree.keys())[0]# second_dict是第一个分类属性的值(也是字典)second_dict = decision_tree[first_feature]# 树根代表的属性,所在属性标签中的位置,即第几个属性index_of_first_feature = features.index(first_feature)# 对于second_dict中的每一个keyfor key in second_dict.keys():# 不等于'others'的keyif key != 'others':if test_example[index_of_first_feature] == key:# 若当前second_dict的key的value是一个字典if type(second_dict[key]).__name__ == 'dict':# 则需要递归查询classLabel = classify(second_dict[key], features, test_example)# 若当前second_dict的key的value是一个单独的值else:# 则就是要找的标签值classLabel = second_dict[key]# 如果测试样本在当前特征的取值不等于key,就说明它在当前特征的取值属于'others'else:# 如果second_dict['others']的值是个字符串,则直接输出if isinstance(second_dict['others'],str):classLabel = second_dict['others']# 如果second_dict['others']的值是个字典,则递归查询else:classLabel = classify(second_dict['others'], features, test_example)return classLabelif __name__ == '__main__':dataset, features = create_dataset()decision_tree = create_decision_tree(dataset, features)# 打印生成的决策树print(decision_tree)# 对新样本进行分类测试features = ['age', 'work', 'house', 'credit']test_example = ['midlife', 'yes', 'no', 'great']print(classify(decision_tree, features, test_example))

若是二分类问题,则函数calcGini和choose_best_feature可简化如下:

# 计算样本属于第1个类的概率p
def calcProbabilityEnt(dataset):numEntries = len(dataset)count = 0label = dataset[0][len(dataset[0]) - 1]for example in dataset:if example[-1] == label:count += 1probabilityEnt = float(count) / numEntriesreturn probabilityEntdef choose_best_feature(dataset):# 特征总数numFeatures = len(dataset[0]) - 1# 当只有一个特征时if numFeatures == 1:return 0# 初始化最佳基尼系数bestGini = 1# 初始化最优特征index_of_best_feature = -1for i in range(numFeatures):# 去重,每个属性值唯一uniqueVals = set(example[i] for example in dataset)# 定义特征的值的基尼系数Gini = {}for value in uniqueVals:sub_dataset1, sub_dataset2 = split_dataset(dataset,i,value)prob1 = len(sub_dataset1) / float(len(dataset))prob2 = len(sub_dataset2) / float(len(dataset))probabilityEnt1 = calcProbabilityEnt(sub_dataset1)probabilityEnt2 = calcProbabilityEnt(sub_dataset2)Gini[value] = prob1 * 2 * probabilityEnt1 * (1 - probabilityEnt1) + prob2 * 2 * probabilityEnt2 * (1 - probabilityEnt2)if Gini[value] < bestGini:bestGini = Gini[value]index_of_best_feature = ibest_split_point = valuereturn index_of_best_feature, best_split_point

五、运行结果

CART决策树算法的Python实现(注释详细)相关推荐

  1. CART决策树算法Python实现 (人工智能导论作业)

    文章目录 决策树的介绍 CART决策树算法简介 基尼指数 CART决策树生成算法及Python代码实现 决策树的介绍 决策树是以树的结构将决策或者分类过程展现出来,其目的是根据若干输入变量的值构造出一 ...

  2. 决策树算法和CART决策树算法详细介绍及其原理详解

    相关文章 K近邻算法和KD树详细介绍及其原理详解 朴素贝叶斯算法和拉普拉斯平滑详细介绍及其原理详解 决策树算法和CART决策树算法详细介绍及其原理详解 线性回归算法和逻辑斯谛回归算法详细介绍及其原理详 ...

  3. 【机器学习】Java 代码实现 CART 决策树算法

    文章目录 一.决策树算法 二.CART 决策树 三.Java 代码实现 3.1 TrainDataSet 3.2 DataType 3.3 PredictResult 3.4 CartDecision ...

  4. ID3决策树算法及其Python实现

    目录 一.决策树算法 基础理论 决策树的学习过程 ID3算法 二.实现针对西瓜数据集的ID3算法 实现代码 三.C4.5和CART的算法代码实现 C4.5算法 CART算法 总结 参考文章 一.决策树 ...

  5. CART决策树算法总结

    CART决策树算法,顾名思义可以创建分类树(classification)和回归树(regression). 1.分类树. 当CART决策树算法用于创建分类树时,和ID3和C4.5有很多相似之处,但是 ...

  6. cart算法_ID3、C4.5、CART决策树算法

    本文主要介绍的主要内容如下: 概念 ID3 决策树算法 C4.5 决策树算法 CART 决策树算法 1. 概念 1.1 信息熵 信息熵(Entropy),随机变量的不确定性,也称为"系统混乱 ...

  7. 决策树算法及Python 代码示例

    决策树是一种基于树形结构的算法,用于在一系列决策和结果之间建立模型.它通过对特征和目标变量之间的关系进行划分,来预测目标变量的值. 决策树算法示例: 假设我们有一组数据,其中包含天气,温度,湿度和是否 ...

  8. 机器学习强基计划2-2:一文详解ID3、C4.5、CART决策树算法+ Python实现

    目录 0 写在前面 1 什么是决策树? 2 决策树算法框架 3 常见决策树算法 3.1 ID3算法 3.2 C4.5算法 3.3 CART算法 4 Python实现三种决策树算法 4.1 数据集 4. ...

  9. 机器学习笔记(4)——ID3决策树算法及其Python实现

    决策树是一种基于树结构来进行决策的分类算法,我们希望从给定的训练数据集学得一个模型(即决策树),用该模型对新样本分类.决策树可以非常直观展现分类的过程和结果,一旦模型构建成功,对新样本的分类效率也相当 ...

  10. python决策树算法_决策树算法及python实现

    决策树算法是机器学习中的经典算法 1.决策树(decision tree) 决策树是一种树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别. 假设小明去看 ...

最新文章

  1. Python数据框结合lambda函数的使用
  2. 华人一作占半数,陶大程等人上榜,CVPR公布32篇最佳论文候选
  3. windows 7 unmountable boot volume 解决方法
  4. XAML数据绑定(Data Binding)
  5. python十二:字符串格式化
  6. 【Python】调用百度云API图像搜索服务
  7. CSS 设计模式一 元素
  8. html中显示数据库中的一条数据,如何使用html表显示数据库中的数据
  9. jdbc版本低MySQL版本高_Mysql JDBC驱动版本与Mysql版本的对应问题解决
  10. Spring DI依赖注入方式
  11. 【ESP8266】发送HTTP请求
  12. jquery 常用组件的小代码
  13. Object类中的wait()和notify()
  14. Win 95 使用技巧
  15. anaconda怎么切换目录_Anaconda更改工作路径
  16. 人脸识别算法一:特征脸方法(Eigenface)
  17. Hibernate使用原生SQL查询
  18. 【Beta阶段】第一次Scrum Meeting
  19. CVPR2022|稀疏融合稠密:通过深度补全实现高质量的3D目标检测
  20. Windows安装You-get详细教程和问题解决分享

热门文章

  1. 数学建模入门例题python_用Python分析支付宝轻定投收益--Python数学建模实例
  2. 23种设计模式之builder模式
  3. [CI、CD入门]maven打包可执行程序之微服务-服务提供者篇
  4. Python 蓝牙通信模块pybluez Win7
  5. 智能ABC输入法使用技巧
  6. python官方文档(自翻译)
  7. 又是灵格斯导致软件自动关闭
  8. 基于生物特征密钥生成研究 ------应用于区块链领域密钥的生成办法
  9. 常见的SQL优化面试题
  10. 易语言高级表格如何右击选择当前项,再弹出右击菜单?