深度森林(Deep Forest)是周志华教授和冯霁博士在2017年2月28日发表的论文《Deep Forest: Towards An Alternative to Deep Neural Networks》中提出来的一种新的可以与深度神经网络相媲美的基于树的模型,其结构如图所示。

gcForest.png

文中提出的多粒度级联森林(Multi-Grained Cascade Forest)是一种决策树集成方法,性能较之深度神经网络有很强的竞争力。相比深度神经网络,gcForest有如下若干有点:

1. 容易训练,计算开销小
2.天然适用于并行的部署,效率高
3. 超参数少,模型对超参数调节不敏感,并且一套超参数可使用到不同数据集
4.可以适应于不同大小的数据集,模型复杂度可自适应伸缩
5. 每个级联的生成使用了交叉验证,避免过拟合
6. 在理论分析方面也比深度神经网络更加容易。

Paper:https://arxiv.org/abs/1702.08835v2
Github:https://github.com/kingfengji/gcForest
Website:http://lamda.nju.edu.cn/code_gcForest.ashx

南京大学机器学习与数据挖掘研究所提供了基于Python 2.7官方实现版本,在本文中,我们使用基于Python3实现的gcForest实现分类任务。

Github:https://github.com/pylablanche/gcForest

gcForest类与sklearn包装的分类器使用方法类似,使用 a .fit() 进行训练,使用a .predict() 进行预测。其中需要我们进行设置的属性为shape_1X和window。shape_1X由数据集决定(所有样本必须具有相同的形状),而window取决于我们自己的选择。

shape_1X 告诉代码我们的样本数据的形状是怎样的,它接受一个列表或数组,其中第一个元素是行数,第二个元素是列数。例如,对于20行和30列的图片,需要给出:shape_1X = [20,30],如果给出长度为40的序列,需要给shape_1X = [1,40]。

window 是数据切片的窗口大小。例如,如果正在使用一个形状[1,40]的序列,并且想要切片的尺寸为20,那么只需设置window = [20]。如果正在使用大小为[20,20]的图片,要进行4x4的切片操作,只需设置“window = [4]”。

分类器构建时需要的参数如下所示:

shape_1X: int or tuple list or np.array (default=None)训练量样本的大小,格式为[n_lines, n_cols]. n_mgsRFtree: int (default=30)多粒度扫描时构建随即森林使用的决策树数量.window: int (default=None)多粒度扫描时的数据扫描窗口大小.stride: int (default=1)数据切片时的步长大小.cascade_test_size: float or int (default=0.2)级联训练时的测试集大小.n_cascadeRF: int (default=2)每个级联层的随机森林的大小.n_cascadeRFtree: int (default=101)每个级联层的随即森林中包含的决策树的数量.min_samples_mgs: float or int (default=0.1)多粒度扫描期间,要执行拆分行为时节点中最小样本数.min_samples_cascade: float or int (default=0.1)训练级联层时,要执行拆分行为时节点中最小样本数.cascade_layer: int (default=np.inf)级联层层数的最大值tolerance: float (default=0.0)判断级联层是否增长的准确度公差。如果准确性的提高不如tolerance,那么层数将停止增长。n_jobs: int (default=1)随机森林并行运行的工作数量。如果为-1,则设置为cpu核心数.

我们使用sklearn带有的Iris数据集进行分类测试,plot_confusion_matrix函数用来绘制混淆矩阵,gcf函数用来进行训练、预测与评估。代码如下所示:

# -*- coding: utf-8 -*-
import itertools
import numpy as np
import matplotlib.pyplot as pltimport sklearn.metrics as metrics
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_splitfrom GCForest import gcForestdef plot_confusion_matrix(cm, classes, normalize=False,title='Confusion matrix', cmap=plt.cm.Blues):plt.imshow(cm, interpolation='nearest', cmap=cmap)plt.title(title)plt.colorbar()tick_marks = np.arange(len(classes))plt.xticks(tick_marks, classes, rotation=45)plt.yticks(tick_marks, classes)if normalize:cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]print("Normalized confusion matrix")else:print('Confusion matrix, without normalization')thresh = cm.max() / 2.for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):plt.text(j, i, cm[i, j],horizontalalignment="center",color="white" if cm[i, j] > thresh else "black")plt.tight_layout()plt.ylabel('True label')plt.xlabel('Predicted label')plt.show()def gcf(X_train, X_test, y_train, y_test, cnames):clf = gcForest(shape_1X=(1, 3), window=[2])clf.fit(X_train, y_train)y_pred = clf.predict(X_test)print()print('accuracy:', metrics.accuracy_score(y_test, y_pred))print('kappa:', metrics.cohen_kappa_score(y_test, y_pred))print(metrics.classification_report(y_test, y_pred, target_names=cnames))cnf_matrix = metrics.confusion_matrix(y_test, y_pred)plot_confusion_matrix(cnf_matrix, classes=cnames, normalize=True,title='Normalized confusion matrix')if __name__ == '__main__':data = load_iris()x = data.datay = data.targetcnames = list(data.target_names)X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2)gcf(X_train, X_test, y_train, y_test, cnames)

深度森林的运行与评估结果如下图所示。模型自动选择了深度为2的结构,我们使用accuracy、kappa、f1-score作为分类结果评估指标,并绘制出其结果的混淆矩阵,可以看出深度森林的分类结果非常可观。

