文章目录

  • 楔子
  • 定义变量:
  • 定义方法
    • 获得划分的feature
    • 生成结点
    • 停止条件及其处理
    • fit()
    • 生成树剪枝

楔子

前面已经实现了各种信息量的计算,那么我们划分的基本有了,那么我们需要使用这个基本来划分,来生成决策树,树的基本单元是node,这里的node是一堆数据的集合+他们内在的方法。由于需要处理三种算法,我们最好能使用基类,该类应该至少包含:

1、选择划分的feature;
2、根据基准划分数据生成新的结点;
3、判断那些节点可以当成叶子,并归类;

定义变量:

node的变量是不是有点多,容易糊涂,我们需要抓住重点,首先是一个node,那么这个node有那些变量了,按照上述方法应该有如下三个变量:self.feature_dim(划分的feature),self._children(生成的子节点),self.leafs(一颗子节点实际代表一颗子树,这颗子树下的所有叶子结点),当然还有数据集(self._x ,self._y),其他的变量根据需要慢慢添加。

# =============================================================================
# 实现完信息量的计算,下面就考虑决策树的生成
# 决策树是一颗x叉树,需要数据结构node
# 我们需要同时处理ID3,C4.5,CART所以需要建立一个基类
# 类的方法:1、根据离散特征划分数据;2、根据连续特征划分数据;3、根据当前数据判断
# 属于哪一个类别
# =============================================================================
import numpy as np
from Cluster import Cluster
class CvDNode:def __init__(self,tree=None,base=2,chaos=None,depth=0,parent=None,is_root=True,pre_feat="Root"):# 数据集的变量,都是numpy数组self._x = self._y = None# 记录当前的log底和当前的不确定度self.base,self.chaos =base,chaos# 计算该节点的信息增益的方法self.criterion = None# 该节点所属的类别self.category = None# 针对连续特征和CART,记录当前结点的左右结点self.left_child = self.right_child =None# 该node的所有子节点和所有的叶子结点self._children,self.leafs = {},{}# 记录样本权重self.sample_weight =None# whether continuous记录各个纬度的特征是否连续self.wc = None# 记录该node为根的treeself.tree =tree# 如果传入tree的话,初始化if tree is not None:# 数据预处理是由tree完成# 各个features是否连续也是由tree记录self.wc = tree.whether_continuous# 这里的node变量是Tree中所记录的所有node的列表tree.nodes.append(self)# 记录该node划分的相关信息:# 记录划分选取的featureself.feature_dim =None# 记录划分二分划分的标准,针对连续特征和CARTself.tar = None# 记录划分标准的feature的featureValuesself.feats =[]# 记录该结点的父节点和是否为根结点self.parent =parentself.is_root = is_root# 记录结点的深度self._depth = depth# 记录该节点的父节点的划分标准featureself.pre_feat = pre_feat# 记录该节点是否使用的CART算法self.is_cart =False# 记录该node划分标准feature是否是连续的self.is_continuous = False# 记录该node是否已经被剪掉,后面剪枝会用到self.pruned = False# 用于剪枝时,删除一个node对那些node有影响self.affected = False # 重载__getitem__运算符,支持[]运算符# 重载__getattribute__运算符,支持.运算符def __getitem__(self,item):if isinstance(item,str):return getattr(self,"_" + item)# 重载__lt__的方法,使node之间可以相互比较,less than,方便调试def __lt__(self,other):return self.pre_feat < other.pre_feat# 重载__str__和 __repr__为了方便调试def __str__(self):# 该节点的类别属性if self.category is None:return "CvDNode ({}) ({} -> {})".format(self.depth,self.pre_feat,self.feature_dim)return "CvDNode ({}) ({} -> class:{})".format(self.depth,self.pre_feat,self.tree.label[self.category])# __repr__ 用于打印,面对程序员,__str__用于打印,面对用户__repr__ = __str__
# =============================================================================
#   定义几个property,Python内置的@property装饰器就是负责把一个方法变成属性调用的
#   @property广泛应用在类的定义中,可以让调用者写出简短的代码,
#   同时保证对参数进行必要的检查,这样,程序运行时就减少了出错的可能性。
# =============================================================================# 定义children属性,主要区分开连续,CART和其余情况# 有了该属性之后所有子节点都不需要区分情况了@propertydef children(self):return {"left" : self.left_child, "right" : self.right_child} if (self.is_cart or self.is_continuous) else self._children# 递归定义高度属性# 叶子结点的高度为1,其余结点高度为子节点最高高度+1@propertydef height(self):if self.category is not None:return 1return 1 + max([_child.height if _child is not None else 0 for _child in self.children.values()])# 定义info_dic属性(信息字典),记录了该node的主要属性# 在更新各个node的叶子结点时,叶子结点self.leafs的属性就是该节点@propertydef info_dic(self):return {"chaos" : self.chaos,"y" :self._y}

