import numpy as np
import pandas as pd
from sklearn.utils.multiclass import type_of_target
from decision_tree import treePlottterclass Node(object):def __init__(self):self.feature_name = None  # 特性的名称self.feature_index = None  # 特性的下标self.subtree = {}  # 树节点的集合self.impurity = None  # 信息此节点的信息增益self.is_continuous = False  # 是否为连续值self.split_value = None  # 连续值时的划分依据self.is_leaf = False  # 是否为叶子节点self.leaf_class = None  # 叶子节点对应的类self.leaf_num = 0  # 叶子数目self.high = -1  # 树的高度def entroy(y):p = pd.value_counts(y) / y.shape[0]  # 计算各类样本所占比率ent = np.sum(-p * np.log2(p))return entreturn nodedef gini(y):p = pd.value_counts(y) / y.shape[0]gini = 1 - np.sum(p ** 2)return gini
def gini_index(feature, y, is_continuous=False):'''计算基尼指数, 对于连续值,选择基尼系统最小的点,作为分割点-------:param feature::param y::return:'''m = y.shape[0]unique_value = pd.unique(feature)if is_continuous:unique_value.sort()  # 排序, 用于建立分割点# 这里其实也可以直接用feature值作为分割点,但这样会出现空集, 所以还是按照书中4.7式建立分割点。好处是不会出现空集split_point_set = [(unique_value[i] + unique_value[i + 1]) / 2 for i in range(len(unique_value) - 1)]min_gini = float('inf')min_gini_point = Nonefor split_point_ in split_point_set:  # 遍历所有的分割点,寻找基尼指数最小的分割点Dv1 = y[feature <= split_point_]Dv2 = y[feature > split_point_]gini_index = Dv1.shape[0] / m * gini(Dv1) + Dv2.shape[0] / m * gini(Dv2)if gini_index < min_gini:min_gini = gini_indexmin_gini_point = split_point_return [min_gini, min_gini_point]else:gini_index = 0for value in unique_value:Dv = y[feature == value]m_dv = Dv.shape[0]gini_ = gini(Dv)  # 原书4.5式gini_index += m_dv / m * gini_  # 4.6式return [gini_index]
def info_gain(feature, y, entD, is_continuous=False):'''计算信息增益------:param feature: 当前特征下所有样本值:param y:       对应标签值:return:        当前特征的信息增益, list类型,若当前特征为离散值则只有一个元素为信息增益,若为连续值,则第一个元素为信息增益,第二个元素为切分点'''m = y.shape[0]unique_value = pd.unique(feature)if is_continuous:unique_value.sort()  # 排序, 用于建立分割点split_point_set = [(unique_value[i] + unique_value[i + 1]) / 2 for i in range(len(unique_value) - 1)]min_ent = float('inf')  # 挑选信息熵最小的分割点min_ent_point = Nonefor split_point_ in split_point_set:Dv1 = y[feature <= split_point_]Dv2 = y[feature > split_point_]feature_ent_ = Dv1.shape[0] / m * entroy(Dv1) + Dv2.shape[0] / m * entroy(Dv2)if feature_ent_ < min_ent:min_ent = feature_ent_min_ent_point = split_point_gain = entD - min_entreturn [gain, min_ent_point]else:feature_ent = 0for value in unique_value:Dv = y[feature == value]  # 当前特征中取值为 value 的样本,即书中的 D^{v}feature_ent += Dv.shape[0] / m * entroy(Dv)gain = entD - feature_ent  # 原书中4.2式return [gain]def info_gainRatio(feature, y, entD, is_continuous=False):'''计算信息增益率 参数和info_gain方法中参数一致------:param feature::param y::param entD::return:'''if is_continuous:# 对于连续值,以最大化信息增益选择划分点之后,计算信息增益率,注意,在选择划分点之后,需要对信息增益进行修正,要减去log_2(N-1)/|D|,N是当前特征的取值个数,D是总数据量。# 修正原因是因为:当离散属性和连续属性并存时,C4.5算法倾向于选择连续特征做最佳树分裂点# 信息增益修正中,N的值,网上有些资料认为是“可能分裂点的个数”,也有的是“当前特征的取值个数”,这里采用“当前特征的取值个数”。# 这样 (N-1)的值,就是去重后的“分裂点的个数” , 即在info_gain函数中,split_point_set的长度,个人感觉这样更加合理。有时间再看看原论文吧。gain, split_point = info_gain(feature, y, entD, is_continuous)p1 = np.sum(feature <= split_point) / feature.shape[0]  # 小于或划分点的样本占比p2 = 1 - p1  # 大于划分点样本占比IV = -(p1 * np.log2(p1) + p2 * np.log2(p2))grain_ratio = (gain - np.log2(feature.nunique()) / len(y)) / IV  # 对信息增益修正return [grain_ratio, split_point]else:p = pd.value_counts(feature) / feature.shape[0]  # 当前特征下 各取值样本所占比率IV = np.sum(-p * np.log2(p))  # 原书4.4式grain_ratio = info_gain(feature, y, entD, is_continuous)[0] / IVreturn [grain_ratio]def choose_best_feature_gini( X, y):features = X.columnsbest_feature_name = Nonebest_gini = [float('inf')]for feature_name in features:is_continuous = type_of_target(X[feature_name]) == 'continuous'gini_idex = gini_index(X[feature_name], y, is_continuous)if gini_idex[0] < best_gini[0]:best_feature_name = feature_namebest_gini = gini_idexreturn best_feature_name, best_ginidef choose_best_feature_gainratio( X, y):'''以返回值中best_gain_ratio 的长度来判断当前特征是否为连续值,若长度为 1 则为离散值,若长度为 2 , 则为连续值:param X: 当前所有特征的数据 pd.DaraFrame格式:param y: 标签值:return:  以信息增益率来选择的最佳划分属性,第一个返回值为属性名称,第二个为最佳划分属性对应的信息增益率'''features = X.columnsbest_feature_name = Nonebest_gain_ratio = [float('-inf')]entD = entroy(y)for feature_name in features:is_continuous = type_of_target(X[feature_name]) == 'continuous'info_gain_ratio = info_gainRatio(X[feature_name], y, entD, is_continuous)if info_gain_ratio[0] > best_gain_ratio[0]:best_feature_name = feature_namebest_gain_ratio = info_gain_ratioreturn best_feature_name, best_gain_ratio
def choose_best_feature_infogain(X, y):'''以返回值中best_info_gain 的长度来判断当前特征是否为连续值,若长度为 1 则为离散值,若长度为 2 , 则为连续值:param X: 当前所有特征的数据 pd.DaraFrame格式:param y: 标签值:return:  以信息增益来选择的最佳划分属性,第一个返回值为属性名称,'''features = X.columnsbest_feature_name = Nonebest_info_gain = [float('-inf')]entD = entroy(y)for feature_name in features:is_continuous = type_of_target(X[feature_name]) == 'continuous'infogain = info_gain(X[feature_name], y, entD, is_continuous)if infogain[0] > best_info_gain[0]:best_feature_name = feature_namebest_info_gain = infogainreturn best_feature_name, best_info_gaindef generate(X, y, columns,criterion):node = Node()# Pandas.Series.nunique()统计不同值的个数if y.nunique() == 1:  # 属于同一类别node.is_leaf = Truenode.leaf_class = y.values[0]node.high = 0node.leaf_num += 1return nodeif X.empty:  # 特征用完了,数据为空,返回样本数最多的类node.is_leaf = Truenode.leaf_class = pd.value_counts(y).index[0]  # 返回样本数最多的类node.high = 0node.leaf_num += 1return nodeif criterion == 'gini':best_feature_name, best_impurity=choose_best_feature_gini(X, y)elif criterion == 'infogain':best_feature_name, best_impurity=choose_best_feature_infogain(X, y)elif criterion == 'gainratio':best_feature_name, best_impurity=choose_best_feature_gainratio(X, y)# best_feature_name, best_impurity = choose_best_feature_infogain(X, y)node.feature_name = best_feature_namenode.impurity = best_impurity[0]node.feature_index = columns.index(best_feature_name)feature_values = X.loc[:, best_feature_name]if len(best_impurity) == 1:  # 离散值node.is_continuous = Falseunique_vals = pd.unique(feature_values)sub_X = X.drop(best_feature_name, axis=1)max_high = -1for value in unique_vals:node.subtree[value] = generate(sub_X[feature_values == value], y[feature_values == value], columns,criterion)if node.subtree[value].high > max_high:  # 记录子树下最高的高度max_high = node.subtree[value].highnode.leaf_num += node.subtree[value].leaf_numnode.high = max_high + 1elif len(best_impurity) == 2:  # 连续值node.is_continuous = Truenode.split_value = best_impurity[1]up_part = '>= {:.3f}'.format(node.split_value)down_part = '< {:.3f}'.format(node.split_value)node.subtree[up_part] = generate(X[feature_values >= node.split_value],y[feature_values >= node.split_value], columns,criterion)node.subtree[down_part] = generate(X[feature_values < node.split_value],y[feature_values < node.split_value], columns,criterion)node.leaf_num += (node.subtree[up_part].leaf_num + node.subtree[down_part].leaf_num)node.high = max(node.subtree[up_part].high, node.subtree[down_part].high) + 1return nodeif __name__ == "__main__":data = pd.read_csv("西瓜3.0.txt", index_col=0)  # index_col参数设置第一列作为index# 不带第一列,求得西瓜的属性x = data.iloc[:, :8]  # <class 'pandas.core.frame.DataFrame'>y = data.iloc[:, 8]  # <class 'pandas.core.series.Series'>columns_name = list(x.columns)  # 包括原数据的列名criterion = "gini" #'gainratio''infogain': 'gini':node = generate(x, y, columns_name,criterion)treePlottter.create_plot(node)

