深度学习R语言 mlr3 建模,训练,预测,评估(随机森林,Logistic Regression)

本文主要通过使用mlr3包来训练German credit数据集,实现不同的深度学习模型。

1. 加载R使用环境

# 安装官方包,一般情况下大部分常用的包都可以官方安装
# install.packages("tidyverse")
# install.packages("bruceR")
#
# # 安装Github来源的包
# # 先安装devtools包后才可以安装github来源的包
#
# install.packages("devtools")
# devtools::install_github("tidyverse")
# remotes::install_github("tidyverse")# 加载包
library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.0 ──
## ✓ ggplot2 3.3.3     ✓ purrr   0.3.4
## ✓ tibble  3.1.1     ✓ dplyr   1.0.5
## ✓ tidyr   1.1.3     ✓ stringr 1.4.0
## ✓ readr   1.4.0     ✓ forcats 0.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::filter() masks stats::filter()
## x dplyr::lag()    masks stats::lag()
library(data.table)
##
## Attaching package: 'data.table'
## The following objects are masked from 'package:dplyr':
##
##     between, first, last
## The following object is masked from 'package:purrr':
##
##     transpose
library(mlr3)
library(mlr3learners)
library(mlr3viz)
library(ggplot2)

2. 数据描述

German credit data

德国信用数据,可以从rchallenge中获得,目标是使用20个解释变量来判断因变量信用风险(好/坏)

2.1 导入数据

# install.package("rchallenge)
data("german", package = "rchallenge") #观察数据
glimpse(german) # 数据类别
## Rows: 1,000
## Columns: 21
## $ status                  <fct> no checking account, no checking account, ... …
## $ duration                <int> 18, 9, 12, 12, 12, 10, 8, 6, 18, 24, 11, 30, 6…
## $ credit_history          <fct> all credits at this bank paid back duly, all c…
## $ purpose                 <fct> car (used), others, retraining, others, others…
## $ amount                  <int> 1049, 2799, 841, 2122, 2171, 2241, 3398, 1361,…
## $ savings                 <fct> unknown/no savings account, unknown/no savings…
## $ employment_duration     <fct> < 1 yr, 1 <= ... < 4 yrs, 4 <= ... < 7 yrs, 1 …
## $ installment_rate        <ord> < 20, 25 <= ... < 35, 25 <= ... < 35, 20 <= ..…
## $ personal_status_sex     <fct> female : non-single or male : single, male : m…
## $ other_debtors           <fct> none, none, none, none, none, none, none, none…
## $ present_residence       <ord> >= 7 yrs, 1 <= ... < 4 yrs, >= 7 yrs, 1 <= ...…
## $ property                <fct> car or other, unknown / no property, unknown /…
## $ age                     <int> 21, 36, 23, 39, 38, 48, 39, 40, 65, 23, 36, 24…
## $ other_installment_plans <fct> none, none, none, none, bank, none, none, none…
## $ housing                 <fct> for free, for free, for free, for free, rent, …
## $ number_credits          <ord> 1, 2-3, 1, 2-3, 2-3, 2-3, 2-3, 1, 2-3, 1, 2-3,…
## $ job                     <fct> skilled employee/official, skilled employee/of…
## $ people_liable           <fct> 0 to 2, 3 or more, 0 to 2, 3 or more, 0 to 2, …
## $ telephone               <fct> no, no, no, no, no, no, no, no, no, no, no, no…
## $ foreign_worker          <fct> no, no, no, yes, yes, yes, yes, yes, no, no, n…
## $ credit_risk             <fct> good, good, good, good, good, good, good, good…
dim(german) # 数据维数
## [1] 1000   21

通过观察发现数据集一共有2000个观测,21个属性(列)。想要预测的因变量是 creadit_risk (good or bad) ,自变量一共有20个,其中 duration, age, amount三个是数值变量,剩余的都是factor因子变量。

可以安装 skimr 包更细致的观察理解变量。

# install.packages("skimr")skimr::skim(german)

Table: Data summary

Name german
Number of rows 1000
Number of columns 21
_______________________
Column type frequency:
factor 18
numeric 3
________________________
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
status 0 1 FALSE 4 …: 394, no : 274, …: 269, 0<=: 63
credit_history 0 1 FALSE 5 no : 530, all: 293, exi: 88, cri: 49
purpose 0 1 FALSE 10 fur: 280, oth: 234, car: 181, car: 103
savings 0 1 FALSE 5 unk: 603, …: 183, …: 103, 100: 63
employment_duration 0 1 FALSE 5 1 <: 339, >= : 253, 4 <: 174, < 1: 172
installment_rate 0 1 TRUE 4 < 2: 476, 25 : 231, 20 : 157, >= : 136
personal_status_sex 0 1 FALSE 4 mal: 548, fem: 310, fem: 92, mal: 50
other_debtors 0 1 FALSE 3 non: 907, gua: 52, co-: 41
present_residence 0 1 TRUE 4 >= : 413, 1 <: 308, 4 <: 149, < 1: 130
property 0 1 FALSE 4 bui: 332, unk: 282, car: 232, rea: 154
other_installment_plans 0 1 FALSE 3 non: 814, ban: 139, sto: 47
housing 0 1 FALSE 3 ren: 714, for: 179, own: 107
number_credits 0 1 TRUE 4 1: 633, 2-3: 333, 4-5: 28, >= : 6
job 0 1 FALSE 4 ski: 630, uns: 200, man: 148, une: 22
people_liable 0 1 FALSE 2 0 t: 845, 3 o: 155
telephone 0 1 FALSE 2 no: 596, yes: 404
foreign_worker 0 1 FALSE 2 no: 963, yes: 37
credit_risk 0 1 FALSE 2 goo: 700, bad: 300

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
duration 0 1 20.90 12.06 4 12.0 18.0 24.00 72 ▇▇▂▁▁
amount 0 1 3271.25 2822.75 250 1365.5 2319.5 3972.25 18424 ▇▂▁▁▁
age 0 1 35.54 11.35 19 27.0 33.0 42.00 75 ▇▆▃▁▁

