前文(机器学习算法 - 随机森林之决策树初探(1))讲述了决策树的基本概念、决策评价标准并手算了单个变量单个分组Gini impurity。是一个基本概念学习的过程,如果不了解,建议先读一下再继续。

本篇通过 R 代码(希望感兴趣的朋友能够投稿这个代码的Python实现)从头暴力方式自写函数训练决策树。之前计算的结果,可以作为正对照,确定后续函数结果的准确性。

训练决策树 - 确定根节点的分类阈值

Gini impurity可以用来判断每一步最合适的决策分类方式,那么怎么确定最优的分类变量和分类阈值呢?

最粗暴的方式是,我们用每个变量的每个可能得阈值来进行决策分类,选择具有最低Gini impurity值的分类组合。这不是最快速的解决问题的方式,但是最容易理解的方式。

定义计算Gini impurity的函数

data <- data.frame(x=c(0,0.5,1.1,1.8,1.9,2,2.5,3,3.6,3.7),y=c(1,0.5,1.5,2.1,2.8,2,2.2,3,3.3,3.5),color=c(rep('blue',3),rep('red',2),rep('green',5)))data##      x   y color
## 1  0.0 1.0  blue
## 2  0.5 0.5  blue
## 3  1.1 1.5  blue
## 4  1.8 2.1   red
## 5  1.9 2.8   red
## 6  2.0 2.0 green
## 7  2.5 2.2 green
## 8  3.0 3.0 green
## 9  3.6 3.3 green
## 10 3.7 3.5 green

首先定义个函数计算每个分支的Gini_impurity

Gini_impurity <- function(branch){# print(branch)len_branch <- length(branch)if(len_branch==0){return(0)}table_branch <- table(branch)wrong_probability <- function(x, total) (x/total*(1-x/total))return(sum(sapply(table_branch, wrong_probability, total=len_branch)))
}

测试下,没问题。

Gini_impurity(c(rep('a',2),rep('b',3)))## [1] 0.48

再定义一个函数,计算每次决策的总Gini impurity.

Gini_impurity_for_split_branch <- function(threshold, data, variable_column, class_column, Init_gini_impurity=NULL){total = nrow(data)left <- data[data[variable_column]<threshold,][[class_column]]left_len = length(left)left_table = table(left)left_gini <- Gini_impurity(left)right <- data[data[variable_column]>=threshold,][[class_column]]right_len = length(right)right_table = table(right)right_gini <- Gini_impurity(right)total_gini <- left_gini * left_len / total + right_gini * right_len /totalresult = c(variable_column,threshold, paste(names(left_table), left_table, collapse="; ", sep=" x "),paste(names(right_table), right_table, collapse="; ", sep=" x "),total_gini)names(result) <- c("Variable", "Threshold", "Left_branch", "Right_branch", "Gini_impurity")if(!is.null(Init_gini_impurity)){Gini_gain <- Init_gini_impurity - total_giniresult = c(variable_column, threshold, paste(names(left_table), left_table, collapse="; ", sep=" x "),paste(names(right_table), right_table, collapse="; ", sep=" x "),Gini_gain)names(result) <- c("Variable", "Threshold", "Left_branch", "Right_branch", "Gini_gain")}return(result)
}

测试下,跟之前计算的结果一致:

as.data.frame(rbind(Gini_impurity_for_split_branch(2, data, 'x', 'color'), Gini_impurity_for_split_branch(2, data, 'y', 'color')))##   Variable Threshold       Left_branch       Right_branch     Gini_impurity
## 1        x         2 blue x 3; red x 2          green x 5              0.24
## 2        y         2          blue x 3 green x 5; red x 2 0.285714285714286

暴力决策根节点和阈值

基于前面定义的函数,遍历每一个可能的变量和阈值。

首先看下基于变量x的计算方法:

uniq_x <- sort(unique(data$x))
delimiter_x <- zoo::rollmean(uniq_x,2)
impurity_x <- as.data.frame(do.call(rbind, lapply(delimiter_x, Gini_impurity_for_split_branch, data=data, variable_column='x', class_column='color')))
print(impurity_x)##   Variable Threshold                  Left_branch                 Right_branch     Gini_impurity
## 1        x      0.25                     blue x 1 blue x 2; green x 5; red x 2 0.533333333333333
## 2        x       0.8                     blue x 2 blue x 1; green x 5; red x 2             0.425
## 3        x      1.45                     blue x 3           green x 5; red x 2 0.285714285714286
## 4        x      1.85            blue x 3; red x 1           green x 5; red x 1 0.316666666666667
## 5        x      1.95            blue x 3; red x 2                    green x 5              0.24
## 6        x      2.25 blue x 3; green x 1; red x 2                    green x 4 0.366666666666667
## 7        x      2.75 blue x 3; green x 2; red x 2                    green x 3 0.457142857142857
## 8        x       3.3 blue x 3; green x 3; red x 2                    green x 2             0.525
## 9        x      3.65 blue x 3; green x 4; red x 2                    green x 1 0.577777777777778

