import numpy as np
import pandas as pd
from sklearn.utils.multiclass import type_of_target
from decision_tree import treePlottterclass Node(object):def __init__(self):self.feature_name = None  # 特性的名称self.feature_index = None  # 特性的下标self.subtree = {}  #树节点的集合self.impurity = None #信息此节点的信息增益self.is_continuous = False #是否为连续值self.split_value = None #连续值时的划分依据self.is_leaf = False #是否为叶子节点self.leaf_class = None #叶子节点对应的类self.leaf_num = 0 # 叶子数目self.high = -1 # 树的高度def entroy(y):p = pd.value_counts(y) / y.shape[0]  # 计算各类样本所占比率ent = np.sum(-p * np.log2(p))return entreturn  node
def info_gain(feature, y, entD, is_continuous=False):'''计算信息增益------:param feature: 当前特征下所有样本值:param y:       对应标签值:return:        当前特征的信息增益, list类型,若当前特征为离散值则只有一个元素为信息增益,若为连续值,则第一个元素为信息增益,第二个元素为切分点'''m = y.shape[0]unique_value = pd.unique(feature)if is_continuous:unique_value.sort()  # 排序, 用于建立分割点split_point_set = [(unique_value[i] + unique_value[i + 1]) / 2 for i in range(len(unique_value) - 1)]min_ent = float('inf')  # 挑选信息熵最小的分割点min_ent_point = Nonefor split_point_ in split_point_set:Dv1 = y[feature <= split_point_]Dv2 = y[feature > split_point_]feature_ent_ = Dv1.shape[0] / m * entroy(Dv1) + Dv2.shape[0] / m * entroy(Dv2)if feature_ent_ < min_ent:min_ent = feature_ent_min_ent_point = split_point_gain = entD - min_entreturn [gain, min_ent_point]else:feature_ent = 0for value in unique_value:Dv = y[feature == value]  # 当前特征中取值为 value 的样本,即书中的 D^{v}feature_ent += Dv.shape[0] / m * entroy(Dv)gain = entD - feature_ent  # 原书中4.2式return [gain]
def choose_best_feature_infogain(X, y):'''以返回值中best_info_gain 的长度来判断当前特征是否为连续值,若长度为 1 则为离散值,若长度为 2 , 则为连续值:param X: 当前所有特征的数据 pd.DaraFrame格式:param y: 标签值:return:  以信息增益来选择的最佳划分属性,第一个返回值为属性名称,'''features = X.columnsbest_feature_name = Nonebest_info_gain = [float('-inf')]entD = entroy(y)for feature_name in features:is_continuous = type_of_target(X[feature_name]) == 'continuous'infogain = info_gain(X[feature_name], y, entD, is_continuous)if infogain[0] > best_info_gain[0]:best_feature_name = feature_namebest_info_gain = infogainreturn best_feature_name, best_info_gain
def generate(X,y,columns):node = Node()# Pandas.Series.nunique()统计不同值的个数if y.nunique() == 1:  # 属于同一类别node.is_leaf = Truenode.leaf_class = y.values[0]node.high = 0node.leaf_num += 1return nodeif X.empty:  # 特征用完了,数据为空,返回样本数最多的类node.is_leaf = Truenode.leaf_class = pd.value_counts(y).index[0]  # 返回样本数最多的类node.high = 0node.leaf_num += 1return nodebest_feature_name, best_impurity = choose_best_feature_infogain(X, y)node.feature_name = best_feature_namenode.impurity = best_impurity[0]node.feature_index = columns.index(best_feature_name)feature_values = X.loc[:, best_feature_name]if len(best_impurity) == 1:  # 离散值node.is_continuous = Falseunique_vals = pd.unique(feature_values)sub_X = X.drop(best_feature_name, axis=1)max_high = -1for value in unique_vals:node.subtree[value] = generate(sub_X[feature_values == value], y[feature_values == value],columns)if node.subtree[value].high > max_high:  # 记录子树下最高的高度max_high = node.subtree[value].highnode.leaf_num += node.subtree[value].leaf_numnode.high = max_high + 1elif len(best_impurity) == 2:  # 连续值node.is_continuous = Truenode.split_value = best_impurity[1]up_part = '>= {:.3f}'.format(node.split_value)down_part = '< {:.3f}'.format(node.split_value)node.subtree[up_part] = generate(X[feature_values >= node.split_value],y[feature_values >= node.split_value],columns)node.subtree[down_part] = generate(X[feature_values < node.split_value],y[feature_values < node.split_value],columns)node.leaf_num += (node.subtree[up_part].leaf_num + node.subtree[down_part].leaf_num)node.high = max(node.subtree[up_part].high, node.subtree[down_part].high) + 1return nodeif __name__ == "__main__":data = pd.read_csv("西瓜3.0.txt", index_col=0)  # index_col参数设置第一列作为index#不带第一列,求得西瓜的属性x = data.iloc[:, :8] #<class 'pandas.core.frame.DataFrame'>y = data.iloc[:, 8] #<class 'pandas.core.series.Series'>columns_name = list(x.columns)  # 包括原数据的列名node = generate(x,y,columns_name)
treePlottter.create_plot(node)