画图的函数

from matplotlib import pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
decision_node = dict(boxstyle='round,pad=0.3', fc='#FAEBD7')
leaf_node = dict(boxstyle='round,pad=0.3', fc='#F4A460')
arrow_args = dict(arrowstyle="<-")y_off = None
x_off = None
total_num_leaf = None
total_high = Nonedef plot_node(node_text, center_pt, parent_pt, node_type, ax_):ax_.annotate(node_text, xy=[parent_pt[0], parent_pt[1] - 0.02], xycoords='axes fraction',xytext=center_pt, textcoords='axes fraction',va="center", ha="center", size=15,bbox=node_type, arrowprops=arrow_args)def plot_mid_text(mid_text, center_pt, parent_pt, ax_):x_mid = (parent_pt[0] - center_pt[0]) / 2 + center_pt[0]y_mid = (parent_pt[1] - center_pt[1]) / 2 + center_pt[1]ax_.text(x_mid, y_mid, mid_text, fontdict=dict(size=10))def plot_tree(my_tree, parent_pt, node_text, ax_):global y_offglobal x_offglobal total_num_leafglobal total_highnum_of_leaf = my_tree.leaf_numcenter_pt = (x_off + (1 + num_of_leaf) / (2 * total_num_leaf), y_off)plot_mid_text(node_text, center_pt, parent_pt, ax_)if total_high == 0:  # total_high为零时,表示就直接为一个叶节点。因为西瓜数据集的原因,在预剪枝的时候,有时候会遇到这种情况。plot_node(my_tree.leaf_class, center_pt, parent_pt, leaf_node, ax_)returnplot_node(my_tree.feature_name, center_pt, parent_pt, decision_node, ax_)y_off -= 1 / total_highfor key in my_tree.subtree.keys():if my_tree.subtree[key].is_leaf:x_off += 1 / total_num_leafplot_node(str(my_tree.subtree[key].leaf_class), (x_off, y_off), center_pt, leaf_node, ax_)plot_mid_text(str(key), (x_off, y_off), center_pt, ax_)else:plot_tree(my_tree.subtree[key], center_pt, str(key), ax_)y_off += 1 / total_highdef create_plot(tree_):global y_offglobal x_offglobal total_num_leafglobal total_hightotal_num_leaf = tree_.leaf_numtotal_high = tree_.highy_off = 1x_off = -0.5 / total_num_leaffig_, ax_ = plt.subplots()ax_.set_xticks([])  # 隐藏坐标轴刻度ax_.set_yticks([])ax_.spines['right'].set_color('none')  # 设置隐藏坐标轴ax_.spines['top'].set_color('none')ax_.spines['bottom'].set_color('none')ax_.spines['left'].set_color('none')plot_tree(tree_, (0.5, 1), '', ax_)plt.show()

