gcForest模型是2018年南京大学机器学习大师周志华老师团队提出来的以决策树和随机森林为基础模型的级联深度森林模型,这个论文我看过了感觉跟我当时硕士期间的一个研究有一点类似,当时我基于XGBOOST的再编码能力有效提升了GBDT模型的分类能力,这个gcForest模型也是需要“再编码”,然后将上一层模型的数据累加到下一层的输入中去。它的表征学习能力可以通过对高维输入数据的多粒度扫描而进行加强。串联的层数也可以通过自适应的决定从而使得模型复杂度不需要成为一个自定义的超参数,而是一个根据数据情况而自动设定的参数。值得注意的是,gcForest会比DNN有更少的超参数,更好的一点在于gcForest对参数是有非常好的鲁棒性,哪怕用默认参数也可以获得很棒的结果。下面是论文中提出的gcForest模型的示意图:

论文中提出了一种Mutil-Grained Scanning的方法,使用窗口切片的方式来进行多粒度的划分,示意图如下:
                                         

gcForest的总体结构示意图如下所示:

我们今天并不是要来详细去讨论分析gcForest模型的内部构造和算法原理,而是基于gcForest模型来做一点实践性的工作来看一下这个模型的表现能力怎么样。

官方的源码在这里,一位外国小哥实现的gcForest模块在这里。感兴趣的话都可以去拿去试试,下面是我具体的实现:

#!usr/bin/env python
#encoding:utf-8
from __future__ import division'''
__Author__:沂水寒城
功能: gcForest 实践
'''import numpy as np
from GCForest import gcForest
from sklearn.externals import joblib
from sklearn.metrics import accuracy_score
from sklearn.datasets import load_iris, load_digits
from sklearn.model_selection import train_test_splitdef irisFunc():'''对鸢尾花数据集进行测试'''iris=load_iris()X,y=iris.data,iris.targetprint('==========================Data Shape======================')print(X.shape,y.shape)X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3)model=gcForest(shape_1X=4, window=2,tolerance=0.0)model.fit(X_train,y_train)#持久化存储joblib.dump(model,'irisModel.sav')model=joblib.load('irisModel.sav')y_predict=model.predict(X_test)print('===========================y_predict======================')print(y_predict)accuarcy=accuracy_score(y_true=y_test,y_pred=y_predict)print('gcForest accuarcy : {}'.format(accuarcy))

上述代码中总体来看十分地简洁,gcForest模型的调用方式同sklearn中其他方法的调用接口十分地相似,这个就不需要太多的学习成本了,得到模型后我们先借助于joblib模块实现了模型的持久化存储,之后加载本地保存的模型来对测试数据集进行预测,结果如下图所示:
                     

刚开始执行的时候,给我了一种深度学习模型启动的感觉,O(∩_∩)O哈哈~,我们可以看到准确率达到了97%以上,可见模型的性能还是不错的。

树模型是可以预测类别的概率的,我们这里也来做一下:

def irisFunc():'''对鸢尾花数据集进行测试'''iris=load_iris()X,y=iris.data,iris.targetprint('==========================Data Shape======================')print(X.shape,y.shape)X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3)model=gcForest(shape_1X=4, window=2,tolerance=0.0)model.fit(X_train,y_train)y_predict=model.predict_proba(X_test)y_predict=y_predict.tolist()print('==========================y_predict======================')for one_res in y_predict:print(one_res)

结果如下:

从上面的结果截图中可以看到:我将类别概率预测结果转化为列表的形式,依次输出每个样本的预测结果,在每个结果中,都有三个数值,分别对应0类、1类和2类这三个类别模型判定的概率,predict方法就是将最大的概率对应的类别进行输出。

最后简单贴一下原论文中作者给出来的各种模型和数据集上gcForest的对比结果统计:

gcForest模型和多粒度扫描的结果对比:

