本文通过示例介绍R实现CART(classification and regression tree)过程。

当一组预测变量与响应变量的关系为线性时,我们使用多重线性回归可以生成准确的预测模型。但当它们的关系为更复杂的非线性关系时,则需采用非线性模型。

分类回归CART(classification and regression tree)方法使用一组预测变量构建决策树,用来预测响应变量。响应变量是连续的,我们能构建回归树;如果响应变量是分类类型,则构建分类树。下面通过示例构建回归和分类树过程。

构建回归树

我们使用ISLR包中的Hitters数据集,它包括263个专业棒球运动员的各类信息。我们将使用该数据集构建回归树,预测变量是home runsyears played ,响应变量运动员的Salary

  1. 加载包
library(ISLR)       # 包含 Hitters 数据集
library(rpart)      # 决策树算法实现
library(rpart.plot) # 图视化决策树
  1. 构建初步回归树

首先构建大的初始回归树,为了让树足够大,我们使用较小的cp值(complexity parameter:复杂性参数)。这意味着指定较小的cp值,只要模型总体R方增加就继续产生新的分支。然后使用printcp()函数打印模型结果:


# 构建初始回归树
tree <- rpart(Salary ~ Years + HmRun, data=Hitters, control=rpart.control(cp=.0001))# 查看结果
printcp(tree)
#
# Regression tree:
# rpart(formula = Salary ~ Years + HmRun, data = Hitters, control = rpart.control(cp = 1e-04))
#
# Variables actually used in tree construction:
# [1] HmRun Years
#
# Root node error: 53319113/263 = 202734
#
# n=263 (因为不存在,59个观察量被删除了)
#
#            CP nsplit rel error  xerror    xstd
# 1  0.24674996      0   1.00000 1.00878 0.13855
# 2  0.10806932      1   0.75325 0.76404 0.12750
# 3  0.01865610      2   0.64518 0.69032 0.12187
# 4  0.01761100      3   0.62652 0.72818 0.12517
# 5  0.01747617      4   0.60891 0.72653 0.12519
# 6  0.01038188      5   0.59144 0.70819 0.12032
# 7  0.01038065      6   0.58106 0.69777 0.11848
# 8  0.00731045      8   0.56029 0.69620 0.12013
# 9  0.00714883      9   0.55298 0.69893 0.11987
# 10 0.00708618     10   0.54583 0.69754 0.11989
# 11 0.00516285     12   0.53166 0.70187 0.11974
# 12 0.00445345     13   0.52650 0.69581 0.12115
# 13 0.00406069     14   0.52205 0.69963 0.12120
# 14 0.00264728     15   0.51799 0.70416 0.12220
# 15 0.00196586     16   0.51534 0.69776 0.12096
# 16 0.00016686     17   0.51337 0.69266 0.11828
# 17 0.00010000     18   0.51321 0.69318 0.11857
  1. 树剪枝

下面对回归树进行剪枝,使用cp值寻找最优值(最低测试误差)。从上节输出我们看到cp的最佳值是致xerror最低的记录值,它表示交叉验证数据的观察结果的误差。