网上下载或自己编程实现任意一种多变量决策树算法,并观察其在西瓜数据集3.0上产生的结果。相关推荐

  1. 周志华《机器学习》3.5答案-编程实现线性判别分析,并给出西瓜数据集3.0α上的结果

    #机器学习线性判别分析3.5题 import numpy as np import matplotlib.pyplot as plt data = [[0.697, 0.460, 1],[0.774, ...

  2. 机器学习《西瓜书》9.4解答——k-means算法:编程实现k均值算法,设置三组不同的k值、三组不同初始中心点,在西瓜数据集4.0上进行实验比较,并讨论什么样的初始中心有助于得到好结果。

    1.运行结果:(注:图中方块标注的点为随机选取的初始样本点) k=2时: 本次选取的2个初始向量为[[0.243, 0.267], [0.719, 0.103]] 共进行61轮 共耗时0.10s k= ...

  3. 如何将自己设计的图标或通过网上下载的图标上传到阿里图标图库中使用方法教程

    如何将自己设计的图标或通过网上下载的图标上传到阿里图标图库中使用方法教程 作者:张国军_Suger 开发工具与关键技术:Win10.项目.图标 对于编程人员来说,有一个不可获取的图库就是阿里图标库,这 ...

  4. 用CMake编译运行在网上下载的源文件src

    参考:http://blog.csdn.net/yiqiudream/article/details/51885698 (一).怎么用CMake打开下载的源文件? 工具:下载CMake --> ...

  5. java 文件下载详解_Java 从网上下载文件的几种方式实例代码详解

    废话不多说了,直接给大家贴代码了,具体代码如下所示: package com.github.pandafang.tool; import java.io.BufferedOutputStream; i ...

  6. android 服务端 导入工程,如何导入与配置从网上下载的android源代码及服务器端源代码...

    将Android项目导入import进Eclipse. 注意SDK版本是否匹配 . 服务器部署到Tomcat下. 你得在数据库中将这个点菜系统的数据库和表建好,或者导入.在服务器的代码中修改好你的数据 ...

  7. 如何把网上下载的前端页面在Spring Boot中跑起来(CSS,JavaScript,程序运行等路径设置)

    这个功能非常有用,估计99.99%的java web开发者都干过,本人是初学者,特写这个博客记录下! 方便本人以后查阅,方便以后进行投机取巧 这里使用thymeleaf模板引擎! 在网上下载了一个Bo ...

  8. 自己封装的Windows7 64位旗舰版,微软官网上下载的Windows7原版镜像制作,绝对纯净版...

    MSDN官网上下载的Windows7 64位 旗舰版原版镜像制作,绝对纯净版,无任何精简,不捆绑任何第三方软件.浏览器插件,不含任何木马.病毒等. 集成: 1.Office2010 2.DirectX ...

  9. 在spring官网上下载历史版本的spring插件,springsource-tool-suite

    在spring官网上下载历史版本的spring插件,springsource-tool-suite 如何为自己的eclipse下载历史版本的sts呢?拼下载的url. 首先,鼠标右键可下载的sts链接 ...

  10. 解决 网上下载的例子 My Mac 64-bit 不能运行的问题

    在左侧选中项目名称,在右侧TARGETS中选择  Summary , 将Deployment Target 字段改成你本机能够支持的版本.例如此次我从网上下载的例子是基于6.0开发的,那么我将其改为5 ...