定义方法

大概回忆一下算法的大概处理流程:选择划分的features,递归的划分生成子节点,满足停止条件,形成叶子。上述就是整个fit的过程。在实现主要函数fit()前,我们需要定义好一些fit()会使用到的函数。

获得划分的feature

这个放在fit()函数里面,见fit()

生成结点

获取了划分的feature之后,就需要根据这标准来生成子树,或者说子节点。
处理大概过程:我们需要生成一个新结点,需要知道他的数据样本(通过mask找到),他对应的chaos,还有一些其他的辅助变量,形成新的结点或者子树,然后递归处理这堆数据不断的生成结点。这里面需要区分离散、连续和CART3种情况处理。

# =============================================================================
#   生成结点的方法
# =============================================================================# chaos_lst:[featureValue] = chaos,指定feature下,不同featureValue的不确定度def _gen_children(self, chaos_lst):# 获取当前结点的划分feature,连续性时:该feature划分基准为tarfeat, tar = self.feature_dim, self.tar# 获取该feature是否是连续的self.is_continuous = continuous = self.wc[feat]# 取对应feature的N个数据,这是简化的写法,得到该feature的那列,实际是一行features = self._x[..., feat]# 当前结点可以使用的featuresnew_feats = self.feats.copy()# 连续性二叉处理if continuous:# 根据划分依据tar得到一类的maskmask = features < tar# 这个就是分成两类的maskmasks = [mask, ~mask]else:if self.is_cart:# CART根据划分依据tar得到一类的maskmask = features == tar# 分成两类的maskmasks = [mask, ~mask]# 把这个划分tar从指定feature下featureValue里移除self.tree.feature_sets[feat].discard(tar)else:# 离散型,没有mask,直接使用featureValue数量生成子节点masks = None# 二分情况处理if self.is_cart or continuous:feats = [tar, "+"] if not continuous else ["{:6.4}-".format(tar), "{:6.4}+".format(tar)]for feat, side, chaos in zip(feats, ["left_child", "right_child"], chaos_lst):new_node = self.__class__(self.tree, self.base, chaos=chaos,depth=self._depth + 1, parent=self, is_root=False, prev_feat=feat)new_node.criterion = self.criterionsetattr(self, side, new_node)for node, feat_mask in zip([self.left_child, self.right_child], masks):if self.sample_weight is None:local_weights = Noneelse:local_weights = self.sample_weight[feat_mask]local_weights /= np.sum(local_weights)tmp_data, tmp_labels = self._x[feat_mask, ...], self._y[feat_mask]if len(tmp_labels) == 0:continuenode.feats = new_featsnode.fit(tmp_data, tmp_labels, local_weights)else:# 离散情况处理# 可选择的features里移除已选择的feature,子节点就使用这些features寻找划分new_feats.remove(self.feature_dim)# self.tree.feature_sets[self.feature_dim]:对应的是这个特征的所有特征值# chaos_lst:对应的是每个特征值的不确定度for feat, chaos in zip(self.tree.feature_sets[self.feature_dim], chaos_lst):# 这个特征值的maskfeat_mask = features == feat# 根据这个mask,找到这个特征值对应那些数据tmp_x = self._x[feat_mask, ...]# 如果这个特征值没有数据,就取下一个特征值,相当于没有必要生成新结点if len(tmp_x) == 0:continue# 否则的话生成新结点,新结点四个参数比较重要:ent,feature_dim,children,leafsnew_node = self.__class__(tree=self.tree, base=self.base, chaos=chaos,depth=self._depth + 1, parent=self, is_root=False, prev_feat=feat)# 新结点的可选维度就是上述的new_featsnew_node.feats = new_feats# 更新当前结点的子节点集self.children[feat] = new_nodeif self.sample_weight is None:local_weights = Noneelse:# 带权重的处理local_weights = self.sample_weight[feat_mask]# 需要归一化local_weights /= np.sum(local_weights)# 递归的处理新结点,需要把分块的数据传入进来new_node.fit(tmp_x, self._y[feat_mask], local_weights)