转载:在Python 3中使用深度森林(Deep Forest)进行分类相关推荐

  1. python画树林_在Python 3中使用深度森林(Deep Forest)进行分类

    深度森林(Deep Forest)是周志华教授和冯霁博士在2017年2月28日发表的论文<Deep Forest: Towards An Alternative to Deep Neural N ...

  2. 学习笔记 | 深度森林 Deep Forest

    文章目录 一.前言 二.主要内容 1. Motivations 2. Insights 3. 解决方案的关键 (1)级联森林结构 (2)多粒度扫描 (3)整体过程和超参数 4. 论文的贡献 三.总结 ...

  3. python编程基础知识点总结_【转载】Python编程中常用的12种基础知识总结

    Python编程中常用的12种基础知识总结:正则表达式替换,遍历目录方法,列表按列排序.去重,字典排序,字典.列表.字符串互转,时间对象操作,命令行参数解析(getopt),print 格式化输出,进 ...

  4. 使用keras进行深度学习_如何在Keras中通过深度学习对蝴蝶进行分类

    使用keras进行深度学习 A while ago I read an interesting blog post on the website of the Dutch organization V ...

  5. 周志华团队:深度森林挑战多标签学习,9大数据集超越传统方法

    来源:arXiv 本文转载自新智元(公众号ID:AI_era),未经许可请勿二次转载. [导读]南京大学周志华团队最新研究首次将深度森林引入到多标签学习中,提出多标签深度森林方法MLDF,在9个基准数 ...

  6. python算法工程师招聘_经验 | 我心目中招聘深度学习算法工程师的标准

    原标题:经验 | 我心目中招聘深度学习算法工程师的标准 本文转载自有三AI 目前利用深度学习这个工具可以做很多事情,各大领域(图像,语音,NLP等),各大行业(娱乐,金融,医疗等)这几年都被玩的风生水 ...

  7. 从深度学习到深度森林方法(Python)

    作者 |泳鱼 来源 |算法进阶 一.深度森林的介绍 目前深度神经网络(DNN)做得好的几乎都是涉及图像视频(CV).自然语言处理(NLP)等的任务,都是典型的数值建模任务(在表格数据tabular d ...

  8. Python 循环中的陷阱(转载)

    Python 中的 for 循环和其他语言中的 for 循环工作方式是不一样的,今天就带你深入了解 Python 的 for 循环,看看它是如何工作的,以及它为什么按照这种方式工作. 循环中的陷阱 我 ...

  9. python 引用文件中的类 报错_Python学习笔记7 头文件的添加规则(转载)

    转载自:https://www.cnblogs.com/taurusfy/p/7605787.html ************************************************ ...

  10. 《Python自然语言处理-雅兰·萨纳卡(Jalaj Thanaki)》学习笔记:09 NLU和NLG问题中的深度学习

    09 NLU和NLG问题中的深度学习 9.1 人工智能概览 9.1.1 人工智能的基础 9.1.2 人工智能的阶段 9.1.3 人工智能的种类 9.1.4 人工智能的目标和应用 9.2 NLU和NLG ...

最新文章

  1. BASIC-23_蓝桥杯_芯片测试
  2. mysql count null_MySQL函数大全及用法示例
  3. 仿 小米运动_小米有品上架“黑科技”床垫,让你睡在“空气”上,改变睡眠体验...
  4. PowerShell 2.0远程管理之隐式远程管理
  5. MVC3 Razor 语法检查 -(转)
  6. C++中类中常规变量、const、static、static const(const static)成员变量的声明和初始化...
  7. 44. 将样式表放在顶部(5)
  8. deeplearning.ai——TensorFlow指南
  9. 位图和矢量图转换工具推荐
  10. 基于二阶矩阵的最优化问题(二)(附matlab代码)
  11. safari html5 自动全屏,IOS10全屏safari Javascript
  12. VueX 以及axios
  13. 小白学爬虫---爬取中国房价工资比
  14. 联想x3650服务器安装硬盘,IBM x3650 M2服务器系统安装攻略(组图)
  15. mysql三大日志_了解的mysql三大日志-----binlog
  16. idea 打包报错:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.6.1:compile (defau
  17. 暑假?不进厂?那就卷s同学吧
  18. wifiwan口速率什么意思_无线路由器WAN口状态的意思是什么?
  19. jackson将JSON字符串转换成复杂的数据类型
  20. 【Devc++】迷宫小游戏2.0

热门文章

  1. 2022考研复习第二十三周
  2. matlab 向量转置,matlab中向量和矩阵怎么转置 值得收藏
  3. Matlab关于转置与共轭转置
  4. CMMI 2.0 和 1.3
  5. 海康、大华摄像头chrome高版本实时播放(java集成)
  6. JspStudy套件在部署java项目时,如何去掉项目名进行访问网址问题
  7. 旅游日记——2000元北京6天5夜游
  8. python 答题辅助_GitHub - anwzx/TopSup: 答题辅助决策:冲顶大会等答题类游戏
  9. [OT]ubuntu下安装HP-P1108打印机驱动
  10. springboot + h2 + vue + AceEditor + element-ui 数据库管理系统(DMS)- JavaWeb毕业设计|课程设计