再包装2个函数,一个计算单个变量为决策节点的各种可能决策的Gini impurity, 另一个计算所有变量依次作为决策节点的各种可能决策的Gini impurity

Gini_impurity_for_all_possible_branches_of_one_variable <- function(data, variable, class, Init_gini_impurity=NULL){uniq_value <- sort(unique(data[[variable]]))delimiter_value <- zoo::rollmean(uniq_value,2)impurity <- as.data.frame(do.call(rbind, lapply(delimiter_value, Gini_impurity_for_split_branch, data=data, variable_column=variable, class_column=class,Init_gini_impurity=Init_gini_impurity)))if(is.null(Init_gini_impurity)){decreasing = F} else {decreasing = T}impurity <- impurity[order(impurity[[colnames(impurity)[5]]], decreasing = decreasing),]return(impurity)
}Gini_impurity_for_all_possible_branches_of_all_variables <- function(data, variables, class, Init_gini_impurity=NULL){one_split_gini <- do.call(rbind, lapply(variables,Gini_impurity_for_all_possible_branches_of_one_variable, data=data, class=class,Init_gini_impurity=Init_gini_impurity))if(is.null(Init_gini_impurity)){decreasing = F} else {decreasing = T}one_split_gini[order(one_split_gini[[colnames(one_split_gini)[5]]], decreasing = decreasing),]
}

测试下:

Gini_impurity_for_all_possible_branches_of_one_variable(data, 'x', 'color')##   Variable Threshold                  Left_branch                 Right_branch     Gini_impurity
## 5        x      1.95            blue x 3; red x 2                    green x 5              0.24
## 3        x      1.45                     blue x 3           green x 5; red x 2 0.285714285714286
## 4        x      1.85            blue x 3; red x 1           green x 5; red x 1 0.316666666666667
## 6        x      2.25 blue x 3; green x 1; red x 2                    green x 4 0.366666666666667
## 2        x       0.8                     blue x 2 blue x 1; green x 5; red x 2             0.425
## 7        x      2.75 blue x 3; green x 2; red x 2                    green x 3 0.457142857142857
## 8        x       3.3 blue x 3; green x 3; red x 2                    green x 2             0.525
## 1        x      0.25                     blue x 1 blue x 2; green x 5; red x 2 0.533333333333333
## 9        x      3.65 blue x 3; green x 4; red x 2                    green x 1 0.577777777777778

两个变量的各个阈值分别进行决策,并计算Gini impurity,输出按Gini impurity由小到大排序后的结果。根据变量x和阈值1.95(与上面选择的阈值2获得的决策结果一致)的决策可以获得本步决策的最好结果。

variables <- c('x', 'y')
Gini_impurity_for_all_possible_branches_of_all_variables(data, variables, class="color")##    Variable Threshold                  Left_branch                 Right_branch     Gini_impurity
## 5         x      1.95            blue x 3; red x 2                    green x 5              0.24
## 3         x      1.45                     blue x 3           green x 5; red x 2 0.285714285714286
## 31        y      1.75                     blue x 3           green x 5; red x 2 0.285714285714286
## 4         x      1.85            blue x 3; red x 1           green x 5; red x 1 0.316666666666667
## 6         x      2.25 blue x 3; green x 1; red x 2                    green x 4 0.366666666666667
## 41        y      2.05          blue x 3; green x 1           green x 4; red x 2 0.416666666666667
## 2         x       0.8                     blue x 2 blue x 1; green x 5; red x 2             0.425
## 21        y      1.25                     blue x 2 blue x 1; green x 5; red x 2             0.425
## 51        y      2.15 blue x 3; green x 1; red x 1           green x 4; red x 1              0.44
## 7         x      2.75 blue x 3; green x 2; red x 2                    green x 3 0.457142857142857
## 71        y       2.9 blue x 3; green x 2; red x 2                    green x 3 0.457142857142857
## 61        y       2.5 blue x 3; green x 2; red x 1           green x 3; red x 1 0.516666666666667
## 8         x       3.3 blue x 3; green x 3; red x 2                    green x 2             0.525
## 81        y      3.15 blue x 3; green x 3; red x 2                    green x 2             0.525
## 1         x      0.25                     blue x 1 blue x 2; green x 5; red x 2 0.533333333333333
## 11        y      0.75                     blue x 1 blue x 2; green x 5; red x 2 0.533333333333333
## 9         x      3.65 blue x 3; green x 4; red x 2                    green x 1 0.577777777777778
## 91        y       3.4 blue x 3; green x 4; red x 2                    green x 1 0.577777777777778

  • https://victorzhou.com/blog/intro-to-random-forests/

  • https://victorzhou.com/blog/gini-impurity/

  • https://stats.stackexchange.com/questions/192310/is-random-forest-suitable-for-very-small-data-sets

  • https://towardsdatascience.com/understanding-random-forest-58381e0602d2

  • https://www.stat.berkeley.edu/~breiman/RandomForests/reg_philosophy.html

  • https://medium.com/@williamkoehrsen/random-forest-simple-explanation-377895a60d2d

往期精品(点击图片直达文字对应教程)

后台回复“生信宝典福利第一波”或点击阅读原文获取教程合集