另一个画图算法,这是另一个py文件

from matplotlib import pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
decision_node = dict(boxstyle='round,pad=0.3', fc='#FAEBD7')
leaf_node = dict(boxstyle='round,pad=0.3', fc='#F4A460')
arrow_args = dict(arrowstyle="<-")y_off = None
x_off = None
total_num_leaf = None
total_high = Nonedef plot_node(node_text, center_pt, parent_pt, node_type, ax_):ax_.annotate(node_text, xy=[parent_pt[0], parent_pt[1] - 0.02], xycoords='axes fraction',xytext=center_pt, textcoords='axes fraction',va="center", ha="center", size=15,bbox=node_type, arrowprops=arrow_args)def plot_mid_text(mid_text, center_pt, parent_pt, ax_):x_mid = (parent_pt[0] - center_pt[0]) / 2 + center_pt[0]y_mid = (parent_pt[1] - center_pt[1]) / 2 + center_pt[1]ax_.text(x_mid, y_mid, mid_text, fontdict=dict(size=10))def plot_tree(my_tree, parent_pt, node_text, ax_):global y_offglobal x_offglobal total_num_leafglobal total_highnum_of_leaf = my_tree.leaf_numcenter_pt = (x_off + (1 + num_of_leaf) / (2 * total_num_leaf), y_off)plot_mid_text(node_text, center_pt, parent_pt, ax_)if total_high == 0:  # total_high为零时,表示就直接为一个叶节点。因为西瓜数据集的原因,在预剪枝的时候,有时候会遇到这种情况。plot_node(my_tree.leaf_class, center_pt, parent_pt, leaf_node, ax_)returnplot_node(my_tree.feature_name, center_pt, parent_pt, decision_node, ax_)y_off -= 1 / total_highfor key in my_tree.subtree.keys():if my_tree.subtree[key].is_leaf:x_off += 1 / total_num_leafplot_node(str(my_tree.subtree[key].leaf_class), (x_off, y_off), center_pt, leaf_node, ax_)plot_mid_text(str(key), (x_off, y_off), center_pt, ax_)else:plot_tree(my_tree.subtree[key], center_pt, str(key), ax_)y_off += 1 / total_highdef create_plot(tree_):global y_offglobal x_offglobal total_num_leafglobal total_hightotal_num_leaf = tree_.leaf_numtotal_high = tree_.highy_off = 1x_off = -0.5 / total_num_leaffig_, ax_ = plt.subplots()ax_.set_xticks([])  # 隐藏坐标轴刻度ax_.set_yticks([])ax_.spines['right'].set_color('none')  # 设置隐藏坐标轴ax_.spines['top'].set_color('none')ax_.spines['bottom'].set_color('none')ax_.spines['left'].set_color('none')plot_tree(tree_, (0.5, 1), '', ax_)plt.show()

