作者:黄天元,复旦大学博士在读,目前研究涉及文本挖掘、社交网络分析和机器学习等。希望与大家分享学习经验,推广并加深R语言在业界的应用。

邮箱:huang.tian-yuan@qq.com

caret包是R语言通用机器学习包之一,能够在统一框架下使用各种不同的模型,从预处理、建模到后期的预测、评估都有非常友好的函数封装。新近学习的DALEX包是给黑箱提供模型解释性的利器。事实上,它不仅仅针对黑箱模型,它能够面向所有模型给出表现的评估、变量的重要性等有价值的信息。本文依照官方文档,尝试习得通用的DALEX解释caret包生成模型的套路。

1 包的载入与数据导入

安装三个包。

library(pacman)p_load(DALEX,caret,tidyverse)

观察我们要使用的目标数据:

apartments %>% as_tibble

# A tibble: 1,000 x 6   m2.price construction.year surface floor no.rooms district         <dbl>             <dbl>   <dbl> <int>    <dbl> <fct>       1     5897              1953      25     3        1 Srodmiescie 2     1818              1992     143     9        5 Bielany     3     3643              1937      56     1        2 Praga       4     3517              1995      93     7        3 Ochota      5     3013              1992     144     6        5 Mokotow     6     5795              1926      61     6        2 Srodmiescie 7     2983              1970     127     8        5 Mokotow     8     2346              1985     105     8        4 Ursus       9     4745              1928     145     6        6 Srodmiescie10     4284              1949     112     9        4 Srodmiescie# ... with 990 more rows

2 使用caret包迅速建模

这里,以m2.price作为响应变量,其余所有变量作为解释变量,进行建模。尝试模型包括:随机森林、GBM和神经网络。其中,随机森林设置树的数量为100,GBM使用默认设置,神经网络在预处理的时候要进行中心化和标准化,最大迭代次数设置为500次,使用线性输出单元,并设置网格对超参数进行优化的选项(这里用了两个隐藏层,权重衰减参数设为0,只设置了一个值,没有用网格去优化)。代码如下:

#下面这串代码的运行可能要等待一段时间

set.seed(123)regr_rf <- train(m2.price~., data = apartments, method="rf", ntree = 100)

regr_gbm <- train(m2.price~. , data = apartments, method="gbm")

regr_nn <- train(m2.price~., data = apartments,                   method = "nnet",                   linout = TRUE,                   preProcess = c('center', 'scale'),                   maxit = 500,                   tuneGrid = expand.grid(size = 2, decay = 0),                   trControl = trainControl(method = "none", seeds = 1))

3 对模型进行解释

这里直接利用DALEX包的explain函数对三个模型进行解释性分析。需要注意的是,做这个分析需要包含4个信息:1.模型信息;2.标签信息(如果没有,会自动从模型抽取);3.验证数据集;4.验证数据集中哪个是响应变量。代码如下:

data(apartmentsTest)

explainer_regr_rf <- DALEX::explain(regr_rf, label="rf",                                     data = apartmentsTest, y = apartmentsTest$m2.price)

explainer_regr_gbm <- DALEX::explain(regr_gbm, label = "gbm",                                      data = apartmentsTest, y = apartmentsTest$m2.price)

explainer_regr_nn <- DALEX::explain(regr_nn, label = "nn",                                     data = apartmentsTest, y = apartmentsTest$m2.price)

建模可能很久,但是解释性验证是非常快的,直接是黑箱的映射关系。

4 模型表现

对模型的表现,需要进行分析:

mp_regr_rf <- model_performance(explainer_regr_rf)mp_regr_gbm <- model_performance(explainer_regr_gbm)mp_regr_nn <- model_performance(explainer_regr_nn)

我们看看得到的结果是什么样子的:

mp_regr_rf

这是样本的残差分布情况,让我们对这个分布进行可视化(累计残差分布图):

plot(mp_regr_rf, mp_regr_nn, mp_regr_gbm)

这个图的正确解释方法是,少数的样本(离群点)贡献了大量的残差(与真实值的偏差)。如果线在上面,那么大量的样本残差都很大,此图表明GBM模型大部分样本的残差都比较小,而神经网络很多样本的残差都比基于树模型的高。让我们采用另一种可视化方法:

plot(mp_regr_rf, mp_regr_nn, mp_regr_gbm, geom = "boxplot")

