这个代码是可以直接运行的,数据集在github链接下的breast_cancer.csv,把代码和数据集放到同一目录下就好了

或者:百度云链接,提取码:eu8q

详细分析和完善下周日再写吧,感觉复杂度好高啊。。。。。。

import numpy as np
import copy as copy
class Node(object):def __init__(self):self.left = Noneself.right = Noneself.parent = Noneself.items = []self.feature = Noneself.feature_value = None@propertydef predict(self):maxCount = 0for i in np.unique(self.items[1]):if self.items[1].count(i) > maxCount:maxCount = self.items[1].count(i)maxPredict = ireturn maxPredictdef __str__(self):if self.left == None and self.right == None:return "size:%d predict:%s"%(len(self.items),str(self.predict))else:return "feature:%s feature_value:%s"%(self.feature,self.feature_value)def get_leafEntropy(self):g = 1n = len(self.items[1])p = {}for item in self.items[1]:p.setdefault(item,0)p[item] += 1for v in p.values():g -= (v / n) ** 2# print(self.items)# print(p)return gdef get_leaf_num(self):if self.left is not None and self.right is not None:return self.right.get_leaf_num() + self.left.get_leaf_num()else:return 1class Dtree(object):def __init__(self):self.root = Node()def __str__(self):queue = [(self.root, -1)]level = 0res = []while queue:node,prelevel = queue.pop(0)res.append("%d -> %d: %s"%(prelevel, prelevel + 1, str(node)))if node.left:queue.append((node.left, prelevel + 1))if node.right:queue.append((node.right, prelevel + 1))level += 1return "\n".join(res)def get_nodeEntropy(self,node):ll = len(node.left.items[0])lr = len(node.right.items[0])return (ll * node.left.get_leafEntropy() + lr * node.right.get_leafEntropy()) / (ll + lr)def split(self,feature,feature_value,idx,X):div = [[],[]]    #对于不同类型的特征选用不同的划分方法:对于离散的,根据是否相等来划分;对于连续的,根据大于还是小于进行划分# for i in idx:#     if X[i][feature] == feature_value:#         div[0].append(i)#     else:#         div[1].append(i)# return divfor i in idx:if X[i][feature] <= feature_value:div[0].append(i)else:div[1].append(i)return divdef get_G(self,idx,X,y):g = 1n = len(idx)p = {}for i in idx:p.setdefault(y[i], 0)p[y[i]] += 1for v in p.values():g -= (v / n) ** 2return gdef get_bestFeatureValue_forAFeature(self,X,y,idx,feature,best_feature,best_feature_value,minG):feature_vs = np.unique([X[i][feature] for i in idx])for feature_v in feature_vs:div = self.split(feature,feature_v,idx,X)ll = len(div[0])lr = len(div[1])curG = (ll * self.get_G(div[0],X,y) + lr * self.get_G(div[1],X,y)) / (ll + lr)# print(feature,feature_v,curG)if curG < minG:minG = curGbest_feature = featurebest_feature_value = feature_vreturn best_feature,best_feature_value,minGdef get_bestFeatureAndValue(self,X,y,idx):best_feature = 0best_feature_value = X[0][best_feature]minG = 1for feature in range(len(X[0])):best_feature,best_feature_value, minG = self.get_bestFeatureValue_forAFeature(X,y,idx,feature,best_feature,best_feature_value,minG)return best_feature,best_feature_valuedef create_Dtree(self,X,y):queue = [(self.root,range(len(X)))]while queue:node,idx = queue.pop(0)if len(np.unique([y[i] for i in idx])) == 1:node.items = [copy.deepcopy(idx),[y[i] for i in idx]]# node.predict = y[idx[0]]continuebest_feature,best_feature_value = self.get_bestFeatureAndValue(X,y,idx)print("bestFeature: %s, bestFeatureValue: %s"%(str(best_feature),str(best_feature_value)))node.feature = best_featurenode.feature_value = best_feature_valuenode.items = [copy.deepcopy(idx),[y[i] for i in idx]]  #为便于剪枝时比较单节点树形式和子树形式的基尼系数,子树的标记也需要保存到根节点div = self.split(best_feature,best_feature_value,idx,X)if div[0] != []:node.left = Node()node.left.parent = nodequeue.append((node.left,div[0]))if div[1] != []:node.right = Node()node.right.parent = nodequeue.append((node.right,div[1]))def predict(self,xi):node = self.rootwhile node.left or node.right:if xi[node.feature] <= node.feature_value:node = node.leftelse:node = node.rightreturn node.predictdef get_min_gt(self):minGt = 0targetNode = Nonequeue = [(self.root)]while queue:node = queue.pop(0)Ct = node.get_leafEntropy()  #寻找最小的g(t),见统计学习方法p86CTt = self.get_nodeEntropy(node)leafnum = node.get_leaf_num()curGt = (Ct - CTt) / (leafnum - 1)if minGt == 0 or curGt < minGt:minGt = curGttargetNode = nodeif node.left.left and node.left.right:queue.append((node.left))if node.right.left and node.right.right:queue.append((node.right))return targetNode,minGtdef merge_subTree(self,node):node.left = Nonenode.right = None# 用统计学习方法第五章贷款申请样本测试一下:
# 注意这些特征都是离散的,所以要在split函数里把划分标准改成是否相等
# def test_create():
#     dt = Dtree()
#     X = [
#         ['young', 'no', 'no', 'normal'],
#         ['young', 'no', 'no', 'good'],
#         ['young', 'yes', 'no', 'good'],
#         ['young', 'yes', 'yes', 'normal'],
#         ['young', 'no', 'no', 'normal'],
#
#         ['midage', 'no', 'no', 'normal'],
#         ['midage', 'no', 'no', 'good'],
#         ['midage', 'yes', 'yes', 'good'],
#         ['midage', 'no', 'yes', 'verygood'],
#         ['midage', 'no', 'yes', 'verygood'],
#
#         ['old', 'no', 'yes', 'verygood'],
#         ['old', 'no', 'yes', 'good'],
#         ['old', 'yes', 'no', 'good'],
#         ['old', 'yes', 'no', 'verygood'],
#         ['old', 'no', 'no', 'normal'],
#     ]
#     y = ['no', 'no', 'yes', 'yes', 'no',
#          'no', 'no', 'yes', 'yes', 'yes',
#          'yes', 'yes', 'yes', 'yes', 'no']
#     dt.create_Dtree(X,y)
#     dt2 = copy.deepcopy(dt)
#     print("dt:\n",dt)
#     print("dt2:\n",dt2)
#     # print(dt.get_C_multiNode(dt.root))  #准确分类时,任意节点的基尼系数为零
#     dt2.root.left.feature = 100
#     print("dt:\n", dt)
#     print("dt2:\n", dt2)
#
#     print(dt.root.items)
#
#     # print(dt.get_min_gt())
# test_create()import numpy as np
data = np.loadtxt('breast_cancer.csv',delimiter=',',dtype=float)
n = data.shape[0]# 按顺序划分数据,这种方式得到的准确率要比随机生成低,大概几个百分点
# X_train = data[:(int) (0.6 * n), : -1]
# y_train = [int(x) for x in data[:(int) (0.6 * n), -1]]
#
# X_cv = data[(int) (0.6 * n) : (int)(0.8 * n), : -1]
# y_cv = data[(int) (0.6 * n) : (int)(0.8 * n), -1]
#
# X_test = data[(int) (0.8 * n) :, : -1]
# y_test = data[(int) (0.8 * n) :, -1]from numpy.random import choice,seedX = data[:,:-1]
y = data[:,-1]train_size =int(0.6 * n)  #训练集大小
cv_size = int(0.2 * n)      #交叉验证集大小
test_size = int(0.2 * n)      #测试集大小train_rows = choice(range(n),size = train_size,replace=False)
X_train = [X[i] for i in train_rows]
y_train = [y[i] for i in train_rows]remains = [i for i in range(n) if i not in train_rows]
cv_rows = choice(remains,size=cv_size,replace=False)
X_cv = [X[i] for i in cv_rows]
y_cv = [y[i] for i in cv_rows]X_test = [X[i] for i in remains if i not in cv_rows]
y_test = [y[i] for i in remains if i not in cv_rows]t1 = Dtree()
t1.create_Dtree(X_train,y_train)predictTrue = 0
queue = [(t1)]
bestTree = t1
maxAcc = 0
alpha = []while queue:   #剪枝并进行交叉验证curTree = queue.pop(0)predictTrue = 0for i in range(len(X_cv)):curPredict = curTree.predict(X_cv[i])if curPredict == y_cv[i]:predictTrue += 1curAcc = predictTrue / len(y_cv)print("交叉验证准确率: ",curAcc)if curAcc > maxAcc:bestTree = curTreemaxAcc = curAccif curTree.root.left and curTree.root.right:nextTree = copy.deepcopy(curTree)bestNode,ai = nextTree.get_min_gt()  #选在交叉验证集上准确率最高的决策树作为最优模型print("每次生成的参数alpha: ",ai)alpha.append(ai)nextTree.merge_subTree(bestNode)queue.append(nextTree)predictTrue = 0
for i in range(len(X_test)):curPredict = bestTree.predict(X_test[i])if curPredict == y_test[i]:predictTrue += 1
acc = predictTrue / len(y_test)
print(acc)