停止条件及其处理

什么时候停止生成子树?就是什么时候我们形成叶子,两种情况,见下述代码。形成叶子之后的处理:已经判断这堆数据可以当作叶子了,我们需要干什么:一是这对数据属于哪一类(少数服从多数),二是更新当前结点的列祖列组的self.leafs,告诉列祖列组我是你们正宗的leaf。停止条件及其处理就是回溯法里面的限界函数,用于剪枝,实际上这个只是预剪枝。

# =============================================================================
#   定义生成算法的准备工作:定义停止生成的准则,定义停止后该node的行为
# =============================================================================# 停止的第一种情况:当特征纬度为0或者当前node的数据的不确定性小于阈值停止# 假如指定了树的最大深度,那么当node的深度太深时也停止# 满足停止条件返回True,否则返回Falsedef stop1(self,eps):if (self._x.shape[1] == 0 or (self.chaos is not None and self.chaos < eps)or (self.tree.max_depth is not None and self._depth >=self.tree.max_depth)):# 调用停止的方法self._handle_terminate()return Truereturn False# 当最大信息增益小于阈值时停止def stop2(self, max_gain, eps):if max_gain <= eps:# 调用停止的方法self._handle_terminate()return Truereturn False# 定义该node所属类别的方法,假如特征已经选完了,将此事样本中最多的类,作为该节点的类def get_category(self):return np.argmax(np.bincount(self._y))# 定义剪枝停止的处理方法,核心思想是:将node结点转换为叶子结点def _handle_terminate(self):# 首先先生成该node的所属类别self.category = self.get_category()# 然后一路回溯,更新该节点的所有祖先的叶子结点信息_parent =self.parentwhile _parent is not None:# id(self)获取当前对象的内存地址_parent.leafs[id(self)] = self.info_dic_parent = _parent.parent

fit()

下面就到了重点,这个是整个node处理的核心函数,实现前面提到的三个方法的处理流程,每次新结点的处理都是调用这个函数来实现:传入数据集,计算信息量,得到划分的feature,是否满足停止条件,否生成子节点递归的处理。

    # 挑选出最佳划分的方法,要注意二分和多分的情况def fit(self, x, y, sample_weight, eps= 1e-8):self._x = np.atleast_2d(x)self._y = np.array(y)# 如果满足第一种停止条件,则退出函数体if self.stop1(eps):return# 使用该node的数据实例化Cluster类以便计算各种信息量_cluster = Cluster(self._x, self._y, sample_weight, self.base)# 对于根节点,需要额外计算其数据的不确定性,第一次需要计算,# 其他时候已经传入了chaosif self.is_root:if self.criterion == "gini":self.chaos = _cluster.gini()else:self.chaos = _cluster.ent()# 用于存最大增益_max_gain = 0# 用于存取最大增益的那个feature,不同的featureValue限制下的不确定度# 最后[featureValue] = 对应的不确定度_chaos_lst = []# 遍历还能选择的featuresfor feat in self.feats:# 如果是该维度是连续的,或者使用CART,则需要计算二分标准的featureValue# 的取值集合if self.wc[feat]:_samples = np.sort(self._x.T(feat))# 连续型的featureValue的取值集合_set = (_samples[:-1] + _samples[1:]) *0.5else:_set = self.tree.feature_sets[feat]# 连续型和CART,需要使用二分类的计算信息量if self.wc[feat] or self.is_cart:# 取一个featureValuefor tar in _set:_tmp_gain, _tmp_chaos_lst = _cluster.bin_info_gain(feat, tar,  criterion=self.criterion,get_chaos_lst= True, continuous=self.wc[feat])if _tmp_gain > _max_gain:(_max_gain,_chaos_lst),_max_feature,_max_tar =(_tmp_gain, _tmp_chaos_lst), feat, tar# 离散数据使用一般的处理                   else:_tmp_gain, _tmp_chaos_lst = _cluster.info_gain(feat, self.criterion, True, self.tree.feature_sets[feat])if _tmp_gain > _max_gain:(_max_gain,_chaos_lst),_max_feature =(_tmp_gain, _tmp_chaos_lst), feat# 当所有features里面最大的不确定都足够小,退出函数体                    if self.stop2(_max_gain, eps):return# 更新相关的属性# 当前结点划分的featureself.feature_dim = _max_feature# 二叉的处理if self.is_cart or self.wc[_max_feature]:self.tar = _max_tar# 根据划分进行生成结点self._gen_children(_chaos_lst)# 这个是专门针对二叉的处理,是在生成树到底之后回溯时# 生成树同一层两个结点都生成的话,看能不能剪枝# 实际这也是后剪枝的一种,但是二叉树比较好处理,剪枝的策略也比较简单# 只是简单的去重,所以x叉树没有实施,后剪枝有更加高效的策略# 如果左右结点都是叶子结点且他们都属于同一个分类,可以使用剪枝操作if (self.left_child.category is not None and self.left_child.category == self.right_child.category):self.prune()# 调用tree的方法,剪掉该节点的左右子树# 从Tree的记录所有Node的列表nodes中除去self.tree.reduce_nodes()else:# 离散情况的处理self._gen_children(_chaos_lst)# 上述的剪枝策略,只是同一父亲的子结点去重,x叉树不好实现。# 直接采用更加高效的策略后剪枝

