获取更多R语言和生信知识,请欢迎关注公众号:医学和生信笔记

医学和生信笔记 公众号主要分享:1.医学小知识、肛肠科小知识;2.R语言和Python相关的数据分析、可视化、机器学习等;3.生物信息学学习资料和自己的学习笔记!

目录

  • IML
    • 企鹅任务
    • FeatureEffects
    • Shapley
    • Featurelmp
    • 独立测试数据
  • DALEX
    • 读取数据
    • 建模
    • `DALEX`工作的一般流程
    • 数据集水平的探索
    • instance水平的探索

关于模型解释平常接触的不是特别多,简单学习下。

理论上,所有通用的模型解释框架都可应用于mlr3,只需要把训练好的模型从Learner对象中提取出来即可。

目前最受欢迎的两个框架分别是:

  • iml
  • DALEX

IML

关于iml包进行模型解释有专门一本书:IML Book。
这里简单介绍。

企鹅任务

企鹅数据包括8个变量,344个企鹅(344行)。

data("penguins", package = "palmerpenguins")
str(penguins)
## tibble [344 x 8] (S3: tbl_df/tbl/data.frame)
##  $ species          : Factor w/ 3 levels "Adelie","Chinstrap",..: 1 1 1 1 1 1 1 1 1 1 ...
##  $ island           : Factor w/ 3 levels "Biscoe","Dream",..: 3 3 3 3 3 3 3 3 3 3 ...
##  $ bill_length_mm   : num [1:344] 39.1 39.5 40.3 NA 36.7 39.3 38.9 39.2 34.1 42 ...
##  $ bill_depth_mm    : num [1:344] 18.7 17.4 18 NA 19.3 20.6 17.8 19.6 18.1 20.2 ...
##  $ flipper_length_mm: int [1:344] 181 186 195 NA 193 190 181 195 193 190 ...
##  $ body_mass_g      : int [1:344] 3750 3800 3250 NA 3450 3650 3625 4675 3475 4250 ...
##  $ sex              : Factor w/ 2 levels "female","male": 2 1 1 NA 1 2 1 2 NA NA ...
##  $ year             : int [1:344] 2007 2007 2007 2007 2007 2007 2007 2007 2007 2007 ...

创建任务:

library(iml)
library(mlr3)
library(mlr3learners)set.seed(1)
penguins <- na.omit(penguins)
task_peng <- as_task_classif(penguins, target = "species")

选择模型并训练,提取模型:

learner <- lrn("classif.ranger", predict_type = "prob")
learner$train(task_peng)
learner$model
## Ranger result
##
## Call:
##  ranger::ranger(dependent.variable.name = task$target_names, data = task$data(),      probability = self$predict_type == "prob", case.weights = task$weights$weight,      num.threads = 1L)
##
## Type:                             Probability estimation
## Number of trees:                  500
## Sample size:                      333
## Number of independent variables:  7
## Mtry:                             2
## Target node size:                 10
## Variable importance mode:         none
## Splitrule:                        gini
## OOB prediction error (Brier s.):  0.01790106
x <- penguins[which(names(penguins) != "species")]
model <- Predictor$new(learner, data = x, y = penguins$species)

FeatureEffects

num_features <- c("bill_length_mm", "bill_depth_mm", "flipper_length_mm", "body_mass_g", "year")
effect <- FeatureEffects$new(model)
plot(effect, features = num_features)

Shapley

x <- penguins[which(names(penguins) != "species")]
model <- Predictor$new(learner, data = penguins, y = "species")
x.interest <- data.frame(penguins[1, ])
shapley <- Shapley$new(model, x.interest = x.interest)
plot(shapley)

Featurelmp

effect <- FeatureImp$new(model, loss = "ce")
effect$plot(features = num_features)

独立测试数据

