之前介绍了随机森林、极端随机森林以及深度森林的原理,本次介绍一下相关的代码

本次实验全部使用糖尿病数据集

数据导入

import pandas as pd
train = pd.read_csv("/Users/admin/Desktop/database/diabetes/diabetes_train.txt",header=None,index_col=False)
test = pd.read_csv("/Users/admin/Desktop/database/diabetes/diabetes_test.txt",header=None,index_col=False)#数据转换
label = train.loc[:,[8]].values.reshape(-1)
data = train.drop(columns=8).values.reshape(-1,8)y_test =  test.loc[:,[8]].values.reshape(-1)
X_test =  test.drop(columns=8).values.reshape(-1,8)

随机森林

rf = RandomForestRegressor(n_estimators=1000)
rf=rf.fit(data, label)
predictions_rf=rf.predict(X_test)
#将概率转换成0、1
pred_rf=judge(predictions_rf)acc = accuracy_score(y_test, pred_rf)
print("Test Accuracy of rf = {:.2f} %".format(acc * 100))

Test Accuracy of rf = 80.97 %

极端随机森林


etr=ExtraTreesRegressor(n_estimators=1000)
etr=etr.fit(data, label) predictions_etr=etr.predict(X_test)
pred_etr=judge(predictions_etr)acc = accuracy_score(y_test, pred_etr)
print("Test Accuracy of pred_etr = {:.2f} %".format(acc * 100))

Test Accuracy of pred_etr = 79.85 %

深度森林

from gcforest.gcforest import GCForestdef get_toy_config():config = {}ca_config = {}ca_config["random_state"] = 0  # 0 or 1ca_config["n_cascadeRFtree"] = 1000ca_config["max_layers"] = 100 # 最大的层数,layer对应论文中的levelca_config["early_stopping_rounds"] = 3 #如果出现某层的三层以内的准确率都没有提升,层中止ca_config["n_classes"] = 2 #判别的类别数量ca_config["estimators"] = []ca_config["estimators"].append({"n_folds": 2, "type": "RandomForestClassifier", "n_estimators": 10, "max_depth": None, "n_jobs": -1})ca_config["estimators"].append({"n_folds": 2, "type": "ExtraTreesClassifier", "n_estimators": 10, "max_depth": None, "n_jobs": -1})ca_config["estimators"].append({"n_folds": 2, "type": "LogisticRegression"})config["cascade"] = ca_config #共使用了3个基学习器return configconfig=get_toy_config()
gc = GCForest(config)
#X_train_enc是每个模型最后一层输出的结果,每一个类别的可能性
X_train_enc = gc.fit_transform(data, label)y_pred = gc.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print("Test Accuracy of GcForest = {:.2f} %".format(acc * 100))

Test Accuracy of GcForest = 80.22 %

MNIST

可以用深度森林跑一下深度学习的常用数据集MNIST


y_pred = gc.predict(x_valid)
acc = accuracy_score(y_valid, y_pred)
print("Test Accuracy of GcForest = {:.2f} %".format(acc * 100))

Test Accuracy of GcForest = 97.32 %

完整代码