决策树实现(CART生成及剪枝)相关推荐

  1. 决策树的生成与剪枝CART

    跟我一起机器学习系列文章将首发于公众号:月来客栈,欢迎文末扫码关注! 在之前的一篇文章中,笔者分别介绍了用ID3和C4.5这两种算法来生成决策树.其中ID3算法每次用信息增益最大的特征来划分数据集,C ...

  2. 机器学习爬大树之决策树(CART与剪枝)

    分类与回归树(classification and regression tree,CART)是应用广泛的决策树学习方法,同样由特征选择,树的生成以及剪枝组成,既可以用于分类也可以用于回归.CART假 ...

  3. 决策树:CART决策树剪枝算法(超详细)

    文章目录 CART算法 1. CART生成算法 2. CART剪枝算法 CART剪枝算法流程 CART剪枝算法解析( 超详细 ) CART算法 CART假设决策树是二叉树,内部结点特征的取值为&quo ...

  4. 集成学习—决策树(CART)

    集成学习: 集成学习(Ensemble learning)是指通过构建并合并多个学习器(模型)来完成学习任务.它一般先产生一组个体学习器,再用某种策略将它们结合起来,常常可以获得比单一学习器显著优越的 ...

  5. [机器学习算法]决策树和CART树

    决策树综述 决策树的工作原理 决策树(decision tree)分类法是一种简单但广泛使用的分类技术.以是否贷款违约的二分类问题为例,当我们希望根据给定的训练集习得一个模型对新出现的贷款人进行分类时 ...

  6. 决策树一CART算法(第四部分)

    决策树一CART算法(第四部分) CART树的剪枝:算法步骤 输入:CART算法生成的决策树. 输出:最优决策树T 设K=0,T=T0K=0,T=T_0K=0,T=T0​ ,从完整的决策树出发 ​ k ...

  7. 决策树(CART)算法总结

    文章目录 1.决策树原理 2.决策树优缺点 3.CART算法 4.CART算法实现 5. 应用实例--泰坦尼克号数据集 5.1 数据集获取 5.2 数据描述 5.3 代码实例 1.决策树原理 决策树算 ...

  8. 树模型之三种常见的决策树:CART,…

    树模型(又称决策树或者树结构模型):基本思想和方差分析中的变异分解极为相似. 目的(基本原则):将总研究样本通过某些牲(自变量取值)分成数个相对同质的子样本.每一子样本因变量的取值高度一致,相应的变异 ...

  9. 决策树一一CART算法(第三部分)

    决策树一一CART算法(第三部分) CART-回归树模型 ​ 如果输出变量是 连续 的,对应的就是 回归 问题,对于决策树而言,输出的信息一定就是叶子结点,所以需要将连续变量按照一定的要求划分. 回归 ...