python实践gcForest模型对鸢尾花数据集iris进行分类相关推荐

  1. 基于逻辑回归模型对鸢尾花数据集进行分类

    基于逻辑回归模型对鸢尾花数据集进行分类 理论知识 不做过多赘述,相关知识有:指数分布族.GLM建模(分布函数+连接函数,对于本例来说是二项分布+sigmoid函数).最大似然函数.交叉熵函数(评估逻辑 ...

  2. python KNN分类算法 使用鸢尾花数据集实战

    KNN分类算法,又叫K近邻算法,它概念极其简单,但效果又很优秀. 如觉得有帮助请点赞关注收藏啦~~~ KNN算法的核心是,如果一个样本在特征空间中的K个最相似,即特征空间中最邻近的样本中的大多数属于某 ...

  3. Python-线性判别分析(Fisher判别分析)使用鸢尾花数据集 Iris

    本博客运行环境为Jupyter Notebook.Python3.使用的数据集是鸢尾花数据集. 目录 线性判别分析 代码实现 缺少一组数据的问题已解决!代码已更新! 线性判别分析 线性判别分析(Lin ...

  4. python画蝴蝶结_使用鸢尾花数据集,通过Sklearn,绘制精确率-召回率曲线—Python...

    Python深度学习的一个小例子,用sklearn自己带的鸢尾花数据集训练. 在导入库的过程中,如果导入from sklearn.model_selection import train_test_s ...

  5. 机器学习与深度学习——通过knn算法分类鸢尾花数据集iris求出错误率并进行可视化

    什么是knn算法? KNN算法是一种基于实例的机器学习算法,其全称为K-最近邻算法(K-Nearest Neighbors Algorithm).它是一种简单但非常有效的分类和回归算法. 该算法的基本 ...

  6. 机器学习实战篇:使用贝叶斯模型对鸢尾花数据集分类

    1.简介 本文主要讲解朴素贝叶斯及其推理,并实现鸢尾花数据的分类问题 2.算法解释 朴素贝叶斯最初来源于统计科学领域.根据朴素贝叶斯公式: 由于类似然涉及到多个特征的组合求解较为困难.所以为了简化运算 ...

  7. 2.试读取鸢尾花数据集iris.npz,绘制sepal_length和sepal_width两个特征之间的散点图,X轴添加“SepalLength”标签,Y轴添加“SepalWidth”标签,散点设置

    2022-2023学年第1期期末考试 <Python数据分析与应用>试卷A卷 (大数据技术专业2131.2132班适用 120分钟 机试开卷) 班级 学号 姓名 1 题 号 一 总 分 得 ...

  8. Python原生代码实现KNN算法(鸢尾花数据集)

    一.作业题目 Python原生代码实现KNN分类算法,使用鸢尾花数据集. KNN算法介绍: K最近邻(k-Nearest Neighbor,KNN)分类算法,是机器学习算法之一. 该方法的思路是:如果 ...

  9. 『自己的工作3』梯度下降实现SVM多分类+最详细的数学推导+Python实战(鸢尾花数据集)

    梯度下降实现SVM多分类+最详细的数学推导+Python实战(鸢尾花数据集)! 文章目录 一. SVM梯度公式详细推导 1.1. SVM多分类模型 1.2. SVM多分类梯度公式推导 1.3. SVM ...

  10. python机器学习常用模型

    python机器学习 算法分类 监督学习 定义︰输入数据是由输入特征值和目标值所组成.函数的输出可以是一个连续的值(称为回归),或是输出是有限个离散值(称作分类). 分类: k-近邻 贝叶斯 决策树 ...

最新文章

  1. 基础设施即代码:Terraform和AWS无服务器
  2. 点滴篇(一) 第一篇 博客
  3. 有关nginx upstream的几种配置方式
  4. 暴雪还不赶快?劳拉与光之守护者PC平台登陆
  5. jQuery 动画效果
  6. fastapi PUT更新数据 / PATCH部分更新
  7. java script 技巧_java script 技巧
  8. 这就是库克的重大计划?英特尔新CEO帕特誓言:CPU必须要比苹果好!
  9. Makefile for Sphinx documentation
  10. RHive的安装和用法
  11. php操作elasticsearch
  12. redis实现可重入锁
  13. 循环神经网络-Recurrent Neural Networks
  14. 立创eda学习笔记二十五:绘制原理图的电气网络(绘制导线,使用节点)
  15. java gui 数独_数独-GUI开发
  16. 02Windows日志分析
  17. java里this.a=a,JAVA基础-关键字之this
  18. 批量图片缩小工具,JPG|PNG|BMP图片缩小工具
  19. win10系统查看占用端口
  20. 信号与系统作业之我的朋友把我的大作业分享了好朋友

热门文章

  1. 持续提高安卓应用安全性与性能
  2. vue $emit 父组件与子组件之间的通信(父组件向子组件传参)
  3. Node.js 8有哪些重要功能和修复? 1
  4. 字符函数-(学习笔记)
  5. 学了N年英语,你学会翻译了吗?——最基本的数据库连接
  6. 2010.2--ip redirects 和 ip directed-broadcast含义
  7. Spring 框架基础(03):核心思想 IOC 说明,案例演示
  8. [LuoguP1360][USACP07MAR]黄金阵容均衡
  9. python web框架【补充】自定义web框架
  10. 【.Net Framework 体积大?】不安装.net framework 也能运行!?原理补充-3