最新文章

  1. 为什么ConcurrentHashMap的读操作不需要加锁?
  2. Effective C++学习笔记(Part Five:Item 26-31)
  3. ExtJs xtype一览
  4. 删除msconfig启动项不打勾的东西
  5. 敏捷开发绩效管理之四:为团队设立外部绩效目标(目标管理,外向型绩效)...
  6. 第二章 Python基本元素:数字、字符串和变量
  7. 如何在Kubernetes集群动态使用 NAS 持久卷 1
  8. python如何实现循环_如何构造python循环
  9. 程序员转型架构师,推荐你读这几本书
  10. c语言中形参和实参的区别
  11. b站黑马springCloud-常见面试题,多多三连
  12. 二叉树叶子结点个数——C++
  13. PMP每日一练 | 考试不迷路-5.13
  14. php清除垃圾代码,批处理清理系统垃圾代码,简单快速实用(适用于xp win7)
  15. php 百度地图导航代码,百度地图API自动定位和3种导航
  16. 四、大话HTTP协议-用Wireshark研究一个完整的TCP连接
  17. 如何查找SCI期刊的缩写
  18. oracle ndb,NDB语法 - ivaneeo's blog - BlogJava
  19. 软件体系结构:应用软件的设计与开发
  20. 杰里之AC696 系列外插 MIC 做混响或扩音设计注意【篇】

热门文章

  1. 视频教程-网络工程师实战系列视频课程【VLAN专题】-网络技术
  2. windows xp 美化大师之系统主题
  3. 0.3-87 GHz频段手持频谱分析仪 —— SAF Spectrum Compact
  4. 响应式web开发 许愿墙
  5. 如何写好简历与迎接面试
  6. ActionForm详解
  7. java action例子_实例——创建ActionForm Bean
  8. ESET_VC52_UPID (nod32激活获取器)V4.2.0.9 绿色版
  9. 宽带连接自动断开是怎么回事?
  10. 利用 0DAY 漏洞 CVE-2018-8174 获取windows系统 shell