最新文章

  1. 你没听说过的Go语言惊人优点
  2. 利用BP神经网络教计算机进行非线函数拟合(代码部分单层)
  3. linkedlist(c语言_简单实现)
  4. android—获取网络数据
  5. 使用QT创建PythonGUI程序
  6. [译]Hour 7 Teach.Yourself.WPF.in.24.Hours
  7. axture动画原型制作_Axure制作原型-基础操作
  8. vue——走马灯-类轮播图
  9. java 查找链表中间元素_java查找链表中间元素_如何通过Java单次查找链表的中间元素...
  10. P1666 前缀单词
  11. ubuntu20.04安装mysql_Ubuntu 20.04安装MySQL 8.0.20记
  12. 请不要再用那种态度把我搞火了!
  13. 配置Windows Server 2008 允许多用户远程桌面连接
  14. 3号团队-团队任务4:每日例会(2018-11-28)
  15. KMP模式匹配算法——C++
  16. 计算机图形学(二)——实验二:直线的生成算法
  17. centos7 安装最新rabbitmq,并设置开机自启
  18. 如何在 Python3 中对列表 通过比较排序(不懂就问)?
  19. 【C语言】输入一个正整数 n,输入 n 个数,生成一个 n*n 的矩阵, 矩阵中第 1 行是输入的 n 个数,以后每一行都是上一行循环左移一个元素。
  20. 【H2】绘制三角警示牌,使用turtle绘制三角警示牌,陈斌老师北京大学暑期学校Python语言基础与应用

热门文章

  1. Qt实现telnet服务【记录】
  2. [Javascript 高级程序设计]学习心得记录3 根据对象数组的属性进行排序
  3. 距离李现生日还有3天 与荣耀30一起解锁生日专属礼包
  4. Android手机号码获取问题
  5. LONGSYS 64G M6固态硬盘SM2244G主控开卡
  6. java环境安装及java编译
  7. A direct formulation for sparse PCA using semidefinite programming
  8. 2019牛客多校第十场 F.Popping Balloons
  9. 架构 Varnish+nginx+php(FastCGI)+MYSQL5+MenCache+MenC
  10. 强化学习系列7:无模型的蒙特卡洛法