编程实现基于信息熵进行划分选择的决策树算法,并为表4.3中数据生成一棵决策树相关推荐

  1. 周志华《机器学习》习题4.4——python实现基于信息熵进行划分选择的决策树算法

    1.题目 试编程实现基于信息熵进行话饭选择的决策树算法,并为表4.3中数据生成一棵决策树. 表4.3如下: 另外再附个txt版的,下次可以复制粘贴: 青绿,蜷缩,浊响,清晰,凹陷,硬滑,0.697,0 ...

  2. 【机器学习】西瓜书_周志华,python实现基于信息熵进行划分选择的决策树算法

    python:实现基于信息熵进行划分选择的决策树算法 本文主要介绍本人用python基于信息熵进行划分选择的决策树代码实现,参考教材为西瓜书第四章--决策树.ps.本文只涉及决策树连续和离散两种情况, ...

  3. c语言编程实现二叉树的镜像,C/C++知识点之C++实现利用(前序和中序生成二叉树)以及(二叉树的镜像)...

    本文主要向大家介绍了C/C++知识点之C++实现利用(前序和中序生成二叉树)以及(二叉树的镜像),通过具体的内容向大家展示,希望对大家学习C/C++知识点有所帮助. #include #include ...

  4. 《机器学习》西瓜书课后习题4.3——python实现基于信息熵划分的决策树算法(简单、全面)

    <机器学习>西瓜书课后习题4.3--python实现基于信息熵划分的决策树算法 <机器学习>西瓜书P93 4.3 试编程实现基于信息熵进行划分选择的决策树算法,并为表4.3中数 ...

  5. 西瓜书习题4.3(基于信息熵的决策树)

    试编程实现基于信息熵进行划分选择的决策树算法,并为表4.3中数据生成一颗决策树. 代码 import numpy as np import matplotlib.pyplot as plt from ...

  6. 【机器学习入门】(4) 决策树算法理论:算法原理、信息熵、信息增益、预剪枝、后剪枝、算法选择

    各位同学好,今天我向大家介绍一下python机器学习中的决策树算法的基本原理.内容主要有: (1) 概念理解:(2) 信息熵:(3) 信息增益:(4) 算法选择:(5) 预剪枝和后剪枝. python ...

  7. 计算机成绩统计优秀率,基于决策树算法的成绩优秀率分析与研究.pdf

    基于决策树算法的成绩优秀率分析与研究.pdf · · EraNo.122015 70 Computer DOI:10.166448.cnki.cn33-1094/tp,2015.12,019 基于决策 ...

  8. python编程实现决策树算法

    最近布置了个课堂作业,用python实现决策树算法 .整了几天勉勉强强画出了棵歪脖子树,记录一下. 大体思路: 1.创建决策树My_Decision_Tree类,类函数__init__()初始化参数. ...

  9. 改进多目标粒子群储能选址定容matlab 采用matlab编程得到33节点系统改进多目标储能选址定容方案,采用基于信息熵的序数偏好法(TOPSIS)求解储能的最优接入方案

    改进多目标粒子群储能选址定容matlab 采用matlab编程得到33节点系统改进多目标储能选址定容方案,采用基于信息熵的序数偏好法(TOPSIS)求解储能的最优接入方案,程序运行稳定,注释清楚. 现 ...

最新文章

  1. 将数据库服务器的文件D 改名为,MySQL如何更改数据库数据存储目录详解
  2. html中radio值的获取、赋值、注册事件示例详解
  3. SqlHelper数据库操作辅助类
  4. 农艺师需要职称计算机,2015年农艺师职称计算机考试宝典.doc
  5. vue加载时闪现模板语法-处理方法
  6. win2003 IIS6配置PHP 5.3.3(fastCGI方式+eAccelerator)+ASP.NET 4.0(MVC3)
  7. mysql noinstall 5.5_mysqlnoinstall 手动安装
  8. 从 Storm 到 Flink,汽车之家基于 Flink 的实时 SQL 平台设计思路与实践
  9. 企业权限管理系统之AdminLTE的基本介绍(一)
  10. installshield java_项目创建失败(vs2012中的InstallShield)
  11. mac卸载java1.7_Mac 下安装、卸载Java 7
  12. linux驱动开发 ST7789 LCD驱动移植(I.MX6ULL平台)
  13. 在线云html排版,云标签,关键字图排版 html5 canvas版
  14. 卡尔曼滤波(Kalman filter)算法
  15. 闭关修炼(十八)maven
  16. android 编译器indel,Overview of the HbbTV compliant browser upgrade on Android based DTV platform
  17. java 风的角度转风向
  18. 技术贴_关于某信辅助分析记录和若干检测方法
  19. 影集制作php源码_2018最新仿720全景在线制作云平台网站PHP源码(新增微信支付+打赏+场景红包+本地存储)...
  20. 使用轻量型模型对deepsort特征提取模块重训练

热门文章

  1. 全能开发工具 ComponentOne(3)——常用控件下篇
  2. ifconfig查看本机IP
  3. Ubuntu搭建Socks5代理服务器
  4. 为什么要将线程设置成分离状态
  5. c语言中结构体的指针初始化,c语言结构体指针初始化
  6. Prototype 入门
  7. 为什么公共关系应该在您的社交媒体营销中发挥作用
  8. ASP.NET网站与Discuz!NT论坛整合
  9. Ant Design Vue子表格展开只展开一行,其他行折叠
  10. 中文字幕!吴恩达 ChatGPT 最新课程