split <- partition(task_peng, ratio = 0.8)
train_set <- split$train
test_set <- split$testlearner$train(task_peng, row_ids = train_set)
prediction <- learner$predict(task_peng, row_ids = test_set)
# 训练集
model <- Predictor$new(learner, data = penguins[train_set, ], y = "species")
effect <- FeatureImp$new(model, loss = "ce")
plot_train <- plot(effect, features = num_features)# 测试集
model <- Predictor$new(learner, data = penguins[test_set, ], y = "species")
effect <- FeatureImp$new(model, loss = "ce")
plot_test <- plot(effect, features = num_features)# 放到一起
library("patchwork")
plot_train + plot_test

分别查看feasurelmp

model <- Predictor$new(learner, data = penguins[train_set, ], y = "species")
effect <- FeatureEffects$new(model)
plot(effect, features = num_features)

model <- Predictor$new(learner, data = penguins[test_set, ], y = "species")
effect <- FeatureEffects$new(model)
plot(effect, features = num_features)

DALEX

这个包介绍的方法也有一本书:Explanatory Model Analysis。

DALEX包可透视预测模型,帮助我们探索、解释、可视化模型行为。将使用fifa20数据集进行演示。

这个包干的事情可通过下图理解:

读取数据

library(DALEX)
## Welcome to DALEX (version: 2.3.0).
## Find examples and detailed introduction at: http://ema.drwhy.ai/
##
## 载入程辑包:'DALEX'
## The following object is masked from 'package:generics':
##
##     explain
## The following object is masked from 'package:dplyr':
##
##     explain
data(fifa, package = "DALEX")
fifa[1:2, c("value_eur", "age", "height_cm", "nationality", "attacking_crossing")]
##                   value_eur age height_cm nationality attacking_crossing
## L. Messi           95500000  32       170   Argentina                 88
## Cristiano Ronaldo  58500000  34       187    Portugal                 84

对于每个球员,都有42个feature,

dim(fifa)
## [1] 5000   42

进行简单的处理,有助于我们理解:

fifa[, c("nationality", "overall", "potential", "wage_eur")] = NULL
for (i in 1:ncol(fifa)) fifa[, i] = as.numeric(fifa[, i])

建模

library(mlr3)
library(mlr3learners)fifa_task <- as_task_regr(fifa, target = "value_eur")fifa_ranger <- lrn("regr.ranger", num.trees = 250)
fifa_ranger$train(fifa_task)
fifa_ranger
## <LearnerRegrRanger:regr.ranger>
## * Model: ranger
## * Parameters: num.threads=1, num.trees=250
## * Packages: mlr3, mlr3learners, ranger
## * Predict Type: response
## * Feature types: logical, integer, numeric, character, factor, ordered
## * Properties: hotstart_backward, importance, oob_error, weights

DALEX工作的一般流程

model %>%explain_mlr3(data = ..., y = ..., label = ...) %>%model_parts() %>%plot()
library("DALEX")
library("DALEXtra")
## Anaconda not found on your computer. Conda related functionality such as create_env.R and condaenv and yml parameters from explain_scikitlearn will not be available
ranger_exp <- explain_mlr3(fifa_ranger,data = fifa,y = fifa$value_eur,label = "Ranger RF",colorize = FALSE)
## Preparation of a new explainer is initiated
##   -> model label       :  Ranger RF
##   -> data              :  5000  rows  38  cols
##   -> target variable   :  5000  values
##   -> predict function  :  yhat.LearnerRegr  will be used (  default  )
##   -> predicted values  :  No value for predict function target column. (  default  )
##   -> model_info        :  package mlr3 , ver. 0.13.2.9000 , task regression (  default  )
##   -> predicted values  :  numerical, min =  509536.7 , mean =  7472248 , max =  92074300
##   -> residual function :  difference between y and yhat (  default  )
##   -> residuals         :  numerical, min =  -8364287 , mean =  1039.203 , max =  17510200
##   A new explainer has been created!

数据集水平的探索