随机森林、极端随机森林以及深度森林代码相关推荐

  1. 周志华团队:深度森林挑战多标签学习,9大数据集超越传统方法

    来源:arXiv 本文转载自新智元(公众号ID:AI_era),未经许可请勿二次转载. [导读]南京大学周志华团队最新研究首次将深度森林引入到多标签学习中,提出多标签深度森林方法MLDF,在9个基准数 ...

  2. 决策树(九)--极端随机森林及OpenCV源码分析

    原文: http://blog.csdn.net/zhaocj/article/details/51648966 一.原理 ET或Extra-Trees(Extremely randomized tr ...

  3. 集成学习、Bagging算法、Bagging+Pasting、随机森林、极端随机树集成(Extra-trees)、特征重要度、包外评估

    集成学习.Bagging算法.Bagging+Pasting.随机森林.极端随机树集成(Extra-trees).特征重要度.包外评估 目录

  4. ML之回归预测:利用十类机器学习算法(线性回归、kNN、SVM、决策树、随机森林、极端随机树、SGD、提升树、LightGBM、XGBoost)对波士顿数据集回归预测(模型评估、推理并导到csv)

    ML之回归预测:利用十类机器学习算法(线性回归.kNN.SVM.决策树.随机森林.极端随机树.SGD.提升树.LightGBM.XGBoost)对波士顿数据集[13+1,506]回归预测(模型评估.推 ...

  5. 词袋模型 matlab,【火炉炼AI】机器学习051-视觉词袋模型+极端随机森林建立图像分类器...

    [火炉炼AI]机器学习051-视觉词袋模型+极端随机森林建立图像分类器 (本文所使用的Python库和版本号: Python 3.6, Numpy 1.14, scikit-learn 0.19, m ...

  6. 随机森林(randomForest)和极限树或者叫做极端随机树(extraTree),

    随机森林:是一个包含多个决策树的分类器, 并且其输出的类别是由个别树输出的类别的众数而定.,随机森林对回归的结果在内部是取得平均 但是并不是所有的回归都是取的平均,有些是取的和,以后会发博文来解释这样 ...

  7. 随机森林与极端随机森林

    ET或Extra-Trees(Extremely randomized trees,极端随机树)是由PierreGeurts等人于2006年提出. 该算法与随机森林算法十分相似,都是由许多决策树构成. ...

  8. 决策树与剪枝、bagging与随机森林、极端随机树、Adaboost、GBDT算法原理详解

    目录 1.决策树 1.1 ID3 1.2 C4.5 1.3 CART 1.4 预剪枝和后剪枝 2 bagging与随机森林 2.1 bagging 2.2 随机森林 3 极端随机树 4 GBDT 5 ...

  9. Python实现Stacking回归模型(随机森林回归、极端随机树回归、AdaBoost回归、GBDT回归、决策树回归)项目实战

    说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取. 1.项目背景 Stacking通常考虑的是异质弱学习器(不同的学习算法被组合在 ...

  10. 12_信息熵,信息熵公式,信息增益,决策树、常见决策树使用的算法、决策树的流程、决策树API、决策树案例、随机森林、随机森林的构建过程、随机森林API、随机森林的优缺点、随机森林案例

    1 信息熵 以下来自:https://www.zhihu.com/question/22178202/answer/161732605 1.2 信息熵的公式 先抛出信息熵公式如下: 1.2 信息熵 信 ...

最新文章

  1. centos7安装uwsgi报错_如何将CentOS 7升级到CentOS 8
  2. jzoj4273-圣章-精灵使的魔法语【线段树】
  3. FEIM Studios 团队欢迎您加入
  4. 第二阶段冲刺报告(六)
  5. Java题 细胞分裂
  6. python解压7z文件_如何读取用7z压缩的文本文件?
  7. P.J. Plauger
  8. CentOS7 下配置svn的安装及基础配置介绍
  9. CentOS7_64位操作系统模板搭建
  10. 编程c语言零基础知识,零基础学习C语言都需要掌握哪些基础知识
  11. Word VBA中的光标操作
  12. 表格-table 样式
  13. 帝国cms cj1.php,帝国cms源码中常用函数所在位置
  14. window服务如何通过程序如何打开谷歌浏览器并登陆指定网站_亚马逊如何看listing销量,亚马逊如何看销量排名...
  15. 网页头部声明lang=”zh-cn”、lang=“zh”、lang=“zh-cmn-Hans”区别
  16. 【element-ui】
  17. Git之cherry-pick
  18. 关于32位系统中int、float、short、double等占多少个字节
  19. gif如何压缩?怎么在线gif压缩?
  20. mysql中vlookup函数_EXCEL表格中VLOOKUP函数怎么用

热门文章

  1. php安装libpng,安装php:configure: error: libpng.(a|so) not found解决办法
  2. 【Excel】数据处理与查看
  3. grads插值_grads各类参数设置.pptx
  4. jsp企业员工请假管理系统
  5. 现代电工技术实训考核装置
  6. Python自学记录--steam密码加密逆向
  7. windows/Linux网络工具
  8. 数据库设计的基本规范和原则
  9. Tomcat8.5访问HTML页面出现乱码
  10. STM32单片机简介