3. 建模

通过使用mlr3包来解决信用风险分类问题。构建机器学习工作流程时出现的典型问题是:

  • 我们试图解决的问题是什么?
  • 什么是合适的学习算法?
  • 我们如何评价“好”的表现?

在 mlr3 中更系统地,它们可以通过五个组件来表示:

  1. 任务定义 Task
  2. 学习期定义 Learner
  3. 模型训练 Training
  4. 预测 Prediction
  5. 通过一项或多项措施进行评估 Evaluation

3.1任务定义 Task Definition

首先,我们要确定建模的目标。大多数监督机器学习问题是回归或分类问题。在 mlr3 中,为了区分这些问题,我们定义了任务。如果我们要解决一个分类问题,我们定义一个分类任务——TaskClassif。对于回归问题,我们定义了一个回归任务——TaskRegr。

在我们的例子中,我们的目标显然是对二元因子变量 credit_risk 进行建模或预测。因此,我们定义了一个 TaskClassif:

# germancredit 是任务标签,可以自行定义, german 数据集,target是目标变量
task = TaskClassif$new("germancredit", german , target = "credit_risk")

3.2学习器定义 Learner Definition

在决定建模目标后,我们需要决定如何建模。这意味着我们需要决定哪些学习算法或 Learners 是合适的。使用先验知识(例如,知道这是一项分类任务或假设类是线性可分的)最终会得到一个或多个合适的学习器。

许多学习者可以通过 mlr3learners 包获得。此外,许多学习器是通过 GitHub 上的 mlr3extralearners 包提供的。这两种资源加起来占标准学习算法的很大一部分。

所有可用的学习器(即您从 mlr3、mlr3learners、mlr3extralearners 或自己编写的安装的所有学习器)都在字典 mlr_learners 中获得:

mlr_learners
## <DictionaryLearner> with 29 stored values
## Keys: classif.cv_glmnet, classif.debug, classif.featureless,
##   classif.glmnet, classif.kknn, classif.lda, classif.log_reg,
##   classif.multinom, classif.naive_bayes, classif.nnet, classif.qda,
##   classif.ranger, classif.rpart, classif.svm, classif.xgboost,
##   regr.cv_glmnet, regr.featureless, regr.glmnet, regr.kknn, regr.km,
##   regr.lm, regr.ranger, regr.rpart, regr.svm, regr.xgboost,
##   surv.cv_glmnet, surv.glmnet, surv.ranger, surv.xgboost

对于我们的问题,合适的学习器可以是以下之一:Logistic regression逻辑回归、CART、random forest随机森林等。

可以使用 lrn() 函数和学习器的名称来初始化学习器,例如 lrn(“classif.xxx”)。使用 ?mlr_learners_xxx 打开名为 xxx 的学习者的帮助页面。

例如,逻辑回归可以通过以下方式初始化(逻辑回归使用 R 的 glm() 函数,由 mlr3learners 包提供):

library("mlr3learners")
learner_logreg = lrn("classif.log_reg")
print(learner_logreg)
## <LearnerClassifLogReg:classif.log_reg>
## * Model: -
## * Parameters: list()
## * Packages: stats
## * Predict Type: response
## * Feature types: logical, integer, numeric, character, factor, ordered
## * Properties: twoclass, weights

3.3 训练 Training

训练是在(训练)数据上拟合模型的过程。

  • 逻辑回归Logistic regression

我们从逻辑回归的例子开始。但是,您会立即看到该过程非常容易推广到任何学习者。

可以使用 $train() 对初始化的学习器进行数据训练:

learner_logreg$train(task)

通常,在机器学习中,我们不使用可用的完整数据,而是使用一个子集,即所谓的训练数据。为了有效地执行数据拆分,可以执行以下操作:

train_set = sample(task$row_ids, 0.8 * task$nrow)
test_set = setdiff(task$row_ids, train_set)

80% 的数据用于训练。剩余的 20% 用于随后进行评估。 train_set 是一个整数向量,指的是原始数据集的选定行:

head(train_set)
## [1] 410 864 543 236 958 851

在 mlr3 中,可以通过附加参数 row_ids = train_set 声明使用数据子集的训练:

learner_logreg$train(task, row_ids = train_set)

训练拟合后的模型可以通过以下命令展示:

learner_logreg$model
##
## Call:  stats::glm(formula = task$formula(), family = "binomial", data = task$data(),
##     model = FALSE)
##
## Coefficients:
##                                               (Intercept)
##                                                -0.1819216
##                                                       age
##                                                 0.0056873
##                                                    amount
##                                                -0.0001196
##    credit_historycritical account/other credits elsewhere
##                                                -1.0951994
## credit_historyno credits taken/all credits paid back duly
##                                                 0.3816992
##    credit_historyexisting credits paid back duly till now
##                                                 0.9330591
##     credit_historyall credits at this bank paid back duly
##                                                 1.3556494
##                                                  duration
##                                                -0.0271785
##                                 employment_duration< 1 yr
##                                                -0.0150296
##                       employment_duration1 <= ... < 4 yrs
##                                                 0.2004790
##                       employment_duration4 <= ... < 7 yrs
##                                                 0.9713337
##                               employment_duration>= 7 yrs
##                                                 0.3789241
##                                          foreign_workerno
##                                                -1.2704600
##                                               housingrent
##                                                 0.6250064
##                                                housingown
##                                                 0.6444397
##                                        installment_rate.L
##                                                -0.5924806
##                                        installment_rate.Q
##                                                 0.0909648
##                                        installment_rate.C
##                                                 0.0636166
##                                   jobunskilled - resident
##                                                -0.8209089
##                              jobskilled employee/official
##                                                -0.7988798
##             jobmanager/self-empl./highly qualif. employee
##                                                -0.9088915
##                                          number_credits.L
##                                                -0.4671141
##                                          number_credits.Q
##                                                 0.0976312
##                                          number_credits.C
##                                                 0.0062673
##                                 other_debtorsco-applicant
##                                                -0.9178934
##                                    other_debtorsguarantor
##                                                 1.3397823
##                             other_installment_plansstores
##                                                 0.1427722
##                               other_installment_plansnone
##                                                 0.4974245
##                                       people_liable0 to 2
##                                                 0.2534176
##   personal_status_sexfemale : non-single or male : single
##                                                -0.0183188
##                 personal_status_sexmale : married/widowed
##                                                 0.6102816
##                        personal_status_sexfemale : single
##                                                 0.0759193
##                                       present_residence.L
##                                                -0.1602614
##                                       present_residence.Q
##                                                 0.4513743
##                                       present_residence.C
##                                                -0.3567466
##                                      propertycar or other
##                                                -0.2797497
##         propertybuilding soc. savings agr./life insurance
##                                                -0.1006801
##                                       propertyreal estate
##                                                -0.7330205
##                                          purposecar (new)
##                                                 1.6559118
##                                         purposecar (used)
##                                                 0.8993030
##                                purposefurniture/equipment
##                                                 0.8574892
##                                   purposeradio/television
##                                                -0.0496272
##                                purposedomestic appliances
##                                                -0.0426126
##                                            purposerepairs
##                                                 0.0285772
##                                           purposevacation
##                                                 0.7196447
##                                         purposeretraining
##                                                 0.7088115
##                                           purposebusiness
##                                                 2.3256145
##                                      savings... <  100 DM
##                                                 0.2495854
##                               savings100 <= ... <  500 DM
##                                                 0.5232586
##                               savings500 <= ... < 1000 DM
##                                                 1.3157498
##                                     savings... >= 1000 DM
##                                                 0.9884852
##                                          status... < 0 DM
##                                                 0.1314611
##                                    status0<= ... < 200 DM
##                                                 0.8973969
##          status... >= 200 DM / salary for at least 1 year
##                                                 1.6226985
##                        telephoneyes (under customer name)
##                                                 0.3142853
##
## Degrees of Freedom: 799 Total (i.e. Null);  745 Residual
## Null Deviance:       982.4
## Residual Deviance: 700.6     AIC: 810.6

可以查看Logistic regression 训练后模型的类型以及总结:

class(learner_logreg$model)
## [1] "glm" "lm"
summary(learner_logreg$model)
##
## Call:
## stats::glm(formula = task$formula(), family = "binomial", data = task$data(),
##     model = FALSE)
##
## Deviance Residuals:
##     Min       1Q   Median       3Q      Max
## -2.7481  -0.6573   0.3599   0.6823   2.0764
##
## Coefficients:
##                                                             Estimate Std. Error
## (Intercept)                                               -1.819e-01  1.313e+00
## age                                                        5.687e-03  1.045e-02
## amount                                                    -1.196e-04  5.297e-05
## credit_historycritical account/other credits elsewhere    -1.095e+00  6.830e-01
## credit_historyno credits taken/all credits paid back duly  3.817e-01  4.971e-01
## credit_historyexisting credits paid back duly till now     9.331e-01  5.441e-01
## credit_historyall credits at this bank paid back duly      1.356e+00  4.897e-01
## duration                                                  -2.718e-02  1.083e-02
## employment_duration< 1 yr                                 -1.503e-02  4.935e-01
## employment_duration1 <= ... < 4 yrs                        2.005e-01  4.693e-01
## employment_duration4 <= ... < 7 yrs                        9.713e-01  5.181e-01
## employment_duration>= 7 yrs                                3.789e-01  4.733e-01
## foreign_workerno                                          -1.270e+00  7.304e-01
## housingrent                                                6.250e-01  2.761e-01
## housingown                                                 6.444e-01  5.408e-01
## installment_rate.L                                        -5.925e-01  2.489e-01
## installment_rate.Q                                         9.096e-02  2.255e-01
## installment_rate.C                                         6.362e-02  2.311e-01
## jobunskilled - resident                                   -8.209e-01  7.516e-01
## jobskilled employee/official                              -7.989e-01  7.274e-01
## jobmanager/self-empl./highly qualif. employee             -9.089e-01  7.380e-01
## number_credits.L                                          -4.671e-01  8.489e-01
## number_credits.Q                                           9.763e-02  6.951e-01
## number_credits.C                                           6.267e-03  5.218e-01
## other_debtorsco-applicant                                 -9.179e-01  4.757e-01
## other_debtorsguarantor                                     1.340e+00  4.751e-01
## other_installment_plansstores                              1.428e-01  5.116e-01
## other_installment_plansnone                                4.974e-01  2.944e-01
## people_liable0 to 2                                        2.534e-01  2.831e-01
## personal_status_sexfemale : non-single or male : single   -1.832e-02  4.396e-01
## personal_status_sexmale : married/widowed                  6.103e-01  4.300e-01
## personal_status_sexfemale : single                         7.592e-02  5.179e-01
## present_residence.L                                       -1.603e-01  2.457e-01
## present_residence.Q                                        4.514e-01  2.304e-01
## present_residence.C                                       -3.567e-01  2.293e-01
## propertycar or other                                      -2.797e-01  2.881e-01
## propertybuilding soc. savings agr./life insurance         -1.007e-01  2.790e-01
## propertyreal estate                                       -7.330e-01  4.750e-01
## purposecar (new)                                           1.656e+00  4.260e-01
## purposecar (used)                                          8.993e-01  3.057e-01
## purposefurniture/equipment                                 8.575e-01  2.807e-01
## purposeradio/television                                   -4.963e-02  9.327e-01
## purposedomestic appliances                                -4.261e-02  6.641e-01
## purposerepairs                                             2.858e-02  4.360e-01
## purposevacation                                            7.196e-01  1.287e+00
## purposeretraining                                          7.088e-01  3.815e-01
## purposebusiness                                            2.326e+00  9.776e-01
## savings... <  100 DM                                       2.496e-01  3.377e-01
## savings100 <= ... <  500 DM                                5.233e-01  4.443e-01
## savings500 <= ... < 1000 DM                                1.316e+00  5.692e-01
## savings... >= 1000 DM                                      9.885e-01  2.983e-01
## status... < 0 DM                                           1.315e-01  2.558e-01
## status0<= ... < 200 DM                                     8.974e-01  4.427e-01
## status... >= 200 DM / salary for at least 1 year           1.623e+00  2.681e-01
## telephoneyes (under customer name)                         3.143e-01  2.305e-01
##                                                           z value Pr(>|z|)
## (Intercept)                                                -0.139 0.889817
## age                                                         0.544 0.586361
## amount                                                     -2.259 0.023910 *
## credit_historycritical account/other credits elsewhere     -1.604 0.108806
## credit_historyno credits taken/all credits paid back duly   0.768 0.442612
## credit_historyexisting credits paid back duly till now      1.715 0.086353 .
## credit_historyall credits at this bank paid back duly       2.768 0.005636 **
## duration                                                   -2.511 0.012052 *
## employment_duration< 1 yr                                  -0.030 0.975704
## employment_duration1 <= ... < 4 yrs                         0.427 0.669230
## employment_duration4 <= ... < 7 yrs                         1.875 0.060842 .
## employment_duration>= 7 yrs                                 0.801 0.423408
## foreign_workerno                                           -1.739 0.081956 .
## housingrent                                                 2.264 0.023571 *
## housingown                                                  1.192 0.233383
## installment_rate.L                                         -2.380 0.017307 *
## installment_rate.Q                                          0.403 0.686685
## installment_rate.C                                          0.275 0.783095
## jobunskilled - resident                                    -1.092 0.274757
## jobskilled employee/official                               -1.098 0.272063
## jobmanager/self-empl./highly qualif. employee              -1.231 0.218137
## number_credits.L                                           -0.550 0.582157
## number_credits.Q                                            0.140 0.888294
## number_credits.C                                            0.012 0.990417
## other_debtorsco-applicant                                  -1.930 0.053659 .
## other_debtorsguarantor                                      2.820 0.004806 **
## other_installment_plansstores                               0.279 0.780181
## other_installment_plansnone                                 1.689 0.091136 .
## people_liable0 to 2                                         0.895 0.370704
## personal_status_sexfemale : non-single or male : single    -0.042 0.966764
## personal_status_sexmale : married/widowed                   1.419 0.155847
## personal_status_sexfemale : single                          0.147 0.883447
## present_residence.L                                        -0.652 0.514230
## present_residence.Q                                         1.959 0.050116 .
## present_residence.C                                        -1.556 0.119724
## propertycar or other                                       -0.971 0.331602
## propertybuilding soc. savings agr./life insurance          -0.361 0.718219
## propertyreal estate                                        -1.543 0.122757
## purposecar (new)                                            3.887 0.000101 ***
## purposecar (used)                                           2.942 0.003265 **
## purposefurniture/equipment                                  3.055 0.002251 **
## purposeradio/television                                    -0.053 0.957566
## purposedomestic appliances                                 -0.064 0.948839
## purposerepairs                                              0.066 0.947743
## purposevacation                                             0.559 0.576176
## purposeretraining                                           1.858 0.063142 .
## purposebusiness                                             2.379 0.017359 *
## savings... <  100 DM                                        0.739 0.459905
## savings100 <= ... <  500 DM                                 1.178 0.238878
## savings500 <= ... < 1000 DM                                 2.312 0.020794 *
## savings... >= 1000 DM                                       3.313 0.000922 ***
## status... < 0 DM                                            0.514 0.607376
## status0<= ... < 200 DM                                      2.027 0.042667 *
## status... >= 200 DM / salary for at least 1 year            6.052 1.43e-09 ***
## telephoneyes (under customer name)                          1.363 0.172751
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
##     Null deviance: 982.41  on 799  degrees of freedom
## Residual deviance: 700.57  on 745  degrees of freedom
## AIC: 810.57
##
## Number of Fisher Scoring iterations: 5
  • 随机森林Random forest

就像逻辑回归一样,我们可以训练一个随机森林。我们使用 ranger包快速实现。为此,我们首先需要定义学习器,然后实际训练它。

我们现在另外提供重要性参数(importance = “permutation”)。这样做,我们覆盖默认值,让学习器根据排列特征重要性来确定特征重要性:

learner_rf = lrn("classif.ranger", importance = "permutation")
learner_rf$train(task, row_ids = train_set)

我们可以通过$importance命令来观察自变量的重要程度:

learner_rf$importance()
##                  status                duration                  amount
##            0.0330947539            0.0175370797            0.0134572307
##          credit_history                 savings                     age
##            0.0129659380            0.0095783381            0.0065733821
##                property     employment_duration                 purpose
##            0.0053766886            0.0053485974            0.0047822849
##           other_debtors        installment_rate     personal_status_sex
##            0.0043989633            0.0036503334            0.0029137105
##       present_residence          number_credits                 housing
##            0.0022437675            0.0017202412            0.0013506399
##               telephone           people_liable                     job
##            0.0012456826            0.0007195306            0.0006561488
## other_installment_plans          foreign_worker
##            0.0003107618            0.0001042939

为了获得重要性值的图,我们将重要性转换为 data.table格式,然后用 ggplot2 处理它:

importance = as.data.table(learner_rf$importance(), keep.rownames = TRUE)
# 修改列名称
colnames(importance) = c("Feature", "Importance")# 用ggplot包画出重要性的图ggplot(data=importance,aes(x = reorder(Feature, Importance), y = Importance)) + geom_col() + coord_flip() + xlab("")

可以看出前七个变量对于预测因变量起到了重要作用。

3.3 预测 Prediction

接下来我们要使用训练得到的模型进行预测。训练模型后,该模型可用于预测。通常,预测是机器学习模型的主要目的。

在我们的案例中,该模型可用于对新的信用申请人进行分类。它们基于特征的相关信用风险(好与坏)。通常,机器学习模型会预测数值。在回归情况下,这是很自然的。对于分类,大多数模型预测分数或概率。基于这些值,可以得出类别预测。

  • 预测类别 Predict Classes

首先,我们直接预测类别:

pred_logreg = learner_logreg$predict(task, row_ids = test_set)
pred_rf = learner_rf$predict(task, row_ids = test_set)pred_logreg
## <PredictionClassif> for 200 observations:
##     row_ids truth response
##           2  good      bad
##           3  good     good
##           6  good     good
## ---
##         986   bad     good
##         998   bad     good
##        1000   bad     good
pred_rf
## <PredictionClassif> for 200 observations:
##     row_ids truth response
##           2  good     good
##           3  good     good
##           6  good     good
## ---
##         986   bad     good
##         998   bad     good
##        1000   bad     good

$predict() 方法返回一个 Prediction 对象。如果想在之后使用它,可以将其转换为 data.table格式。

我们还可以显示在混淆矩阵中的预测结果:

pred_logreg$confusion
##         truth
## response bad good
##     bad   28   26
##     good  29  117
pred_rf$confusion
##         truth
## response bad good
##     bad   22   15
##     good  35  128
  • 预测概率 Predict Probabilities

大多数学习期Learner不仅可以预测类别变量(“响应”),还可以预测他们对给定响应的“置信度”/“不确定性”程度。通常,我们通过将 Learner 的 $predict_type设置为“prob”来实现这一点。有时这需要在学习者接受培训之前完成。或者,我们可以使用此选项直接创建学习器:lrn(“classif.log_reg”, predict_type=“prob”)

learner_logreg$predict_type = "prob"
learner_logreg$predict(task, row_ids = test_set)
## <PredictionClassif> for 200 observations:
##     row_ids truth response  prob.bad prob.good
##           2  good      bad 0.5502737 0.4497263
##           3  good     good 0.2432334 0.7567666
##           6  good     good 0.1617924 0.8382076
## ---
##         986   bad     good 0.1088596 0.8911404
##         998   bad     good 0.1524203 0.8475797
##        1000   bad     good 0.3172837 0.6827163

3.4 评估Performance Evaluation

为了衡量学习者在新的数据上的表现,我们通常通过将数据分成训练集和测试集来模拟unseen数据的场景。训练集用于训练学习器,测试集仅用于预测和评估训练后的学习器的表现。许多重采样方法(交叉验证cross-validation、引导bootstrap)以不同的方式重复分割过程。

在 mlr3 中,我们需要使用 rsmp() 函数指定重采样策略resampling strategy:

resampling = rsmp("holdout", ratio = 2/3)
print(resampling)
## <ResamplingHoldout> with 1 iterations
## * Instantiated: FALSE
## * Parameters: ratio=0.6667

在这里,我们使用“holdout”,这是一个简单的训练-测试分割(只有一次迭代)。我们使用resample()函数进行重采样计算:

res = resample(task, learner = learner_logreg, resampling = resampling)
## INFO  [16:08:51.897] [mlr3]  Applying learner 'classif.log_reg' on task 'germancredit' (iter 1/1)
res
## <ResampleResult> of 1 iterations
## * Task: germancredit
## * Learner: classif.log_reg
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations

度量的默认分数包含在 $aggregate() 中:

res$aggregate()
## classif.ce
##  0.2612613

这种情况下的默认度量是分类错误。越低越好。

我们可以运行不同的重采样策略,例如重复坚持(“二次抽样”),或交叉验证。大多数方法对不同的数据子集执行重复的训练/预测循环并聚合结果(通常作为平均值)。手动执行此操作需要我们编写循环。 mlr3 为我们完成了这项工作:

resampling = rsmp("subsampling", repeats=10)
rr = resample(task, learner = learner_logreg, resampling = resampling)
rr$aggregate()
## classif.ce
##  0.2564565

此外,我们也可以使用交叉验证

resampling = resampling = rsmp("cv", folds=10)
rr = resample(task, learner = learner_logreg, resampling = resampling)
rr$aggregate()
## classif.ce
##      0.246

mlr3 具有更多评估的分数。在这里,我们用 mlr_measures_classif.fpr 计算 false positive rate,用 mlr_measures_classif.fnr 计算 false negative rate。可以将多个度量作为度量列表提供(可以通过 msrs() 直接构造):

