机器学习--CART分类回归树
目录
文章目录
- 目录
- 前言
- 1.CART回归树简介
- 2.剪枝策略
- 3.模型树
- 4.线性回归 回归树 模型树比较
前言
虽然许多问题都可以用线性方法取得良好的结果,但现实中也有许多问题是非线性的,用线性模型并不能很好的拟合数据,这种情况下可以使用树回归来拟合数据。因此本文特别介绍一下CART, 树剪枝,模型树等等算法。
1.CART回归树简介
传统决策树是一种贪心算法,在给定时间内做出最佳选择,不关心是否达到全局最优。切分过于迅速,特征一旦使用后面将不再使用。不能处理连续型特征,进行离散化可能会破坏连续变量的内在特征。
CART 分类回归树,既能分类又能回归。CRAT来进行节点决策时,使用二元切分来处理连续型变量,给定特征属性以及特征值,若大于该值则执行左子树,相反则放入右子树。当某个节点不能再切分时,节点值是单个值(CART),也可以是一个线性方程(模型树)。
加载数据集 按行加载到矩阵中:
def loadDataSet(fileName): #general function to parse tab -delimited floatsdataMat = [] #assume last column is target valuefr = open(fileName)for line in fr.readlines():curLine = line.strip().split('\t')fltLine = list(map(float,curLine)) #map all elements to float()dataMat.append(fltLine)return dataMat
按照某一特征以及响应特征值来对数据集进行划分:
feature是特征属性的索引 即列数 value划分阈值 如果value大于阈值则放入mat0 否则放入mat1
def binSplitDataSet(dataSet, feature, value):mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]return mat0,mat1
- 创建回归树:
- 找到最佳待切分属性;
- 如果该节点不能切分,则该节点存为子节点
- 执行二元切分
- 右子树调用createTree()
- 左子树调用createTree()
leafType 创建叶节点的函数 errType代表误差计算函数 每一个节点使用字典来存储,分别包含spInd spVal left right等key值。
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filteringfeat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best splitif feat == None: return val #if the splitting hit a stop condition return valretTree = {}retTree['spInd'] = featretTree['spVal'] = vallSet, rSet = binSplitDataSet(dataSet, feat, val)retTree['left'] = createTree(lSet, leafType, errType, ops)retTree['right'] = createTree(rSet, leafType, errType, ops)return retTree
树节点划分的度量,计算连续函数的混乱度(决策树使用信息熵以及基尼纯等),这里可以采用数据的总方差来计算数据的混乱度,均方差乘以数据集的样本数。
遍历所有特征以及所有特征值使总方差最小的值即为划分特征以及划分阈值。
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):tolS = ops[0]; tolN = ops[1]#if all the target variables are the same value: quit and return valueif len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1return None, leafType(dataSet)m,n = shape(dataSet)#the choice of the best feature is driven by Reduction in RSS error from meanS = errType(dataSet)bestS = inf; bestIndex = 0; bestValue = 0for featIndex in range(n-1):for splitVal in set((dataSet[:,featIndex].T.A.tolist())[0]):mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continuenewS = errType(mat0) + errType(mat1)if newS < bestS: bestIndex = featIndexbestValue = splitValbestS = newS#if the decrease (S-bestS) is less than a threshold don't do the splitif (S - bestS) < tolS: return None, leafType(dataSet) #exit cond 2mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #exit cond 3return None, leafType(dataSet)return bestIndex,bestValue#returns the best feature to split on#and the value used for that split
tolS为容许的误差最小下降值,当划分一次误差小于该值时,提升效果不大,直接返回。
tolN为切分的最少样本数,当切分之后,左右子数量小于tolN,说明切分字节过小,直接返回。
leafType为叶子结点的创建函数,采用均值方式
def regLeaf(dataSet):#returns the value used for each leafreturn mean(dataSet[:,-1])
errType为误差估计函数,这里使用总方差,即均方差乘以样本总数
def regErr(dataSet):return var(dataSet[:,-1]) * shape(dataSet)[0]
如果某个节点数据特征值都相同,则无法继续划分,直接返回叶子结点。
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
遍历每一个特征以及相应的特征值来进行划分,计算每一种划分的总方差,返回最优的特征属性以及特征阈值:
for featIndex in range(n-1):for splitVal in set(dataSet[:,featIndex]):
绘出样本集的分布图:
def plotarr(arr):import matplotlib.pyplot as pltfig = plt.figure()ax = fig.add_subplot(111)ax.scatter(arr[:,0].flatten().A[0], arr[:,1].flatten().A[0])plt.show()
运行测试如下:
加载另一数据集:
得到CART回归结果:
2.剪枝策略
当回归树叶子结点过多时,容易发生过拟合,导致泛化性能降低。可以采取剪枝来防止过拟合,有预剪枝以及后剪枝。
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):tolS = ops[0]; tolN = ops[1]#if all the target variables are the same value: quit and return valueif len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1return None, leafType(dataSet)m,n = shape(dataSet)#the choice of the best feature is driven by Reduction in RSS error from meanS = errType(dataSet)bestS = inf; bestIndex = 0; bestValue = 0for featIndex in range(n-1):for splitVal in set((dataSet[:,featIndex].T.A.tolist())[0]):mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continuenewS = errType(mat0) + errType(mat1)if newS < bestS: bestIndex = featIndexbestValue = splitValbestS = newS#if the decrease (S-bestS) is less than a threshold don't do the splitif (S - bestS) < tolS: return None, leafType(dataSet) #exit cond 2mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #exit cond 3return None, leafType(dataSet)return bestIndex,bestValue#returns the best feature to split on#and the value used for that split
其中tolS与tolN就能在一定程度上防止过拟合,主要采用预剪枝。通过tolS如果剪枝对于数据集的误差降低不大则可以不划分节点,tolN如果剪枝之后叶子结点数据过少,则也可以预剪枝处理。这对参数tolS,tolN的取值提出了很高的要求,往往难以设置求解。
后剪枝:将数据分为训练集与测试集,首先构建一颗完整树,然后依次寻找叶子结点,用测试集来判断将叶子结点合并是否能降低测试误差,若能则采取后剪枝。
基于已有的树切分测试数据:
如果存在任一子集是一棵树,在该子集继续剪枝过程。
计算将两个叶子结点合并后的误差
计算不合并的误差
若合并会降低误差,则合并两个叶子结点
判断某一节点是否是一棵树,及判断是否为字典类型:
def isTree(obj):return (type(obj).__name__=='dict')
执行树坍塌过程,返回树的平均值
def getMean(tree):if isTree(tree['right']): tree['right'] = getMean(tree['right'])if isTree(tree['left']): tree['left'] = getMean(tree['left'])return (tree['left']+tree['right'])/2.0
进行后剪枝处理:
def prune(tree, testData):if shape(testData)[0] == 0: return getMean(tree) #if we have no test data collapse the tree#如果该树是子集,则划分测试数据,继续后剪枝if (isTree(tree['right']) or isTree(tree['left'])):#if the branches are not trees try to prune themlSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)#if they are now both leafs, see if we can merge them#如果节点是叶子结点if not isTree(tree['left']) and not isTree(tree['right']):#划分测试数据lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])#没有合并前的误差errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\sum(power(rSet[:,-1] - tree['right'],2))#合并后的误差 合并后 节点值变为两个子节点的平均值treeMean = (tree['left']+tree['right'])/2.0#取出最后一列y的值真实值与预测值求总方差errorMerge = sum(power(testData[:,-1] - treeMean,2))if errorMerge < errorNoMerge: print ("merging")return treeMeanelse: return treeelse: return tree
加载数据集,创建一颗最完整的分类回归树, 设置tolS=0, tolN=1
>>> dd = mat(regTrees.loadDataSet('ex2.txt'))
>>> mt = regTrees.createTree(dd, ops(0,1))
此时回归树
剪枝之后:
有一部分节点被剪掉。
3.模型树
前面CART叶子结点为某个值,现在可以把叶子结点变为一个分段函数,即某一个叶子结点下面允许分段函数形式的数据存在。
将一个数据集求出线性拟合函数:
def linearSolve(dataSet): #helper function used in two placesm,n = shape(dataSet)X = mat(ones((m,n))); Y = mat(ones((m,1)))#create a copy of data with 1 in 0th postionX[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]#and strip out YxTx = X.T*Xif linalg.det(xTx) == 0.0:raise NameError('This matrix is singular, cannot do inverse,\n\try increasing the second value of ops')ws = xTx.I * (X.T * Y)return ws,X,Y
首先进行数据矩阵变换,利用线性模型直接求解回归系数ws
如果一个节点是叶子结点时,需要存储ws系数权向量
def modelLeaf(dataSet):#create linear model and return coeficientsws,X,Y = linearSolve(dataSet)return ws
当采用线性模型时,使用平方误差和来计算总误差:
def modelErr(dataSet):ws,X,Y = linearSolve(dataSet)yHat = X * wsreturn sum(power(Y - yHat,2))
加载数据集进行测试:
def testmodel():tt = mat(loadDataSet('exp2.txt'))return createTree(tt, modelLeaf, modelErr, (1, 10))
数据集分布:
4.线性回归 回归树 模型树比较
通过对于同一份数据进行训练模型,在通过测试集比较不同模型之间的性能差异。
模型树与回归树预测值的输出:
def treeForeCast(tree, inData, modelEval=regTreeEval):if not isTree(tree): return modelEval(tree, inData)if inData[tree['spInd']] > tree['spVal']:if isTree(tree['left']): return treeForeCast(tree['left'], inData, modelEval)else: return modelEval(tree['left'], inData)else:if isTree(tree['right']): return treeForeCast(tree['right'], inData, modelEval)else: return modelEval(tree['right'], inData)
tree训练树所得,inData为待预测的样本行向量,modelEval表示节点类型,当modelEval=regTreeEval说明叶子节点为分类类型,节点值为具体的分类值,即预测值值直接返回节点值即可,当modelEval=modelTreeEval时,说明叶子结点为回归类型,节点值为线性权向量,返回值应该与测试数据相乘得到最终预测值。
def regTreeEval(model, inDat):return float(model)def modelTreeEval(model, inDat):n = shape(inDat)[1]X = mat(ones((1,n+1)))X[:,1:n+1]=inDatreturn float(X*model)
返回测试集的预测值,列向量:
def createForeCast(tree, testData, modelEval=regTreeEval):m=len(testData)yHat = mat(zeros((m,1)))for i in range(m):yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)return yHat
利用相关系数来衡量数据拟合情况:
def regtree():traindata = mat(loadDataSet('bikeSpeedVsIq_train.txt'))testdata = mat(loadDataSet('bikeSpeedVsIq_test.txt'))mt = createTree(traindata, ops=(1, 20))yHat = createForeCast(mt, testdata[:,0])return corrcoef(yHat, testdata[:,1], rowvar=0)[0,1]def modeltree():traindata = mat(loadDataSet('bikeSpeedVsIq_train.txt'))testdata = mat(loadDataSet('bikeSpeedVsIq_test.txt'))mt = createTree(traindata, modelLeaf, modelErr, ops=(1, 20))yHat = createForeCast(mt, testdata[:,0], modelTreeEval)return corrcoef(yHat, testdata[:,1], rowvar=0)[0,1]def reg():traindata = mat(loadDataSet('bikeSpeedVsIq_train.txt'))testdata = mat(loadDataSet('bikeSpeedVsIq_test.txt'))ws, x, y = linearSolve(traindata)yHat=[0]*shape(testdata)[0]for i in range(shape(testdata)[0]):yHat[i] = testdata[i,0]*ws[1,0]+ws[0,0]return corrcoef(yHat, testdata[:,1], rowvar=0)[0,1]
可知模型树拟合效果最好
5.Tkinter库图形化
使用tkinter库来实现是图画化展示数据拟合
from numpy import *#python3导入方式不变from tkinter import *import regTreesimport matplotlibmatplotlib.use('TkAgg')from matplotlib.backends.backend_tkagg import FigureCanvasTkAggfrom matplotlib.figure import Figuredef reDraw(tolS,tolN):reDraw.f.clf() # clear the figurereDraw.a = reDraw.f.add_subplot(111)if chkBtnVar.get():if tolN < 2: tolN = 2#绘出模型树myTree=regTrees.createTree(reDraw.rawDat, regTrees.modelLeaf,\regTrees.modelErr, (tolS,tolN))yHat = regTrees.createForeCast(myTree, reDraw.testDat, \regTrees.modelTreeEval)else:#绘出回归树myTree=regTrees.createTree(reDraw.rawDat, ops=(tolS,tolN))yHat = regTrees.createForeCast(myTree, reDraw.testDat)#绘出数据分布时,矩阵和一位数组之间的转换reDraw.a.scatter(reDraw.rawDat[:,0].flatten().A[0], reDraw.rawDat[:,1].flatten().A[0], s=5) #use scatter for data setreDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0) #use plot for yHat#修改draw()为show()reDraw.canvas.draw()def getInputs():try: tolN = int(tolNentry.get())except: tolN = 10 print ("enter Integer for tolN")tolNentry.delete(0, END)tolNentry.insert(0,'10')try: tolS = float(tolSentry.get())except: tolS = 1.0 print ("enter Float for tolS")tolSentry.delete(0, END)tolSentry.insert(0,'1.0')return tolN,tolSdef drawNewTree():tolN,tolS = getInputs()#get values from Entry boxesreDraw(tolS,tolN)root=Tk()reDraw.f = Figure(figsize=(5,4), dpi=100) #create canvasreDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)# show()方法应该修改为draw()reDraw.canvas.draw()reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)Label(root, text="tolN").grid(row=1, column=0)tolNentry = Entry(root)tolNentry.grid(row=1, column=1)tolNentry.insert(0,'10')Label(root, text="tolS").grid(row=2, column=0)tolSentry = Entry(root)tolSentry.grid(row=2, column=1)tolSentry.insert(0,'1.0')Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3)chkBtnVar = IntVar()chkBtn = Checkbutton(root, text="Model Tree", variable = chkBtnVar)chkBtn.grid(row=3, column=0, columnspan=2)reDraw.rawDat = mat(regTrees.loadDataSet('sine.txt'))reDraw.testDat = arange(min(reDraw.rawDat[:,0]),max(reDraw.rawDat[:,0]),0.01)reDraw(1.0, 10)root.mainloop()
由于python3的变化,代码需要改变如下:
1.from tkinter import * 库导入库名变为小写
2.reDraw.canvas.draw() FigureCanvasTkAgg对象draw方法而不是show()
3.reDraw.a.scatter(reDraw.rawDat[:,0].flatten().A[0], reDraw.rawDat[:,1].flatten().A[0], s=5)进行数据分布绘制时需要转换矩阵为一维数组。
改变tolS tolN的值,绘制如下:
机器学习--CART分类回归树相关推荐
- id3决策树 鸢尾花 python_机器学习之分类回归树(python实现CART)
机器学习之分类回归树(python实现CART) 之前有文章介绍过决策树(ID3).简单回顾一下:ID3每次选取最佳特征来分割数据,这个最佳特征的判断原则是通过信息增益来实现的.按照某种特征切分数据后 ...
- 机器学习之分类回归树(CART)
前言 写这一章本来是想来介绍GBDT-LR这一个推荐模型的.但是这里面就涉及到了很多机器学习的基础树形算法,思前想后还是决定分成几篇文章来写,这里先介绍一下CART数,因为在GBDT中用来分类回归的树 ...
- 监督学习 | CART 分类回归树原理
文章目录 CART 算法 1. CART 生成 1.1 回归树生成 最小二乘回归树生成算法 1.2 分类树生成 基尼指数 CART 生成算法 参考文献 相关文章: 机器学习 | 目录 监督学习 | I ...
- 【机器学习】决策树——CART分类回归树(理论+图解+公式)
- python 决策树回归参数_python决策树之CART分类回归树详解
{"moduleinfo":{"card_count":[{"count_phone":1,"count":1}],&q ...
- 机器学习实战(八)分类回归树CART(Classification And Regression Tree)
目录 0. 前言 1. 回归树 2. 模型树 3. 剪枝(pruning) 3.1. 预剪枝 3.2. 后剪枝 4. 实战案例 4.1. 回归树 4.2. 模型树 学习完机器学习实战的分类回归树,简单 ...
- 机器学习算法之CART(分类回归树)概要
分类回归树 classification and regression tree(C&RT) racoon 优点 (1)可自动忽略对目标变量没有贡献的属性变量,也为判断属性变量的重要性,减 ...
- 机器学习系列之手把手教你实现一个分类回归树
https://www.ibm.com/developerworks/cn/analytics/library/machine-learning-hands-on5-cart-tree/index.h ...
- CART决策树(分类回归树)分析及应用建模
一.CART决策树模型概述(Classification And Regression Trees) 决策树是使用类似于一棵树的结构来表示类的划分,树的构建可以看成是变量(属性)选择的过程,内部节 ...
最新文章
- Eclipse SVN插件冲突导致不能使用解决办法
- php memcached 扩展下载,编译安装 PHP 的 Memcached 扩展
- 牛客练习赛74 E CCA的期望(算概率的技巧+floyd处理)
- BGP基本配置(GNS3)
- 微信支付小年上线“点鞭炮,响优惠”活动 大额提现免费券限时发放
- 面向模式的软件体系结构
- C#模拟js的Json对象创建,操作
- android 扇形统计动画,Android自定义View——扇形统计图的实现代码
- rdkit获得原子的标准排序序号
- mtk屏幕背光默认时间修改
- 计算机更改虚拟内存有用吗,电脑虚拟内存有什么用(小白必知虚拟内存作用及设置技巧)...
- 干货|爬虫被封的几个常见原因
- 【整理收集】使用当中IE浏览器遇到的兼容性问题
- 解决iPhone无法连接iTunes
- STemWin专题--画直线
- 信息论与编码课件,希望传播给更多的人
- ESIM(Enhanced Sequential Inference Model)- 模型详解
- 九龙证券|港股盘中暴涨110%!多股涨停,有色、汽车板块爆发!
- 面试题目之:mvvm框架是什么?它与其他框架(jquery)的区别是什么?哪些场景适合?
- 美通企业日报 | 全球金融科技50强榜中企居首;施乐退出与富士胶片合资企业...