CART:Classfication And Regression Trees分类回归树

# coding=utf-8
#Created on Feb 4, 2011
#Tree-Based Regression Methods
#@author: Peter Harrington
#
from numpy import *def loadDataSet(fileName):      #general function to parse tab -delimited floatsdataMat = []                #assume last column is target valuefr = open(fileName)for line in fr.readlines():curLine = line.strip().split('\t')fltLine = list(map(float,curLine)) #map all elements to float()dataMat.append(fltLine)return dataMatdef binSplitDataSet(dataSet, feature, value):mat0 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:] #[0] #根据特征值取样#'nozero()取:(特征feature列的值>value)的行号、列号,+【0】取行号'mat1 = dataSet[nonzero(dataSet[:,feature] > value)[0],:] #[0]return mat0,mat1def regLeaf(dataSet):#returns the value used for each leafreturn mean(dataSet[:,-1]) #在回归数中,叶节点的模型就是目标变量的均值def regErr(dataSet):return var(dataSet[:,-1]) * shape(dataSet)[0] #样本总方差#Var()-样本的二阶中心矩;var( ,ddof = 1)才是样本方差,#######将叶节点由常数改成线性函数模型,将数据集格式化成自变量X和目标变量Y,并得到线性回归系数ws#######
def linearSolve(dataSet):   #helper function used in two placesm,n = shape(dataSet)X = mat(ones((m,n))); Y = mat(ones((m,1)))#create a copy of data with 1 in 0th postionX[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]#and strip out YxTx = X.T*Xif linalg.det(xTx) == 0.0:raise NameError('This matrix is singular, cannot do inverse,\n\try increasing the second value of ops')ws = xTx.I * (X.T * Y)return ws,X,Ydef modelLeaf(dataSet):#create linear model and return coeficientsws,X,Y = linearSolve(dataSet)return wsdef modelErr(dataSet):ws,X,Y = linearSolve(dataSet)yHat = X * wsreturn sum(power(Y - yHat,2))###############找到最佳二元切分方式:特征值唯一、方差最小、样本最小(而不是熵降低最多)############
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):tolS = ops[0]; tolN = ops[1]#tolS是容许的误差下降值,tolN是切分的最小样本数,均是为了控制函数的停止时机#if all the target variables are the same value: quit and return valueif len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1return None, leafType(dataSet)m,n = shape(dataSet)#the choice of the best feature is driven by Reduction in RSS error from meanS = errType(dataSet)bestS = inf; bestIndex = 0; bestValue = 0#对每一特征下的每一个值进行二元切分,切分后误差越低越好,for featIndex in range(n-1):for splitVal in set((dataSet[:,featIndex].T.A.tolist())[0]):#splitVal in set(dataSet[:,featIndex])=>mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)#TypeError: unhashable type: 'matrix'if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continuenewS = errType(mat0) + errType(mat1)if newS < bestS: bestIndex = featIndexbestValue = splitValbestS = newS#if the decrease (S-bestS) is less than a threshold don't do the splitif (S - bestS) < tolS: return None, leafType(dataSet) #exit cond 2mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  #exit cond 3return None, leafType(dataSet)return bestIndex,bestValue#returns the best feature to split on#and the value used for that splitdef createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filteringfeat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best splitif feat == None: return val #if the splitting hit a stop condition return valretTree = {}retTree['spInd'] = featretTree['spVal'] = vallSet, rSet = binSplitDataSet(dataSet, feat, val)retTree['left'] = createTree(lSet, leafType, errType, ops)retTree['right'] = createTree(rSet, leafType, errType, ops)return retTree  #############################################################################
###上面用tolS和tolN控制tree切分的程度,为预剪枝,下面是后剪枝:用trainData训练出tree后,testData来测试,
####若合并叶节点能降低误差,则进行合并(剪枝)
def isTree(obj):return (type(obj).__name__=='dict')def getMean(tree):if isTree(tree['right']): tree['right'] = getMean(tree['right'])if isTree(tree['left']): tree['left'] = getMean(tree['left'])return (tree['left']+tree['right'])/2.0def prune(tree, testData):if shape(testData)[0] == 0: return getMean(tree) #if we have no test data collapse the treeif (isTree(tree['right']) or isTree(tree['left'])):#if the branches are not trees try to prune themlSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)if isTree(tree['right']): tree['right'] =  prune(tree['right'], rSet)#if they are now both leafs, see if we can merge themif not isTree(tree['left']) and not isTree(tree['right']):lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\sum(power(rSet[:,-1] - tree['right'],2))treeMean = (tree['left']+tree['right'])/2.0errorMerge = sum(power(testData[:,-1] - treeMean,2))if errorMerge < errorNoMerge: print ("merging")return treeMeanelse: return treeelse: return tree#######################预测代码    ######################
def regTreeEval(model, inDat):#回归树节点计算return float(model)def modelTreeEval(model, inDat):#模型树节点计算n = shape(inDat)[1]X = mat(ones((1,n+1)))X[:,1:n+1]=inDatreturn float(X*model)def treeForeCast(tree, inData, modelEval=regTreeEval):#计算一个数据inData的预测if not isTree(tree): return modelEval(tree, inData)if inData[tree['spInd']] <= tree['spVal']:if isTree(tree['left']): return treeForeCast(tree['left'], inData, modelEval)else: return modelEval(tree['left'], inData)else:if isTree(tree['right']): return treeForeCast(tree['right'], inData, modelEval)else: return modelEval(tree['right'], inData)def createForeCast(tree, testData, modelEval=regTreeEval):#由tree预测testData的yHatm=len(testData)yHat = mat(zeros((m,1)))for i in range(m):yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)return yHat#corrcoef(yHat,testData[:,-1],rowvar=0)[0,1] #相关系数越靠近1越好
############################################
#myData=loadDataSet(r'C:\Users\li\Downloads\machinelearninginaction\Ch09\ex0.txt')
#myMat=mat(myData)
#retTree=createTree(myMat)
#print(retTree)##############################################################
##########用TKinter创建GUI###############################
from tkinter import *
import matplotlib
matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
def reDraw(tolS,tolN):reDraw.f.clf()        # clear the figurereDraw.a = reDraw.f.add_subplot(111)if chkBtnVar.get():if tolN < 2: tolN = 2myTree=createTree(reDraw.rawDat, modelLeaf,\modelErr, (tolS,tolN))yHat = createForeCast(myTree, reDraw.testDat, \modelTreeEval)else:myTree=createTree(reDraw.rawDat, ops=(tolS,tolN))yHat = createForeCast(myTree, reDraw.testDat)reDraw.a.scatter(reDraw.rawDat[:,0].A.T, reDraw.rawDat[:,1].A.T, s=5) #use scatter for data setreDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0) #use plot for yHatreDraw.canvas.show()def getInputs():try: tolN = int(tolNentry.get())except: tolN = 10 print ("enter Integer for tolN")tolNentry.delete(0, END)tolNentry.insert(0,'10')try: tolS = float(tolSentry.get())except: tolS = 1.0 print ("enter Float for tolS")tolSentry.delete(0, END)tolSentry.insert(0,'1.0')return tolN,tolSdef drawNewTree():tolN,tolS = getInputs()#get values from Entry boxesreDraw(tolS,tolN)root=Tk()reDraw.f = Figure(figsize=(5,4), dpi=100) #create canvas
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.show()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)Label(root, text="tolN").grid(row=1, column=0)
tolNentry = Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0,'10')
Label(root, text="tolS").grid(row=2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0,'1.0')
Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3)
chkBtnVar = IntVar()
chkBtn = Checkbutton(root, text="Model Tree", variable = chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)reDraw.rawDat = mat(loadDataSet(r'C:\Users\li\Downloads\machinelearninginaction\Ch09\sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:,0]),max(reDraw.rawDat[:,0]),0.01)
reDraw(1.0, 10)root.mainloop()

CART树回归、剪枝、Tkinter GUI相关推荐

  1. 树回归--python Tkinter库创建GUI(2)

    简单的Tkinter库创建GUI的例子可参考: http://blog.csdn.net/lilong117194/article/details/78456376 下面是代码: # -*- codi ...

  2. CART树(分类回归树)

    传送门 决策树算法原理(ID3,C4.5) CART回归树 决策树的剪枝 在决策树算法原理(ID3,C4.5)中,提到C4.5的不足,比如模型是用较为复杂的熵来度量,使用了相对较为复杂的多叉树,只能处 ...

  3. CART树分类、回归、剪枝实现

    决策树ID3,C4.5是多叉树,CART树是一个完全二叉树,CART树不仅能完成分类也能实现回归功能,所谓回归指的是目标是一个连续的数值类型,比如体重.身高.收入.价格等,在介绍ID3,C4.5其核心 ...

  4. CART回归树模型树 生成 剪枝 in Python

    现实中,数据集中经常包含一些复杂的相互关系,使得输入数据和目标变量之间呈现非线性关系.对这些复杂的关系建模,一种可行的方式是使用树来对预测值进行分段,包括分段常数或者分段直线,即通过树结构对数据进行切 ...

  5. 【机器学习实战 第九章】树回归 CART算法的原理与实现 - python3

    本文来自<机器学习实战>(Peter Harrington)第九章"树回归"部分,代码使用python3.5,并在jupyter notebook环境中测试通过,推荐c ...

  6. 机器学习--CART分类回归树

    目录 文章目录 目录 前言 1.CART回归树简介 2.剪枝策略 3.模型树 4.线性回归 回归树 模型树比较 前言 虽然许多问题都可以用线性方法取得良好的结果,但现实中也有许多问题是非线性的,用线性 ...

  7. 监督学习 | CART 分类回归树原理

    文章目录 CART 算法 1. CART 生成 1.1 回归树生成 最小二乘回归树生成算法 1.2 分类树生成 基尼指数 CART 生成算法 参考文献 相关文章: 机器学习 | 目录 监督学习 | I ...

  8. 机器学习-有监督学习-分类算法:决策树算法【CART树:分类树(基于信息熵;分类依据:信息增益、信息增益率、基尼系数)、回归树(基于均方误差)】【损失函数:叶节点信息熵和】【对特征具有很好的分析能力】

    一.决策树概述 注:生产实践中,不使用决策树,太简单,而是使用决策树的升级版:集成学习算法. 集成学习算法有: Random Forest(随机森林) Extremely Randomized For ...

  9. 决策树ID3、决策树C4.5、决策树CART、CART树的生成、树的剪枝、从ID3到CART、从决策树生成规则、决策树优缺点

    决策树ID3.决策树C4.5.决策树CART.CART树的生成.树的剪枝.从ID3到CART.从决策树生成规则.决策树优缺点 目录

最新文章

  1. OpenCASCADE:使用 XDE 文档
  2. Eclipse新建Android项目后,出现“The import android.support.v7.app cannot be resolved”
  3. 【Hibernate】could not instantiate class.. from tuple] with root cause
  4. C#异步批量下载文件
  5. 一般市区有测速吗_高速公路增加这么多龙门架,有抓拍超速违章功能吗?可要仔细辨别...
  6. mac 连接hbase的图形化界面_Mac 视觉史(二):90 年代失败 Mac 操作系统大赏
  7. 深入剖析BIO到NIO演变史
  8. 宏山激光sigmatube套料软件多台电脑安装教程
  9. c语言现代程序设计 现代方法_红河分局加强水文现代化新技术、新仪器、新方法的使用和创新...
  10. 陈丹琦新作:关系抽取新SOTA,用pipeline方式挫败joint模型
  11. 可以弹奏的钢琴页面(HTML实现)
  12. 【Linux记录】Linux 可以telnet通localhost,不能telnet ip,telnet localhost正常,telnet ip失败。
  13. R语言lowess函数数据平滑实战(Locally Weighted Regression, Loess)
  14. 《人体解剖学(基础医学)》
  15. 使用beautifulSoup
  16. 《陶行知教育文集》读后感
  17. 图解k8s中pod的创建流程
  18. 【错误记录】Visual Studio 2019 中运行 Unity C# 脚本时报错 ( 根据解决方案, 可能需要安装额外的组件才能获得 | .NET 桌面开发 | 使用 Unity 的游戏开发 )
  19. 怎么设置CAD建筑标高?CAD建筑标高设置技巧
  20. 内卷?躺平?先看看这6个高质量知识星球

热门文章

  1. DevOps实战 —— 如何高效地远程部署?自动化运维利器 Fabric 教程
  2. Redis面试 - 如何保证缓存与数据库的双写一致性?
  3. Linux日志服务器的搭建
  4. 白帽子讲web安全——访问控制
  5. Python爬虫连载16-OCR工具Tesseract、Scrapt初步
  6. Dockerfile文件中CMD指令与ENTRYPOINT指令的区别
  7. 用until编写一段shell程序,计算1~10的平方和
  8. 【Python】五子棋项目记录
  9. 我从创建具有仅仅一年编码经验的视频游戏中学到了什么
  10. Vue——基础(对象、属性样式操作、条件、循环、事件、绑定)