# false positive rate
rr$aggregate(msr("classif.fpr"))
## classif.fpr
##   0.1345898
# false positive rate and false negative
measures = msrs(c("classif.fpr", "classif.fnr"))
rr$aggregate(measures)
## classif.fpr classif.fnr
##   0.1345898   0.5068602

还有更多的重采样方法和相当多的度量(在 mlr3measures 中实现)。

mlr_resamplings
## <DictionaryResampling> with 8 stored values
## Keys: bootstrap, custom, cv, holdout, insample, loo, repeated_cv,
##   subsampling
# 评估分数类型
mlr_measures
## <DictionaryMeasure> with 54 stored values
## Keys: classif.acc, classif.auc, classif.bacc, classif.bbrier,
##   classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
##   classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
##   classif.logloss, classif.mbrier, classif.mcc, classif.npv,
##   classif.ppv, classif.prauc, classif.precision, classif.recall,
##   classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
##   classif.tp, classif.tpr, debug, oob_error, regr.bias, regr.ktau,
##   regr.mae, regr.mape, regr.maxae, regr.medae, regr.medse, regr.mse,
##   regr.msle, regr.pbias, regr.rae, regr.rmse, regr.rmsle, regr.rrse,
##   regr.rse, regr.rsq, regr.sae, regr.smape, regr.srho, regr.sse,
##   selected_features, time_both, time_predict, time_train

3.5模型效果对 Performance Comparision and Benchmarks

我们可以通过手动评估每个学习期的 resample() 来比较学习器。但是, benchmark() 会自动为多个学习者和任务执行重采样评估。 benchmark_grid() 创建完全交叉的设计:比较多个任务的多个学习者 w.r.t.多次重采样。

learners = lrns(c("classif.log_reg", "classif.ranger"), predict_type = "prob")bm_design = benchmark_grid(tasks = task,learners = learners,resamplings = rsmp("cv", folds = 50)
)bmr = benchmark(bm_design)

在基准测试中,我们可以比较不同的度量。在这里,我们看一下误分类率和 AUC:

measures = msrs(c("classif.ce", "classif.auc"))
performances = bmr$aggregate(measures)
performances[, c("learner_id", "classif.ce", "classif.auc")]

3.6超参数调优Deviating from hyperparameters defaults

之前展示的技术构建了以 mlr3 为特色的机器学习工作流程的支柱。然而,在大多数情况下,人们永远不会像我们那样进行。虽然许多 R 包都精心选择了默认设置,但它们在任何情况下都不会以最佳方式运行。通常,我们可以选择此类超参数的值。学习者的(超)参数可以通过它的 ParamSet $param_set 访问和设置:

learner_rf$param_set
## <ParamSet>
##                               id    class lower upper nlevels        default
##  1:                        alpha ParamDbl  -Inf   Inf     Inf            0.5
##  2:       always.split.variables ParamUty    NA    NA     Inf <NoDefault[3]>
##  3:                class.weights ParamDbl  -Inf   Inf     Inf
##  4:                      holdout ParamLgl    NA    NA       2          FALSE
##  5:                   importance ParamFct    NA    NA       4 <NoDefault[3]>
##  6:                   keep.inbag ParamLgl    NA    NA       2          FALSE
##  7:                    max.depth ParamInt  -Inf   Inf     Inf
##  8:                min.node.size ParamInt     1   Inf     Inf              1
##  9:                     min.prop ParamDbl  -Inf   Inf     Inf            0.1
## 10:                      minprop ParamDbl  -Inf   Inf     Inf            0.1
## 11:                         mtry ParamInt     1   Inf     Inf <NoDefault[3]>
## 12:            num.random.splits ParamInt     1   Inf     Inf              1
## 13:                  num.threads ParamInt     1   Inf     Inf              1
## 14:                    num.trees ParamInt     1   Inf     Inf            500
## 15:                    oob.error ParamLgl    NA    NA       2           TRUE
## 16:        regularization.factor ParamUty    NA    NA     Inf              1
## 17:      regularization.usedepth ParamLgl    NA    NA       2          FALSE
## 18:                      replace ParamLgl    NA    NA       2           TRUE
## 19:    respect.unordered.factors ParamFct    NA    NA       3         ignore
## 20:              sample.fraction ParamDbl     0     1     Inf <NoDefault[3]>
## 21:                  save.memory ParamLgl    NA    NA       2          FALSE
## 22: scale.permutation.importance ParamLgl    NA    NA       2          FALSE
## 23:                    se.method ParamFct    NA    NA       2        infjack
## 24:                         seed ParamInt  -Inf   Inf     Inf
## 25:         split.select.weights ParamDbl     0     1     Inf <NoDefault[3]>
## 26:                    splitrule ParamFct    NA    NA       2           gini
## 27:                      verbose ParamLgl    NA    NA       2           TRUE
## 28:                 write.forest ParamLgl    NA    NA       2           TRUE
##                               id    class lower upper nlevels        default
##        parents       value
##  1:
##  2:
##  3:
##  4:
##  5:            permutation
##  6:
##  7:
##  8:
##  9:
## 10:
## 11:
## 12:  splitrule
## 13:                      1
## 14:
## 15:
## 16:
## 17:
## 18:
## 19:
## 20:
## 21:
## 22: importance
## 23:
## 24:
## 25:
## 26:
## 27:
## 28:
##        parents       value
learner_rf$param_set$values = list(verbose = FALSE)

我们可以通过两种不同的方式为我们的学习者选择参数。如果我们对学习器应该如何(超)参数化有先验知识,那么要走的路将是在参数集中手动输入参数。然而,在大多数情况下,我们希望调整学习器,以便它可以自己搜索“好的”模型配置。目前,我们只想比较几个模型。