fifa_vi <- model_parts(ranger_exp)
head(fifa_vi)
##              variable mean_dropout_loss     label
## 1        _full_model_           1339676 Ranger RF
## 2           value_eur           1339676 Ranger RF
## 3           weight_kg           1400918 Ranger RF
## 4    movement_balance           1402226 Ranger RF
## 5 goalkeeping_kicking           1405259 Ranger RF
## 6           height_cm           1409160 Ranger RF
plot(fifa_vi, max_vars = 12, show_boxplots = F)

selected_variables <- c("age", "movement_reactions","skill_ball_control", "skill_dribbling")fifa_pd <- model_profile(ranger_exp,variables = selected_variables)$agr_profiles
fifa_pd
## Top profiles    :
##              _vname_   _label_ _x_  _yhat_ _ids_
## 1 skill_ball_control Ranger RF   5 7535469     0
## 2    skill_dribbling Ranger RF   7 7911763     0
## 3    skill_dribbling Ranger RF  11 7904604     0
## 4    skill_dribbling Ranger RF  12 7903967     0
## 5    skill_dribbling Ranger RF  13 7902823     0
## 6    skill_dribbling Ranger RF  14 7901248     0
library("ggplot2")
plot(fifa_pd) +scale_y_continuous("Estimated value in Euro", labels = scales::dollar_format(suffix = "€", prefix = "")) +ggtitle("Partial Dependence profiles for selected variables")

instance水平的探索

ronaldo <- fifa["Cristiano Ronaldo", ]
ronaldo_bd_ranger <- predict_parts(ranger_exp,new_observation = ronaldo)
head(ronaldo_bd_ranger)
##                                         contribution
## Ranger RF: intercept                         7472248
## Ranger RF: movement_reactions = 96          11845999
## Ranger RF: skill_ball_control = 92           7170577
## Ranger RF: mentality_positioning = 95        4565939
## Ranger RF: attacking_finishing = 94          4874197
## Ranger RF: attacking_short_passing = 83      4279799
plot(ronaldo_bd_ranger)

ronaldo_shap_ranger <- predict_parts(ranger_exp,new_observation = ronaldo,type = "shap")plot(ronaldo_shap_ranger) +scale_y_continuous("Estimated value in Euro", labels = scales::dollar_format(suffix = "€", prefix = ""))

selected_variables <- c("age", "movement_reactions","skill_ball_control", "skill_dribbling")ronaldo_cp_ranger <- predict_profile(ranger_exp, ronaldo, variables = selected_variables)plot(ronaldo_cp_ranger, variables = selected_variables) +scale_y_continuous("Estimated value of Christiano Ronaldo", labels = scales::dollar_format(suffix = "€", prefix = ""))

获取更多R语言和生信知识,请欢迎关注公众号:医学和生信笔记

医学和生信笔记 公众号主要分享:1.医学小知识、肛肠科小知识;2.R语言和Python相关的数据分析、可视化、机器学习等;3.生物信息学学习资料和自己的学习笔记!

