编程实现基于信息熵进行划分选择的决策树算法,并为表4.3中数据生成一棵决策树
import numpy as np
import pandas as pd
from sklearn.utils.multiclass import type_of_target
from decision_tree import treePlottterclass Node(object):def __init__(self):self.feature_name = None # 特性的名称self.feature_index = None # 特性的下标self.subtree = {} #树节点的集合self.impurity = None #信息此节点的信息增益self.is_continuous = False #是否为连续值self.split_value = None #连续值时的划分依据self.is_leaf = False #是否为叶子节点self.leaf_class = None #叶子节点对应的类self.leaf_num = 0 # 叶子数目self.high = -1 # 树的高度def entroy(y):p = pd.value_counts(y) / y.shape[0] # 计算各类样本所占比率ent = np.sum(-p * np.log2(p))return entreturn node
def info_gain(feature, y, entD, is_continuous=False):'''计算信息增益------:param feature: 当前特征下所有样本值:param y: 对应标签值:return: 当前特征的信息增益, list类型,若当前特征为离散值则只有一个元素为信息增益,若为连续值,则第一个元素为信息增益,第二个元素为切分点'''m = y.shape[0]unique_value = pd.unique(feature)if is_continuous:unique_value.sort() # 排序, 用于建立分割点split_point_set = [(unique_value[i] + unique_value[i + 1]) / 2 for i in range(len(unique_value) - 1)]min_ent = float('inf') # 挑选信息熵最小的分割点min_ent_point = Nonefor split_point_ in split_point_set:Dv1 = y[feature <= split_point_]Dv2 = y[feature > split_point_]feature_ent_ = Dv1.shape[0] / m * entroy(Dv1) + Dv2.shape[0] / m * entroy(Dv2)if feature_ent_ < min_ent:min_ent = feature_ent_min_ent_point = split_point_gain = entD - min_entreturn [gain, min_ent_point]else:feature_ent = 0for value in unique_value:Dv = y[feature == value] # 当前特征中取值为 value 的样本,即书中的 D^{v}feature_ent += Dv.shape[0] / m * entroy(Dv)gain = entD - feature_ent # 原书中4.2式return [gain]
def choose_best_feature_infogain(X, y):'''以返回值中best_info_gain 的长度来判断当前特征是否为连续值,若长度为 1 则为离散值,若长度为 2 , 则为连续值:param X: 当前所有特征的数据 pd.DaraFrame格式:param y: 标签值:return: 以信息增益来选择的最佳划分属性,第一个返回值为属性名称,'''features = X.columnsbest_feature_name = Nonebest_info_gain = [float('-inf')]entD = entroy(y)for feature_name in features:is_continuous = type_of_target(X[feature_name]) == 'continuous'infogain = info_gain(X[feature_name], y, entD, is_continuous)if infogain[0] > best_info_gain[0]:best_feature_name = feature_namebest_info_gain = infogainreturn best_feature_name, best_info_gain
def generate(X,y,columns):node = Node()# Pandas.Series.nunique()统计不同值的个数if y.nunique() == 1: # 属于同一类别node.is_leaf = Truenode.leaf_class = y.values[0]node.high = 0node.leaf_num += 1return nodeif X.empty: # 特征用完了,数据为空,返回样本数最多的类node.is_leaf = Truenode.leaf_class = pd.value_counts(y).index[0] # 返回样本数最多的类node.high = 0node.leaf_num += 1return nodebest_feature_name, best_impurity = choose_best_feature_infogain(X, y)node.feature_name = best_feature_namenode.impurity = best_impurity[0]node.feature_index = columns.index(best_feature_name)feature_values = X.loc[:, best_feature_name]if len(best_impurity) == 1: # 离散值node.is_continuous = Falseunique_vals = pd.unique(feature_values)sub_X = X.drop(best_feature_name, axis=1)max_high = -1for value in unique_vals:node.subtree[value] = generate(sub_X[feature_values == value], y[feature_values == value],columns)if node.subtree[value].high > max_high: # 记录子树下最高的高度max_high = node.subtree[value].highnode.leaf_num += node.subtree[value].leaf_numnode.high = max_high + 1elif len(best_impurity) == 2: # 连续值node.is_continuous = Truenode.split_value = best_impurity[1]up_part = '>= {:.3f}'.format(node.split_value)down_part = '< {:.3f}'.format(node.split_value)node.subtree[up_part] = generate(X[feature_values >= node.split_value],y[feature_values >= node.split_value],columns)node.subtree[down_part] = generate(X[feature_values < node.split_value],y[feature_values < node.split_value],columns)node.leaf_num += (node.subtree[up_part].leaf_num + node.subtree[down_part].leaf_num)node.high = max(node.subtree[up_part].high, node.subtree[down_part].high) + 1return nodeif __name__ == "__main__":data = pd.read_csv("西瓜3.0.txt", index_col=0) # index_col参数设置第一列作为index#不带第一列,求得西瓜的属性x = data.iloc[:, :8] #<class 'pandas.core.frame.DataFrame'>y = data.iloc[:, 8] #<class 'pandas.core.series.Series'>columns_name = list(x.columns) # 包括原数据的列名node = generate(x,y,columns_name)
treePlottter.create_plot(node)
另一个画图算法,这是另一个py文件
from matplotlib import pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
decision_node = dict(boxstyle='round,pad=0.3', fc='#FAEBD7')
leaf_node = dict(boxstyle='round,pad=0.3', fc='#F4A460')
arrow_args = dict(arrowstyle="<-")y_off = None
x_off = None
total_num_leaf = None
total_high = Nonedef plot_node(node_text, center_pt, parent_pt, node_type, ax_):ax_.annotate(node_text, xy=[parent_pt[0], parent_pt[1] - 0.02], xycoords='axes fraction',xytext=center_pt, textcoords='axes fraction',va="center", ha="center", size=15,bbox=node_type, arrowprops=arrow_args)def plot_mid_text(mid_text, center_pt, parent_pt, ax_):x_mid = (parent_pt[0] - center_pt[0]) / 2 + center_pt[0]y_mid = (parent_pt[1] - center_pt[1]) / 2 + center_pt[1]ax_.text(x_mid, y_mid, mid_text, fontdict=dict(size=10))def plot_tree(my_tree, parent_pt, node_text, ax_):global y_offglobal x_offglobal total_num_leafglobal total_highnum_of_leaf = my_tree.leaf_numcenter_pt = (x_off + (1 + num_of_leaf) / (2 * total_num_leaf), y_off)plot_mid_text(node_text, center_pt, parent_pt, ax_)if total_high == 0: # total_high为零时,表示就直接为一个叶节点。因为西瓜数据集的原因,在预剪枝的时候,有时候会遇到这种情况。plot_node(my_tree.leaf_class, center_pt, parent_pt, leaf_node, ax_)returnplot_node(my_tree.feature_name, center_pt, parent_pt, decision_node, ax_)y_off -= 1 / total_highfor key in my_tree.subtree.keys():if my_tree.subtree[key].is_leaf:x_off += 1 / total_num_leafplot_node(str(my_tree.subtree[key].leaf_class), (x_off, y_off), center_pt, leaf_node, ax_)plot_mid_text(str(key), (x_off, y_off), center_pt, ax_)else:plot_tree(my_tree.subtree[key], center_pt, str(key), ax_)y_off += 1 / total_highdef create_plot(tree_):global y_offglobal x_offglobal total_num_leafglobal total_hightotal_num_leaf = tree_.leaf_numtotal_high = tree_.highy_off = 1x_off = -0.5 / total_num_leaffig_, ax_ = plt.subplots()ax_.set_xticks([]) # 隐藏坐标轴刻度ax_.set_yticks([])ax_.spines['right'].set_color('none') # 设置隐藏坐标轴ax_.spines['top'].set_color('none')ax_.spines['bottom'].set_color('none')ax_.spines['left'].set_color('none')plot_tree(tree_, (0.5, 1), '', ax_)plt.show()
编程实现基于信息熵进行划分选择的决策树算法,并为表4.3中数据生成一棵决策树相关推荐
- 周志华《机器学习》习题4.4——python实现基于信息熵进行划分选择的决策树算法
1.题目 试编程实现基于信息熵进行话饭选择的决策树算法,并为表4.3中数据生成一棵决策树. 表4.3如下: 另外再附个txt版的,下次可以复制粘贴: 青绿,蜷缩,浊响,清晰,凹陷,硬滑,0.697,0 ...
- 【机器学习】西瓜书_周志华,python实现基于信息熵进行划分选择的决策树算法
python:实现基于信息熵进行划分选择的决策树算法 本文主要介绍本人用python基于信息熵进行划分选择的决策树代码实现,参考教材为西瓜书第四章--决策树.ps.本文只涉及决策树连续和离散两种情况, ...
- c语言编程实现二叉树的镜像,C/C++知识点之C++实现利用(前序和中序生成二叉树)以及(二叉树的镜像)...
本文主要向大家介绍了C/C++知识点之C++实现利用(前序和中序生成二叉树)以及(二叉树的镜像),通过具体的内容向大家展示,希望对大家学习C/C++知识点有所帮助. #include #include ...
- 《机器学习》西瓜书课后习题4.3——python实现基于信息熵划分的决策树算法(简单、全面)
<机器学习>西瓜书课后习题4.3--python实现基于信息熵划分的决策树算法 <机器学习>西瓜书P93 4.3 试编程实现基于信息熵进行划分选择的决策树算法,并为表4.3中数 ...
- 西瓜书习题4.3(基于信息熵的决策树)
试编程实现基于信息熵进行划分选择的决策树算法,并为表4.3中数据生成一颗决策树. 代码 import numpy as np import matplotlib.pyplot as plt from ...
- 【机器学习入门】(4) 决策树算法理论:算法原理、信息熵、信息增益、预剪枝、后剪枝、算法选择
各位同学好,今天我向大家介绍一下python机器学习中的决策树算法的基本原理.内容主要有: (1) 概念理解:(2) 信息熵:(3) 信息增益:(4) 算法选择:(5) 预剪枝和后剪枝. python ...
- 计算机成绩统计优秀率,基于决策树算法的成绩优秀率分析与研究.pdf
基于决策树算法的成绩优秀率分析与研究.pdf · · EraNo.122015 70 Computer DOI:10.166448.cnki.cn33-1094/tp,2015.12,019 基于决策 ...
- python编程实现决策树算法
最近布置了个课堂作业,用python实现决策树算法 .整了几天勉勉强强画出了棵歪脖子树,记录一下. 大体思路: 1.创建决策树My_Decision_Tree类,类函数__init__()初始化参数. ...
- 改进多目标粒子群储能选址定容matlab 采用matlab编程得到33节点系统改进多目标储能选址定容方案,采用基于信息熵的序数偏好法(TOPSIS)求解储能的最优接入方案
改进多目标粒子群储能选址定容matlab 采用matlab编程得到33节点系统改进多目标储能选址定容方案,采用基于信息熵的序数偏好法(TOPSIS)求解储能的最优接入方案,程序运行稳定,注释清楚. 现 ...
最新文章
- 将数据库服务器的文件D 改名为,MySQL如何更改数据库数据存储目录详解
- html中radio值的获取、赋值、注册事件示例详解
- SqlHelper数据库操作辅助类
- 农艺师需要职称计算机,2015年农艺师职称计算机考试宝典.doc
- vue加载时闪现模板语法-处理方法
- win2003 IIS6配置PHP 5.3.3(fastCGI方式+eAccelerator)+ASP.NET 4.0(MVC3)
- mysql noinstall 5.5_mysqlnoinstall 手动安装
- 从 Storm 到 Flink,汽车之家基于 Flink 的实时 SQL 平台设计思路与实践
- 企业权限管理系统之AdminLTE的基本介绍(一)
- installshield java_项目创建失败(vs2012中的InstallShield)
- mac卸载java1.7_Mac 下安装、卸载Java 7
- linux驱动开发 ST7789 LCD驱动移植(I.MX6ULL平台)
- 在线云html排版,云标签,关键字图排版 html5 canvas版
- 卡尔曼滤波(Kalman filter)算法
- 闭关修炼(十八)maven
- android 编译器indel,Overview of the HbbTV compliant browser upgrade on Android based DTV platform
- java 风的角度转风向
- 技术贴_关于某信辅助分析记录和若干检测方法
- 影集制作php源码_2018最新仿720全景在线制作云平台网站PHP源码(新增微信支付+打赏+场景红包+本地存储)...
- 使用轻量型模型对deepsort特征提取模块重训练