【机器学习基础】数学推导+纯Python实现机器学习算法4:决策树之ID3算法
Python机器学习算法实现
Author:louwill
作为机器学习中的一大类模型,树模型一直以来都颇受学界和业界的重视。目前无论是各大比赛各种大杀器的XGBoost、lightgbm还是像随机森林、Adaboost等典型集成学习模型,都是以决策树模型为基础的。传统的经典决策树算法包括ID3算法、C4.5算法以及GBDT的基分类器CART算法。
三大经典决策树算法最主要的区别在于其特征选择准则的不同。ID3算法选择特征的依据是信息增益、C4.5是信息增益比,而CART则是Gini指数。作为一种基础的分类和回归方法,决策树可以有如下两种理解方式。一种是我们可以将决策树看作是一组if-then规则的集合,另一种则是给定特征条件下类的条件概率分布。关于这两种理解方式,读者朋友可深入阅读相关教材进行理解,笔者这里补详细展开。
根据上述两种理解方式,我们既可以将决策树的本质视作从训练数据集中归纳出一组分类规则,也可以将其看作是根据训练数据集估计条件概率模型。整个决策树的学习过程就是一个递归地选择最优特征,并根据该特征对数据集进行划分,使得各个样本都得到一个最好的分类的过程。
ID3算法理论
所以这里的关键在于如何选择最优特征对数据集进行划分。答案就是前面提到的信息增益、信息增益比和Gini指数。因为本篇针对的是ID3算法,所以这里笔者仅对信息增益进行详细的表述。
在讲信息增益之前,这里我们必须先介绍下熵的概念。在信息论里面,熵是一种表示随机变量不确定性的度量方式。若离散随机变量X的概率分布为:
则随机变量X的熵定义为:
同理,对于连续型随机变量Y,其熵可定义为:
当给定随机变量X的条件下随机变量Y的熵可定义为条件熵H(Y|X):
所谓信息增益就是数据在得到特征X的信息时使得类Y的信息不确定性减少的程度。假设数据集D的信息熵为H(D),给定特征A之后的条件熵为H(D|A),则特征A对于数据集的信息增益g(D,A)可表示为:
g(D,A) = H(D) - H(D|A)
信息增益越大,则该特征对数据集确定性贡献越大,表示该特征对数据有较强的分类能力。信息增益的计算示例如下:
1).计算目标特征的信息熵。
2).计算加入某个特征之后的条件熵。
3).计算信息增益。
以上就是ID3算法的核心理论部分,至于如何基于ID3构造决策树,我们在代码实例中来看。
ID3算法实现
先读入示例数据集:
import numpy as np
import pandas as pd
from math import logdf = pd.read_csv('./example_data.csv')
df
定义熵的计算函数:
def entropy(ele): '''function: Calculating entropy value.input: A list contain categorical value.output: Entropy value.entropy = - sum(p * log(p)), p is a prob value.'''# Calculating the probability distribution of list valueprobs = [ele.count(i)/len(ele) for i in set(ele)] # Calculating entropy valueentropy = -sum([prob*log(prob, 2) for prob in probs]) return entropy
计算示例:
然后我们需要定义根据特征和特征值进行数据划分的方法:
def split_dataframe(data, col): '''function: split pandas dataframe to sub-df based on data and column.input: dataframe, column name.output: a dict of splited dataframe.'''# unique value of columnunique_values = data[col].unique() # empty dict of dataframeresult_dict = {elem : pd.DataFrame for elem in unique_values} # split dataframe based on column valuefor key in result_dict.keys():result_dict[key] = data[:][data[col] == key] return result_dict
根据temp和其三个特征值的数据集划分示例:
然后就是根据熵计算公式和数据集划分方法计算信息增益来选择最佳特征的过程:
def choose_best_col(df, label): '''funtion: choose the best column based on infomation gain.input: datafram, labeloutput: max infomation gain, best column,splited dataframe dict based on best column.'''# Calculating label's entropyentropy_D = entropy(df[label].tolist()) # columns list except labelcols = [col for col in df.columns if col not in [label]] # initialize the max infomation gain, best column and best splited dictmax_value, best_col = -999, Nonemax_splited = None# split data based on different columnfor col in cols:splited_set = split_dataframe(df, col)entropy_DA = 0for subset_col, subset in splited_set.items(): # calculating splited dataframe label's entropyentropy_Di = entropy(subset[label].tolist()) # calculating entropy of current featureentropy_DA += len(subset)/len(df) * entropy_Di # calculating infomation gain of current featureinfo_gain = entropy_D - entropy_DA if info_gain > max_value:max_value, best_col = info_gain, colmax_splited = splited_set return max_value, best_col, max_splited
最先选到的信息增益最大的特征是outlook:
决策树基本要素定义好后,我们即可根据以上函数来定义一个ID3算法类,在类里面定义构造ID3决策树的方法:
class ID3Tree: # define a Node classclass Node: def __init__(self, name):self.name = nameself.connections = {} def connect(self, label, node):self.connections[label] = node def __init__(self, data, label):self.columns = data.columnsself.data = dataself.label = labelself.root = self.Node("Root") # print tree methoddef print_tree(self, node, tabs):print(tabs + node.name) for connection, child_node in node.connections.items():print(tabs + "\t" + "(" + connection + ")")self.print_tree(child_node, tabs + "\t\t") def construct_tree(self):self.construct(self.root, "", self.data, self.columns) # construct treedef construct(self, parent_node, parent_connection_label, input_data, columns):max_value, best_col, max_splited = choose_best_col(input_data[columns], self.label) if not best_col:node = self.Node(input_data[self.label].iloc[0])parent_node.connect(parent_connection_label, node) returnnode = self.Node(best_col)parent_node.connect(parent_connection_label, node)new_columns = [col for col in columns if col != best_col] # Recursively constructing decision treesfor splited_value, splited_data in max_splited.items():self.construct(node, splited_value, splited_data, new_columns)
根据上述代码和示例数据集构造一个ID3决策树:
以上便是ID3算法的手写过程。sklearn中tree模块为我们提供了决策树的实现方式,参考代码如下:
from sklearn.datasets import load_iris
from sklearn import tree
import graphviziris = load_iris()
# criterion选择entropy,这里表示选择ID3算法
clf = tree.DecisionTreeClassifier(criterion='entropy', splitter='best')
clf = clf.fit(iris.data, iris.target)dot_data = tree.export_graphviz(clf, out_file=None,feature_names=iris.feature_names,class_names=iris.target_names,filled=True,rounded=True,special_characters=True)
graph = graphviz.Source(dot_data)
graph
以上便是本篇的全部内容,完整版代码和数据请移步本人github:
https://github.com/luwill/machine-learning-code-writing
参考资料:
李航 统计学习方法
https://github.com/heolin123/id3/blob/master
往期精彩:
数学推导+纯Python实现机器学习算法3:k近邻
数学推导+纯Python实现机器学习算法2:逻辑回归
数学推导+纯Python实现机器学习算法1:线性回归
往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑获取一折本站知识星球优惠券,复制链接直接打开:https://t.zsxq.com/yFQV7am本站qq群1003271085。加入微信群请扫码进群:
【机器学习基础】数学推导+纯Python实现机器学习算法4:决策树之ID3算法相关推荐
- 【机器学习基础】数学推导+纯Python实现机器学习算法30:系列总结与感悟
Python机器学习算法实现 Author:louwill Machine Learning Lab 终于到了最后的总结.从第一篇线性回归的文章开始到现在,已经接近有两年的时间了.当然,也不是纯写这3 ...
- 【机器学习基础】数学推导+纯Python实现机器学习算法24:HMM隐马尔可夫模型
Python机器学习算法实现 Author:louwill Machine Learning Lab HMM(Hidden Markov Model)也就是隐马尔可夫模型,是一种由隐藏的马尔可夫链随机 ...
- 【机器学习基础】数学推导+纯Python实现机器学习算法28:CRF条件随机场
Python机器学习算法实现 Author:louwill Machine Learning Lab 本文我们来看一下条件随机场(Conditional Random Field,CRF)模型.作为概 ...
- 【机器学习基础】数学推导+纯Python实现机器学习算法27:EM算法
Python机器学习算法实现 Author:louwill Machine Learning Lab 从本篇开始,整个机器学习系列还剩下最后三篇涉及导概率模型的文章,分别是EM算法.CRF条件随机场和 ...
- 【机器学习基础】数学推导+纯Python实现机器学习算法26:随机森林
Python机器学习算法实现 Author:louwill Machine Learning Lab 自从第14篇文章结束,所有的单模型基本就讲完了.而后我们进入了集成学习的系列,整整花了5篇文章的篇 ...
- 【机器学习基础】数学推导+纯Python实现机器学习算法25:CatBoost
Python机器学习算法实现 Author:louwill Machine Learning Lab 本文介绍GBDT系列的最后一个强大的工程实现模型--CatBoost.CatBoost与XGBoo ...
- 【机器学习基础】数学推导+纯Python实现机器学习算法24:LightGBM
Python机器学习算法实现 Author:louwill Machine Learning Lab 第17讲我们谈到了竞赛大杀器XGBoost,本篇我们来看一种比XGBoost还要犀利的Boosti ...
- 【机器学习基础】数学推导+纯Python实现机器学习算法23:kmeans聚类
Python机器学习算法实现 Author:louwill Machine Learning Lab 聚类分析(Cluster Analysis)是一类经典的无监督学习算法.在给定样本的情况下,聚类分 ...
- 【机器学习基础】数学推导+纯Python实现机器学习算法22:最大熵模型
Python机器学习算法实现 Author:louwill Machine Learning Lab 最大熵原理(Maximum Entropy Principle)是一种基于信息熵理论的一般原理,在 ...
- 【机器学习基础】数学推导+纯Python实现机器学习算法21:马尔可夫链蒙特卡洛...
Python机器学习算法实现 Author:louwill Machine Learning Lab 蒙特卡洛(Monte Carlo,MC)方法作为一种统计模拟和近似计算方法,是一种通过对概率模型随 ...
最新文章
- AppManager
- 使用next_permutation()的坑,你中招了么?
- 接口本地正常服务器报500_运维该如何解决服务器底层维护难题?
- ActiveMQ的Transport Connectors配置(六)
- 自定义LOG投递OSS数据Partition,优化你的计算
- 博文强识|进阶企业大咖
- Java中解决继承和接口默认方法冲突
- 09-Windows Server 2012 R2 会话远程桌面-标准部署-使用PowerShell进行部署2-2
- 如何阅读Cookbook技术书——如果我要把一本几百上千页的书从头读到尾,应该怎样有效阅读。...
- 如何直观理解拉格朗日乘子法与KKT条件
- 【路径规划】基于matlab无线充电车辆路径和速度预测【含Matlab源码 1473期】
- 福建高中计算机会考知识点,福建省高中信息技术会考《信息技术基础》复习提纲.doc...
- [Azure - VM] 解决办法:无法通过SSH连接VM,解决错误:This service allows sftp connections only.
- AFML读书笔记--Sample weight
- Java逐行读取fasta文件
- 图片编辑软件_pinta在Linux下安装
- MAC合约3.0API
- 如何安装适用于win11的安卓子系统(WSA)的谷歌框架安卓13版本
- 关于javascript的调试
- VB.net学习笔记(二十九)认识STA与MTA
热门文章
- flink批处理中的source以及sink介绍
- WPF窗口继承实现统一风格的自定义窗口
- python模块之keyword
- js函数中的参数的个数
- 2012 Stackoverflow meetup at Shanghai PRC
- ACE前摄器Proactor模式
- 六式建站浅见,和大家一起分享,不足之处还望斧正。
- 一文搞懂重复测量资料分析
- 2018年高教社杯全国大学生数学建模竞赛题目问题B 智能RGV的动态调度策略
- python采集修改原创_python应用系列教程——python中ftp操作:连接、登录、获取目录,重定向、上传下载,删除更改...