R语言机器学习mlr3:模型解释相关推荐

  1. R语言机器学习mlr3:模型评价和比较

    获取更多R语言和生信知识,请关注公众号:医学和生信笔记. 公众号后台回复R语言,即可获得海量学习资料! 目录 二分类变量和ROC曲线 重抽样 benchmark 前面一篇介绍了如何使用 mlr3创建任 ...

  2. R语言机器学习mlr3:数据预处理和pipelines

    获取更多R语言和生信知识,请欢迎关注公众号:医学和生信笔记 医学和生信笔记 公众号主要分享:1.医学小知识.肛肠科小知识:2.R语言和Python相关的数据分析.可视化.机器学习等:3.生物信息学学习 ...

  3. R语言机器学习mlr3:超参数调优

    获取更多R语言和生信知识,请关注公众号:医学和生信笔记. 公众号后台回复R语言,即可获得海量学习资料! 目录 模型调优 调整超参数 方法一:通过`tuninginstancesinglecrite`和 ...

  4. R语言机器学习mlr3:基础使用

    获取更多R语言和生信知识,请关注公众号:医学和生信笔记. 公众号后台回复R语言,即可获得海量学习资料! 目录 创建任务 创建learner 训练.预测和性能评价 本篇主要介绍mlr3包的基本使用. 一 ...

  5. R语言机器学习mlr3:简介

    获取更多R语言和生信知识,请关注公众号:医学和生信笔记. 公众号后台回复R语言,即可获得海量学习资料! 目录 `mlr3`简介 目标群体 为什么重写? 设计理念 `mlr3`生态 mlr3简介 mlr ...

  6. R语言机器学习mlr3:特征选择和hyperband调参

    获取更多R语言和生信知识,请关注公众号:医学和生信笔记. 公众号后台回复R语言,即可获得海量学习资料! 目录 Hyperband调参 特征选择 filters 计算分数 计算变量重要性 组合方法(wr ...

  7. R语言机器学习mlr3:嵌套重抽样

    获取更多R语言和生信知识,请关注公众号:医学和生信笔记. 公众号后台回复R语言,即可获得海量学习资料! 目录 嵌套重抽样 进行嵌套重抽样 评价模型 把超参数应用于模型 嵌套重抽样 既有外部重抽样,也有 ...

  8. R语言机器学习mlr3:技术细节

    获取更多R语言和生信知识,请欢迎关注公众号:医学和生信笔记 医学和生信笔记 公众号主要分享:1.医学小知识.肛肠科小知识:2.R语言和Python相关的数据分析.可视化.机器学习等:3.生物信息学学习 ...

  9. R语言机器学习Caret包(Caret包是分类和回归训练的简称)、数据划分、数据预处理、模型构建、模型调优、模型评估、多模型对比、模型预测推理

    R语言机器学习Caret包(Caret包是分类和回归训练的简称).数据划分.数据预处理.模型构建.模型调优.模型评估.多模型对比.模型预测推理 目录

最新文章

  1. RHEL6.3配置Apache服务器(2) 构建虚拟主机
  2. 雷讯和pix_青海叶陇沟金矿地质地球化学特征及找矿方向
  3. android TextUtils的使用
  4. Google 图片下载工具
  5. java 安全库_国家信息安全漏洞库
  6. 深入mysql语言_MySQL对数据操作的一些深入语法
  7. 查看grafana版本_使用 Prometheus 与 Grafana 为 Kubernetes 集群建立监控与警报机制
  8. php金币格式转换,php 资金格式转换函数_PHP教程
  9. 数据结构与算法之-----栈的应用(二)
  10. 如何用阿里云服务器建立一个wordpress网站
  11. ImageButton
  12. 10大顶级运营商转型案例剖析
  13. 推荐 10 款 C++ 在线编译器
  14. php项目部署在腾讯云服务器,腾讯云服务器部署
  15. oracle ora 31644,dmp文件损坏导致ORA-39014 ORA-39029 ORA-31693错误
  16. c 语言中古括号,如何将中古调式运用在你的作品上
  17. java中高效遍历list_Java中四种遍历List的方法总结(推荐)
  18. suse种运行wkhtmltopdf
  19. didi.github.io 域名无法打开解决办法
  20. Oracle:数据库备份之exp与imp的使用(切记,不能在plsql或sqlplus中使用)

热门文章

  1. mac 有道云词典闪退问题修复
  2. Gradle教程和指南 - 构建审视
  3. 异端的制作:数字人物Gawain
  4. python微信群聊机器人_Python + itchat 实现微信机器人聊天(支持自动回复指定群聊)...
  5. 电视机尺寸一览表2022
  6. 82.android 简单的当前运行内存清理
  7. Golang如何实现排序
  8. php,ajax -->Uncaught SyntaxError: Unexpected end of JSON input at JSON.parse (<anonymous>)
  9. 攻防演练建设过程中技术考虑
  10. 讨论一下微信小程序中如何长按识别图片中二维码跳转