高下立判,红点为均值,箱线图则为分位数。

5 变量重要性分析

需要看每个模型中,不同变量对于模型预测的相对重要性,可以用如下方法。

vi_regr_rf <- variable_importance(explainer_regr_rf, loss_function = loss_root_mean_square)vi_regr_gbm <- variable_importance(explainer_regr_gbm, loss_function = loss_root_mean_square)vi_regr_nn <- variable_importance(explainer_regr_nn, loss_function = loss_root_mean_square)

plot(vi_regr_rf, vi_regr_gbm, vi_regr_nn)

损失函数使用的是RMSE,这里解释为:如果模型少了这个变量,将会给响应变量的预测值带来多大影响?

6 变量解析

6.1 连续型变量解析

Partial Dependence Plots (PDP),是解释单个连续型解释变量与响应变量关系的方法。专门有相关的包和论文描述这个方法的机理,详情请去找pdp包的官方文档。比如我们想要研究房屋建筑年份(construction.year)对响应变量房价的影响,我们这样做:

pdp_regr_rf  <- variable_response(explainer_regr_rf, variable =  "construction.year", type = "pdp")pdp_regr_gbm  <- variable_response(explainer_regr_gbm, variable =  "construction.year", type = "pdp")pdp_regr_nn  <- variable_response(explainer_regr_nn, variable =  "construction.year", type = "pdp")

plot(pdp_regr_rf, pdp_regr_gbm, pdp_regr_nn)

从随机森林和GBM模型可以看出来,建筑年份与放假具有非线性关系。特别老的房子和新建的房子房价都很贵,但是40年代到90年代的房子则价格较低。不过,神经网络模型不能很好地捕捉这个规律。 此外,还有一种方法称为Acumulated Local Effects (ALE),是为了解决变量相关性的问题设计的,本质上是PDP方法的延伸。实现方法如下:

ale_regr_rf  <- variable_response(explainer_regr_rf, variable =  "construction.year", type = "ale")ale_regr_gbm  <- variable_response(explainer_regr_gbm, variable =  "construction.year", type = "ale")ale_regr_nn  <- variable_response(explainer_regr_nn, variable =  "construction.year", type = "ale")

plot(ale_regr_rf, ale_regr_gbm, ale_regr_nn)

6.2 离散型变量解析

对于离散型变量,DALEX包目前的解析方法是调用了factorMerger包的mergeFactors函数。

mpp_regr_rf  <- variable_response(explainer_regr_rf, variable =  "district", type = "factor")mpp_regr_gbm  <- variable_response(explainer_regr_gbm, variable =  "district", type = "factor")mpp_regr_nn  <- variable_response(explainer_regr_nn, variable =  "district", type = "factor")

plot(mpp_regr_rf, mpp_regr_gbm, mpp_regr_nn)

这个方法的本质是根据响应变量的分布对单个因子变量进行聚类。就上面这个图而言,我们可以看到,对于不同地区的房价是不同的,可以明显分为6类。

——————————————

往期精彩:

  • 今天,我改名了!

  • Oracle裁员补偿N+6,员工仍不满意,为何?

  • 被嫌弃的程序员的一生