生成树剪枝

得到一颗生成树之后,这棵树没啥约束或者仅仅只靠信息增益约束,可能枝繁叶茂,为了使这棵树在其他数据集上能取得较好的泛化性能,我们需要剪枝,删除一些叶子结点。
在树生成完毕或者局部生成完毕的剪枝称之为后剪枝,上述fit()实现针对二叉树的处理,左右结点生成完毕判断两个结点是否属于同一类别剪枝,就是后剪枝的一种,实际后剪枝有一套全局规划的策略,下次再讲。
post-prune的实现:这种情况的剪枝生成树已经生成完毕,剪枝相当于把下面所有的叶子看成一个叶子结点,那么需要把下面所有的叶子从列祖列宗的leafs谱里除名,然后把当前结点成了叶子加入到列祖列宗的leafs谱里,完了还需要标记自己和下面的结点为pruned已被剪掉,为什么自己也需要剪掉,因为叶子的信息已经在列祖列宗的leafs谱里。

    # 剪枝操作,将当前结点转化为叶子结点,这里可能觉得莫名奇妙,前面停止处理的时候不是# 已经剪枝了吗,这个地方怎么还有专门的剪枝函数,而且还要把,把删除剪枝剪掉的叶子。# 实际上这是后剪枝,称之为post-pruning, stop里面的剪枝称之为pre-pruning.def prune(self):# 调用方法计算该node属于的类别self.category = self.get_category()# 当前结点转化为叶子结点,记录其下属结点的叶子结点_pop_lst = [key for key in self.leafs]# 然后一路回溯,更新各个parent的属性叶子_parent = self.parentwhile _parent is not None: # 删除由于被剪枝而剪掉的叶子结点# 由于删除结点,对父节点产生了影响,用于后剪枝更新新损失函数_parent.affected = Truefor _k in _pop_lst:_parent.leafs.pop(_k)# 把当前结点更新进去,因为当前结点变成了叶子_parent.leafs[id(self)] = self.info_dic_parent = _parent.parent# 调用mark_pruned函数将所有的子子孙孙标记为已剪掉,pruned属性为Trueself.mark_pruned()# 重置各个属性,这个有必要吗,理论上是要删除,毕竟现在是叶子了,叶子这些都# 应该是空的self.feature_dim = Noneself.left_child = self.right_child = Noneself._children = {}self.leafs = {}# 下面实现mark_pruned函数,self._children放的是子树# 为啥把自己也置为True,自己也被剪掉了?还是因为叶子结点信息已在父亲的leafs里不重要?def mark_pruned(self):self.pruned = True# 这里使用的children属性,获取的是子树,递归的调用for _child in self.children.value():if _child is  not None:_child.mark_pruned()

至此node类的变量和方法基本实现完毕,为什么说基本呢,因为真正的后剪枝还没将,他还需要在node类里添加一些方法。