(请备注姓名-学校/企业-职务等)

机器学习算法-随机森林之决策树R 代码从头暴力实现(2)相关推荐

  1. 机器学习算法-随机森林之决策树R 代码从头暴力实现(3)

    前文 (机器学习算法 - 随机森林之决策树初探(1)) 讲述了决策树的基本概念.决策评价标准并手算了单个变量.单个分组的Gini impurity.是一个基本概念学习的过程,如果不了解,建议先读一下再 ...

  2. 机器学习算法 - 随机森林之决策树初探(1)

    随机森林是基于集体智慧的一个机器学习算法,也是目前最好的机器学习算法之一. 随机森林实际是一堆决策树的组合(正如其名,树多了就是森林了).在用于分类一个新变量时,相关的检测数据提交给构建好的每个分类树 ...

  3. 机器学习算法-随机森林之理论概述

    前面我们用 3 条推文从理论和代码角度讲述了决策树的概念和粗暴生成. 机器学习算法-随机森林之决策树R 代码从头暴力实现(3) 机器学习算法-随机森林之决策树R 代码从头暴力实现(2) 机器学习算法 ...

  4. 机器学习算法-随机森林初探(1)

    机器学习算法-随机森林之理论概述 表达数据集来源于 https://file.biolab.si/biolab/supp/bi-cancer/projections/. 为了展示随机森林的能力,我们用 ...

  5. 【机器学习】机器学习算法 随机森林学习 之决策树

    随机森林是基于集体智慧的一个机器学习算法,也是目前最好的机器学习算法之一. 随机森林实际是一堆决策树的组合(正如其名,树多了就是森林了).在用于分类一个新变量时,相关的检测数据提交给构建好的每个分类树 ...

  6. 机器学习算法 随机森林学习 之决策树

    随机森林是基于集体智慧的一个机器学习算法,也是目前最好的机器学习算法之一. 随机森林实际是一堆决策树的组合(正如其名,树多了就是森林了).在用于分类一个新变量时,相关的检测数据提交给构建好的每个分类树 ...

  7. [机器学习算法]随机森林原理

    随机森林 单棵决策树的劣势 有时候单棵决策树可能难以实现较高的准确率,这主要是由以下几个方面决定的: 求解一棵最优(泛化误差最小)的决策树是一个NP难(无法穷极所有可能的树结构)问题,往往得到的是局部 ...

  8. 各维度 特征 重要程度 随机森林_机器学习算法——随机森林

    随机森林简介 随机森林是一种通用的机器学习方法,能够处理回归和分类问题.它还负责数据降维.缺失值处理.离群值处理以及数据分析的其他步骤.它是一种集成学习方法,将一组一般的模型组合成一个强大的模型 工作 ...

  9. 脑电分析系列[MNE-Python-5]| Python机器学习算法随机森林判断睡眠类型

    案例介绍 本案例通过对多导睡眠图(Polysomnography,PSG)数据进行睡眠阶段的分类来判断睡眠类型. 训练:对Alice的睡眠数据进行训练: 测试:利用训练结果对Bob的睡眠数据进行测试, ...

最新文章

  1. android 验证输入,最佳实践:输入验证(Android)
  2. 【IntelliJ IDEA系列】IDEA编译方式介绍及编译器的设置和选择
  3. 软件工程概论作业01
  4. Jenkins(Pipeline)
  5. linux内核那些事之 VMA Gap
  6. 河北大学工商学院计算机分数线,河北大学工商学院录取分数线()
  7. 基于linux的智能小车_ROS全开源阿克曼转向智能网联无人驾驶车
  8. 约瑟夫环c语言程序完整版,约瑟夫环的C语言实现
  9. 机器学习常用十大算法
  10. HTML5从入门到精通电子书pdf下载
  11. Java第32课——求数组元素最大值
  12. 油管最火KMP算法讲解,阿三哥的源代码!
  13. 狂野飙车8:极速凌云 for Mac v1.0.2 Asphalt 8 好玩的赛车游戏
  14. cartographer自动更新地图,2条路径数据合并为1条数据
  15. java大写转化小写的同时小写转化成大写的方法
  16. 集成电路版图设计(一)
  17. configure配置文件
  18. DS1302实时时钟芯片
  19. 使用C语言你必须知道的常见的字符串错误
  20. 实验二+087+饶慧敏

热门文章

  1. 2016科学数据大会临时通知
  2. 【数据库系统】管理持久对象的模式
  3. 【Java】Maven报错 Dependency ‘aspectj:aspectjrt:1.5.4‘ not found 的解决方法
  4. 【Java】JDBC连接MySQL/SQLServer/Oracle三种数据库
  5. 【Python】Matplotlib使用字符串代替变量绘制散点图
  6. 每天打卡心情好(洛谷P1664题题解,Java语言描述)
  7. 【逻辑与计算理论】组合子逻辑与 Y 组合子
  8. 深入理解JVM垃圾收集机制,下次面试你准备好了吗
  9. sqlserver把小数点后面多余的0去掉
  10. backtype.storm.generated.InvalidTopologyException:null问题的解决