要了解可以操作哪些参数,我们可以调查原始包版本的参数或查看学习器的参数集:

as.data.table(learner_rf$param_set)[,.(id, class, lower, upper)]

对于随机森林,控制模型复杂性的两个有意义的参数是 num.trees 和 mtry。 num.trees 默认为 500,mtry 为 floor(sqrt(ncol(data) - 1)),在我们的例子中是 4。

下面我们的目标是训练三个不同的学习器:

  1. 默认随机森林。
  2. 低 num.trees 和低 mtry 的随机森林。
  3. 具有高 num.trees 和高 mtry 的随机森林。

我们将在德国信用数据集上对他们的表现进行基准测试。为此,我们构建了三个学习器并相应地设置参数:

rf_med = lrn("classif.ranger", id = "med", predict_type = "prob")rf_low = lrn("classif.ranger", id = "low", predict_type = "prob",num.trees = 5, mtry = 2)rf_high = lrn("classif.ranger", id = "high", predict_type = "prob",num.trees = 1000, mtry = 11)

一旦定义了学习器,我们就可以对它们进行基准测试:

learners = list(rf_low, rf_med, rf_high)
bm_design = benchmark_grid(tasks = task,learners = learners,resamplings = rsmp("cv", folds = 10)
)
bmr = benchmark(bm_design)
print(bmr)
## <BenchmarkResult> of 30 rows with 3 resampling runs
##  nr      task_id learner_id resampling_id iters warnings errors
##   1 germancredit        low            cv    10        0      0
##   2 germancredit        med            cv    10        0      0
##   3 germancredit       high            cv    10        0      0

我们比较不同学习器的误分类率和 AUC:

measures = msrs(c("classif.ce", "classif.auc"))
performances = bmr$aggregate(measures)
performances[, .(learner_id, classif.ce, classif.auc)]
autoplot(bmr)

“低”设置似乎有点不适合,“高”设置的标准差比默认设置“中”的大。所以对比三个参数调优模型,本文中还是默认参数的模型更优。

Session info

## R version 4.0.3 (2020-10-10)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Catalina 10.15.7
##
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base
##
## other attached packages:
##  [1] mlr3viz_0.5.3      mlr3learners_0.4.5 mlr3_0.11.0        data.table_1.14.0
##  [5] forcats_0.5.1      stringr_1.4.0      dplyr_1.0.5        purrr_0.3.4
##  [9] readr_1.4.0        tidyr_1.1.3        tibble_3.1.1       ggplot2_3.3.3
## [13] tidyverse_1.3.0
##
## loaded via a namespace (and not attached):
##  [1] httr_1.4.2           sass_0.3.1           jsonlite_1.7.2
##  [4] modelr_0.1.8         bslib_0.2.4          assertthat_0.2.1
##  [7] lgr_0.4.2            highr_0.9            cellranger_1.1.0
## [10] yaml_2.2.1           mlr3misc_0.8.0       globals_0.14.0
## [13] pillar_1.6.0         backports_1.2.1      lattice_0.20-41
## [16] glue_1.4.2           uuid_0.1-4           digest_0.6.27
## [19] checkmate_2.0.0      rvest_1.0.0          colorspace_2.0-0
## [22] htmltools_0.5.1.1    Matrix_1.2-18        pkgconfig_2.0.3
## [25] mlr3measures_0.3.1   broom_0.7.6.9001     listenv_0.8.0
## [28] haven_2.3.1          scales_1.1.1         ranger_0.12.1
## [31] farver_2.0.3         generics_0.1.0       ellipsis_0.3.1
## [34] withr_2.4.1          repr_1.1.3           skimr_2.1.3
## [37] cli_2.4.0            magrittr_2.0.1       crayon_1.4.1
## [40] readxl_1.3.1         paradox_0.7.1        evaluate_0.14
## [43] future_1.21.0        fs_1.5.0             fansi_0.4.2
## [46] parallelly_1.24.0    xml2_1.3.2           palmerpenguins_0.1.0
## [49] tools_4.0.3          hms_1.0.0            lifecycle_1.0.0
## [52] munsell_0.5.0        reprex_2.0.0         compiler_4.0.3
## [55] jquerylib_0.1.3      rlang_0.4.10         grid_4.0.3
## [58] rstudioapi_0.13      base64enc_0.1-3      labeling_0.4.2
## [61] rmarkdown_2.7        codetools_0.2-16     gtable_0.3.0
## [64] DBI_1.1.1            R6_2.5.0             lubridate_1.7.9.2
## [67] knitr_1.33           future.apply_1.7.0   utf8_1.2.1
## [70] stringi_1.5.3        parallel_4.0.3       Rcpp_1.0.6
## [73] vctrs_0.3.7          dbplyr_2.1.0         tidyselect_1.1.0
## [76] xfun_0.22

Reference

Lovelace, Robin, Jakub Nowosad, and Jannes Muenchow. 2019. Geocomputation with r. CRC Press.

Lang, Michel. 2017. “checkmate: Fast Argument Checks for Defensive R Programming.” The R Journal 9 (1): 437–45. https://doi.org/10.32614/RJ-2017-028.

Funk, et al. (2020, July 27). mlr3gallery: Bike Sharing Demand - Use Case. Retrieved from https://mlr3gallery.mlr-org.com/posts/2020-07-27-bikesharing-demand/

Binder & Pfisterer (2020, March 11). mlr3gallery: mlr3tuning Tutorial - German Credit. Retrieved from https://mlr3gallery.mlr-org.com/posts/2020-03-11-mlr3tuning-tutorial-german-credit/