# 识别最佳CP值
best <- tree$cptable[which.min(tree$cptable[,"xerror"]),"CP"]# 基于最佳CP值对模型树进行剪枝
pruned_tree <- prune(tree, cp=best)# 画出剪枝后的模型树
prp(pruned_tree,faclen=0,   # 使用完整标签名称extra=1,    # 显示每个终端节点数量roundint=F, # 输出数值不近似为整数digits=5)   # 输出显示小数位数5位

我们看到最终剪枝为3个终端节点。每个终端节点显示运动员的薪资及原始数据中属于该节点的观察记录数量。

举例,原始数据中职业经验小于4.5年的有90个运动员,薪资为$225.83k。

  1. 使用剪枝树进行预测

下面使用最终的剪枝树预测新的运动员薪资,基于职业经验和平均本垒(home runs). 举例,某运动员有7年职业经验,平均home runs为4,则预测薪资为:577.61k .

执行predict函数进行验证:

# 给定新的运动员信息
new <- data.frame(Years=7, HmRun=4)# 使用剪枝树预测运动员薪资
predict(pruned_tree, newdata=new)# 577.6061

构建分类树

这个示例使用 rpart.plot 包中的 ptitanic 数据集,它包含Titanic(泰坦尼克号)上乘客的各类信息。我们利用该信息构建分类树,使用预测变量:pclass(乘客等级), sex, 和 age,预测变量为是否存活。

  1. 加载包
library(rpart)      # 决策树算法实现
library(rpart.plot) # 图视化决策树
  1. 构建初始分类树
#build the initial tree
tree <- rpart(survived~pclass+sex+age, data=ptitanic, control=rpart.control(cp=.0001))#view results
printcp(tree)# Classification tree:
# rpart(formula = survived ~ pclass + sex + age, data = ptitanic,
#     control = rpart.control(cp = 1e-04))
#
# Variables actually used in tree construction:
# [1] age    pclass sex
#
# Root node error: 500/1309 = 0.38197
#
# n= 1309
#
#       CP nsplit rel error xerror     xstd
# 1 0.4240      0     1.000  1.000 0.035158
# 2 0.0140      1     0.576  0.576 0.029976
# 3 0.0095      3     0.548  0.580 0.030050
# 4 0.0070      7     0.510  0.550 0.029477
# 5 0.0050      9     0.496  0.524 0.028952
# 6 0.0025     11     0.486  0.534 0.029157
# 7 0.0020     19     0.464  0.538 0.029238
# 8 0.0001     22     0.458  0.528 0.029035
  1. 剪枝树

下面对回归树进行剪枝,使用cp值寻找最优值(最低测试误差)。从上节输出我们看到cp的最佳值是致xerror最低的记录值,它表示交叉验证数据的观察结果的误差。

# 识别最佳CP值
best <- tree$cptable[which.min(tree$cptable[,"xerror"]),"CP"]# 基于最佳cp值进行剪枝
pruned_tree <- prune(tree, cp=best)# 画出剪枝后的模型树
prp(pruned_tree,faclen=0,   # 使用完整标签名称extra=1,    # 显示每个终端节点数量roundint=F, # 输出数值不近似为整数digits=5)   # 输出显示小数位数5位

我们看到最终有10个终端节点,每个节点显示死亡和幸存乘客的数量。举例,最左边节点显示664个乘客死亡,136个乘客幸存。

  1. 预测

我们现在能使用最终剪枝树模型,通过pclass,age,sex变量预测给定乘客生存的概率。

给定乘客:pclass:1st, age:8 sex:male, 则生存概率:11/29=37.9% .

总结

本文通过回归树和分类树两个示例展示决策树实现过程,希望对你有帮助。

R语言决策树实战教程相关推荐

  1. R语言并行计算实战教程

    foreach包增强了R循环遍历功能,并且提供了并行执行能力.本文通过示例带你轻松掌握这个高级技能. foreach语法介绍 %do% 和 %dopar% 是对遍历对象执行一段业务功能代码的操作. # ...

  2. 2020互联网数据分析师教程视频 统计学分析与数据实战 r语言数据分析实战 python数据分析实战 excel自动化报表分析实战 excel数据分析处理实战

    2020互联网数据分析师教程视频 统计学分析与数据实战 r语言数据分析实战 python数据分析实战 excel自动化报表分析实战 excel数据分析处理实战

  3. R语言数据挖掘实战系列(4)

    R语言数据挖掘实战系列(4)--数据预处理 数据预处理一方面是要提高数据的质量,另一方面是要让数据更好地适应特定的挖掘技术或工具.数据预处理的主要内容包括数据清洗.数据集成.数据变换和数据规约. 一. ...

  4. R语言学习实战——解决边际分布图

    目录 0 R语言概述 1 本次实战简介 2 涉及的工具包 2.1 ggplot2简介 2.2 ggExtra简介 2.3 ggpointdensity简介 3 开始画图 3.1 安装并载入 3.2 导 ...

  5. R语言决策树、bagging、随机森林模型在训练集以及测试集的预测结果(accuray、F1、偏差Deviance)对比分析、计算训练集和测试集的预测结果的差值来分析模型的过拟合(overfit)情况

    R语言决策树.bagging.随机森林模型在训练集以及测试集的预测结果(accuray.F1.偏差Deviance)对比分析.计算训练集和测试集的预测结果的差值来分析模型的过拟合(overfit)情况 ...

  6. R语言R-markdown实战示例、R-markdown、R-markdown生成结果汇报的HTML文件

    R语言R-markdown实战示例.R-markdown.R-markdown生成结果汇报的HTML文件 目录 R语言R-markdown实战示例.R-markdown.R-markdown生成结果汇 ...

  7. 三十四、R语言数据分析实战

    @Author : By Runsen @Date : 2020/5/14 作者介绍:Runsen目前大三下学期,专业化学工程与工艺,大学沉迷日语,Python, Java和一系列数据分析软件.导致翘 ...

  8. r语言中which的使用_大数据分析R语言RStudio使用教程

    RStudio是用于R编程的开源工具.如果您对使用R编程感兴趣,则值得了解RStudio的功能.它是一种灵活的工具,可帮助您创建可读的分析,并将您的代码,图像,注释和图解保持在一起. 在此大数据分析R ...

  9. 大数据分析R语言RStudio使用教程

    RStudio是用于R编程的开源工具.如果您对使用R编程感兴趣,则值得了解RStudio的功能.它是一种灵活的工具,可帮助您创建可读的分析,并将您的代码,图像,注释和图解保持在一起. 在此大数据分析R ...

  10. 统计学习导论之R语言应用(四):分类算法R语言代码实战

    统计学习导论之R语言应用(ISLR) 参考资料: The Elements of Statistical Learning An Introduction to Statistical Learnin ...

最新文章

  1. matlab根据 2 6,#2.6 应用MATLAB进行模型处理
  2. C# 获取Excel版本
  3. 绝望的力量:美术创作者的晋级之路
  4. 前端vue后端java,Vue调用后端java接口的实例代码_亦心_前端开发者
  5. Java初学者必知 关于Java字符串问题
  6. 计算机电子琴音乐,电脑电子琴软件
  7. Spring 中获取 request 的几种方法,及其线程安全性分析
  8. Win7下的C盘重新划分为两个盘
  9. Android的onCreateOptionsMenu()创建菜单Menu详解
  10. 数据结构------图(一)
  11. OpenSSL笔记-PKCS#1和PKCS#8的区别及分别调用的API
  12. linux tar命令 打包压缩
  13. SPSS学习笔记(四)非参数检验
  14. 机器人 迷宫算法_机器人,迷宫和附属建筑
  15. k3s-(3)k3s-agent工作节点安装
  16. 抖音视频选择封面android,抖音视频封面怎么选取?
  17. 排列组合、伯努利试验
  18. VSCode Remote SSH 过程试图写入的管道不存在
  19. 【IoT】加密与安全:动态密码 OTP 算法详解
  20. 数据库DDL、DML分别是什么

热门文章

  1. 在线答题助手c语言源码,很早之前发的逆水寒答题助手,开源!!自己可以修改成任何答题器源码!~~...
  2. Spring源码分析-Bean生命周期概述
  3. Pybluez Win10系统安装教程(蓝牙通信模块pybluez,Python完美安装)
  4. php 字符串大写转小写转大写,字符串大小写批量互相转换 - 在线工具
  5. 人人商城小程序 java版_人人商城小程序用户授权问题
  6. excel中如何对比两个表格的重复数据
  7. 4个老司机常用的黑科技资源网站
  8. uv422转换为yuv420_YUV420 Planar 转换为 YUV422 Packed
  9. JMeter接口测试及接口登陆压力测试
  10. 区位码、国标码、机内码的区别和内在机制