机器学习:结点的实现,决策树代码实现(二)相关推荐

  1. 【机器学习】决策树代码练习

    本课程是中国大学慕课<机器学习>的"决策树"章节的课后代码. 课程地址: https://www.icourse163.org/course/WZU-146409617 ...

  2. 温州大学《机器学习》课程代码(二)(回归)

    温州大学<机器学习>课程代码(二)(回归) 代码修改并注释:黄海广,haiguang2000@wzu.edu.cn 课件   视频 下载地址:https://github.com/feng ...

  3. MATLAB机器学习系列-9:决策树和随机森林的原理及其例子代码实现

    决策树 原理 决策树通过把样本实例从根节点排列到某个叶子节点来对其进 行分类.树上的每个非叶子节点代表对一个属性取值的测试, 其分支就代表测试的每个结果:而树上的每个叶子节点均代表 一个分类的类别,树 ...

  4. 【机器学习基础】数学推导+纯Python实现机器学习算法5:决策树之CART算法

    目录 CART概述 回归树 分类树 剪枝 Python实现示例:分类树 在数学推导+纯Python实现机器学习算法4:决策树之ID3算法中笔者已经对决策树的基本原理进行了大概的论述.本节将在上一讲的基 ...

  5. 机器学习(九)决策树,随机森林

    机器学习(九)决策树,随机森林 文章目录 机器学习(九)决策树,随机森林 一.决策树 1.1 如何理解决策树 1.2 信息论的一些基础 1.3 信息论与决策树的关系 1.3.1 信息增益 1.4 常见 ...

  6. [当人工智能遇上安全] 5.基于机器学习算法的主机恶意代码识别研究

    您或许知道,作者后续分享网络安全的文章会越来越少.但如果您想学习人工智能和安全结合的应用,您就有福利了,作者将重新打造一个<当人工智能遇上安全>系列博客,详细介绍人工智能与安全相关的论文. ...

  7. 机器学习深度学习算法及代码实现

    原文地址:https://blog.csdn.net/qq_31456593/article/details/69340697 最近在学机器学习,学习过程中收获颇多,在此留下学习记录,希望与同道中人相 ...

  8. 机器学习 の04 梯度提升决策树GBDT

    机器学习 の04 梯度提升决策树GBDT GBDT的背景知识 集成学习(ensemble learning) Bagging(Bootstrap Aggregating)算法 Boosting提升算法 ...

  9. 机器学习系列(10)_决策树与随机森林回归

    注:本篇文章接上一篇文章>>机器学习系列(9)_决策树详解01 文章目录 一.决策树优缺点 二.泰坦尼克号幸存者案例 三.随机森林介绍 1.随机森林的分类 2.重要参数 [1]n_esti ...

最新文章

  1. 大话数据结构:线性表(2)
  2. 百度App Objective-C/Swift 组件化混编之路(二)- 工程化
  3. 【Tools】Visual Studio 2019下载和安装
  4. Oracle教程之分析Oracle索引扫描四大类
  5. 二级mysql教程下载_全国计算机等级考试教程:二级MySQL数据库程序设计
  6. 324.摆动排序II
  7. Matlab中添加LibPLS安装包
  8. python股票预测模型_一种基于Python和BP神经网络的股票预测方法
  9. 2020年小米校招JAVA岗笔试第二题
  10. IT经理的两条职业路做管理还是管理咨询
  11. 仙剑奇侠传四服务器维护,《仙剑奇侠传四》无法登录怎么办_无法登录解决办法_3DM手游...
  12. MyBatis 源码分析系列文章导读 1
  13. 计算机是如何工作的 计算机原理
  14. 自己的理解——WMD
  15. 以计算机网络为话题的英语作文,以Internet为话题的英语作文
  16. NETBOX BT大改造
  17. 页面显示LCD液晶字体或者其他特殊字体
  18. DaZeng:3分钟搞定内网渗透之外网访问指定域名
  19. 和 Google Play 一起打开日本市场 | ChinaJoy 干货分享
  20. iOS——微信朋友圈小视频的播放和聊天窗口小视频的播放

热门文章

  1. 逻辑错误有哪些c语言,c语言程序,现在出现逻辑错误,哪位高手指点下啊。。。...
  2. python如何入侵服务器的_通过redis入侵服务器的步骤
  3. php推荐码生成,最新最全PHP生成制作验证码代码详解(推荐),验证码详解_PHP教程...
  4. java运行机制以及 运行流程
  5. postgres 退出_如何退出postgresql
  6. matlab敏感词输出代码,敏感词设置
  7. java和mysql中md5+base64的执行结果
  8. jmeter连接MySQL出错_MySQL数据库之jmeter连接mysql数据库报错Cannot create PoolableConnectionFactory...
  9. linux ping策略打开_linux ping策略打开_如何在Linux服务器禁止和开启ping包 互联网技术圈 互联网技术圈......
  10. fiddler如何设置过滤https_Google Analytics如何设置含有过滤器的帐户数据视图