Pfisterer (2020, April 27). mlr3gallery: A Pipeline for the Titanic Data Set - Advanced. Retrieved from https://mlr3gallery.mlr-org.com/posts/2020-04-27-mlr3pipelines-Imputation-titanic/

Li, Lisha, Kevin G. Jamieson, Giulia DeSalvo, Afshin Rostamizadeh, and Ameet Talwalkar. 2016. “Efficient Hyperparameter Optimization and Infinitely Many Armed Bandits.” CoRR abs/1603.06560. http://arxiv.org/abs/1603.06560.

Schratz, Patrick, Jannes Muenchow, Eugenia Iturritxa, Jakob Richter, and Alexander Brenning. 2019. “Hyperparameter Tuning and Performance Assessment of Statistical and Machine-Learning Algorithms Using Spatial Data.” Ecological Modelling 406 (August): 109–20. https://doi.org/10.1016/j.ecolmodel.2019.06.002.

深度学习R语言 mlr3 建模,训练,预测,评估(随机森林,Logistic Regression)相关推荐

  1. 【2022新书】深度学习R语言实战,第二版

    来源:专知 本文为书籍介绍,建议阅读5分钟使用R和强大的Keras库从头开始进行深度学习! R深度学习,第二版 使用R和强大的Keras库从头开始进行深度学习! 在R深度学习第二版中,您将学习: 从基 ...

  2. R语言第十一讲 决策树与随机森林

    概念 决策树主要有树的回归和分类方法,这些方法主要根据分层和分割 的方式将预测变量空间划分为一系列简单区域.对某个给定待预测的观 测值,用它所属区域中训练集的平均值或众数对其进行预测.         ...

  3. 备战数学建模43-决策树随机森林Logistic模型(攻坚站7)

    决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法.由于 ...

  4. 基于R语言对股市价格预测的ARIMA建模

    基于R语言对股市价格预测的ARIMA建模 获取数据 tushare ID=399224 利用ARIMA对股市价格进行拟合后预测,本次实验的数据源于tushare 首先导入本次实验所需要的所有包 req ...

  5. R语言GARCH建模常用软件包比较、拟合标准普尔SP 500指数波动率时间序列和预测可视化...

    原文链接:http://tecdat.cn/?p=24441 我们研究波动聚集,以及使用单变量 GARCH(1,1) 模型对其进行建模. 波动聚集 波动聚集--存在相对平稳时期和高波动时期的现象--是 ...

  6. 基于机器学习与深度学习的金融风控贷款违约预测

    基于机器学习与深度学习的金融风控贷款违约预测 目录 一.赛题分析 1. 任务分析 2. 数据属性 3. 评价指标 4. 问题归类 5. 整体思路 二.数据可视化分析 1. 总体数据分析 2. 数值型数 ...

  7. 【零基础深度学习教程第二课:深度学习进阶之神经网络的训练】

    深度学习进阶之神经网络的训练 神经网络训练优化 一.数据集 1.1 数据集分类 1.2 数据集的划分 1.3 同源数据集的重要性 1.4 无测试集的情况 二.偏差与方差 2.1 概念定义 2.1.1 ...

  8. DeepRMethylSite:一种基于深度学习的蛋白质精氨酸甲基化位点预测方法

    DeepRMethylSite:一种基于深度学习的蛋白质精氨酸甲基化位点预测方法 https://www.researchgate.net/publication/341890599_DeepRMet ...

  9. 用最酷的方法学习R语言

    1. 看大神怎么说 前几天去新疆培训,制作了R语言的基础教程,在翻阅资料时,看到了知乎张敬信关于R学习的观点,很是赞同. 张敬信老师写了一本书<R语言编程–基于tidyverse>,网址: ...

最新文章

  1. Windows编程设备描述表的概念和在客户区绘制、在窗口标题栏绘制、在桌面绘制图解
  2. ACL 2018 论文解读 | 基于深度强化学习的远程监督关系抽取
  3. auto static 的区别
  4. DB2数据库安全的12条军规
  5. 网恋奔现发现对方长得很好看是什么样的体验?
  6. 新常态 新核心,浪潮商用机器为关键行业数字化转型打造新Power
  7. 微信小程序热潮或渐趋冷静
  8. GNSS RTK 北斗GPS接收机多径环境测试接收机自主完好性监测实验
  9. 云队友丨抖音之后,互联网失去创造力
  10. PlaySound error
  11. 求圆周长、圆面积、圆球表面积、圆球体积、圆柱体积。用scanf输入数据,输出计算结果
  12. Inoventica干线网络,600 Gbit / s。
  13. 移动云迁移工具:物理服务器迁移到移动云
  14. 关于勾股数的规律及证明
  15. 结合实际案例谈谈项目管理经验
  16. 湘潭大学 Hurry Up 三分,求凹函数的最小值问题
  17. h5大前端常用网站以及npm模块整理
  18. 几款有意思的html游戏推荐(在线云玩+源码)
  19. 【参赛作品65】MOGDB/openGauss的txid_snapshot 数据类型和相关函数
  20. 计算机内存类型包括什么,计算机内存类型是什么

热门文章

  1. 口腔医疗APP开发的优势和作用分析
  2. 数据“跑路”1.47亿次 浪潮政务云助“蜀道难”转向“全渝通办”
  3. Javabase万年历
  4. 如何学习一门新的语言?
  5. BZOJ[1696][Usaco2007 Feb]Building A New Barn新牛舍 贪心
  6. C++ OpenCV绘制非对称圆点标定图案
  7. nest学习(2) js装饰器+nest控制器
  8. Android自定义圆角圆形图片
  9. Nginx反向代理 对响应网页中的字符串进行替换设置
  10. 定积分华里士公式推广_分部积分法与点火公式|第四十六回|高数(微积分)...