R语言机器学习:caret包使用及其黑箱模型解释(连续变量预测)相关推荐

  1. r语言 tunerf_R语言机器学习:caret包使用及其黑箱模型解释(连续变量预测)

    原标题:R语言机器学习:caret包使用及其黑箱模型解释(连续变量预测) 作者:黄天元,复旦大学博士在读,目前研究涉及文本挖掘.社交网络分析和机器学习等.希望与大家分享学习经验,推广并加深R语言在业界 ...

  2. R语言机器学习Caret包(Caret包是分类和回归训练的简称)、数据划分、数据预处理、模型构建、模型调优、模型评估、多模型对比、模型预测推理

    R语言机器学习Caret包(Caret包是分类和回归训练的简称).数据划分.数据预处理.模型构建.模型调优.模型评估.多模型对比.模型预测推理 目录

  3. R语言使用caret包构建遗传算法树模型(Tree Models from Genetic Algorithms )构建回归模型、通过method参数指定算法名称

    R语言使用caret包构建遗传算法树模型(Tree Models from Genetic Algorithms  )构建回归模型.通过method参数指定算法名称.通过trainControl函数控 ...

  4. R语言使用caret包构建岭回归模型(Ridge Regression )构建回归模型、通过method参数指定算法名称、通过trainControl函数控制训练过程

    R语言使用caret包构建岭回归模型(Ridge Regression )构建回归模型.通过method参数指定算法名称.通过trainControl函数控制训练过程 目录

  5. R语言使用caret包构建随机森林模型(random forest)构建回归模型、通过method参数指定算法名称、通过ntree参数指定随机森林中树的个数

    R语言使用caret包构建随机森林模型(random forest)构建回归模型.通过method参数指定算法名称.通过ntree参数指定随机森林中树的个数 目录

  6. R语言使用caret包中的createFolds函数对机器学习数据集进行交叉验证抽样、返回的样本列表长度为k个

    R语言使用caret包中的createFolds函数对机器学习数据集进行交叉验证抽样.返回的样本列表长度为k个 目录

  7. R语言使用caret包中的createMultiFolds函数对机器学习数据集进行交叉验证抽样、返回的样本列表长度为k×times个、times为组内抽样次数

    R语言使用caret包中的createMultiFolds函数对机器学习数据集进行交叉验证抽样.返回的样本列表长度为k×times个.times为组内抽样次数 目录

  8. R语言使用caret包的preProcess函数进行数据预处理:对所有的数据列进行SpatialSign变换(将数据投影到单位圆之内)、设置method参数为spatialSign

    R语言使用caret包的preProcess函数进行数据预处理:对所有的数据列进行SpatialSign变换(将数据投影到单位圆之内).设置method参数为spatialSign 目录

  9. R语言使用caret包构建gbdt模型(随机梯度提升树、Stochastic Gradient Boosting )构建回归模型、通过method参数指定算法名称

    R语言使用caret包构建gbdt模型(随机梯度提升树.Stochastic Gradient Boosting )构建回归模型.通过method参数指定算法名称 目录

  10. R语言使用caret包的findCorrelation函数批量删除相关性冗余特征、实现特征筛选(feature selection)、剔除高相关的变量

    R语言使用caret包的findCorrelation函数批量删除相关性冗余特征.实现特征筛选(feature selection).剔除高相关的变量 目录

最新文章

  1. 使用Epoll 能监听普通文件吗?
  2. 百度地图——判断一个点是否在一个区域内?
  3. 反思找开瓶器的过程:预判选择方法的执行结果 充分主动积极的沟通
  4. 美国伊利诺伊大学香槟分校计算机专业,伊利诺伊大学香槟分校计算机科学排名第7(2020年TFE美国排名)...
  5. 计算机采用二进制形式的表示,计算机部信息的表示及存储往往采用二进制形式,采用这种形式的最主要原因是...
  6. 操作系统考研辅导教程(计算机专业研究生入学考试全真题解) pdf,计算机组成原理考研辅导教程:计算机专业研究生入学考试全真题解...
  7. mycat 分片规则
  8. 导出标签_如何从系统导出审计要求的日记账
  9. oracle 11g dataguard创建的简单方法
  10. 小白_Unity引擎_Console控制台
  11. MyBatis使用总结+整合Spring
  12. JAVA菜鸟教程(一)
  13. 视频图片音乐音效工具等素材网汇总
  14. SSM开发相关安装教程(idea、tomcat、maven、DB)
  15. mybatis-plus和mysql
  16. 中国宠物医疗市场产业消费需求及盈利前景预测报告(2022-2027年)
  17. Linux之CentOS7安装(VMware虚拟机安装及系统安装图文教程)
  18. 探访广州黑人区,我好像来到非洲
  19. 十几加几不进位加法题C语言,十几加几(不进位))的加法教案
  20. PowerApps入门——PowerApps的3种打开方式

热门文章

  1. iOS 动态隐藏状态栏
  2. 阿里为什么推荐使用LongAdder?而不是AtomicLong?
  3. 神马?写了3年代码,连分布式缓存都没用过~
  4. 知乎高赞:为什么许多原本的 Java 项目都试图用 go 进行重写开源?
  5. 再见!公司的烂系统……
  6. 爱了爱了!0.052 秒打开 100GB 数据,这个开源库火爆了!
  7. 彻底搞懂“红黑树”......
  8. 如何快速搭建一个微服务架构?
  9. 中国的 GitHub 要来了?
  10. 推荐两个非常不错的公众号