点上方计算机视觉联盟获取更多干货

仅作学术分享,不代表本公众号立场,侵权联系删除

转载于:机器之心

AI博士笔记系列推荐

周志华《机器学习》手推笔记正式开源!可打印版本附pdf下载链接

在人工智能发展史上,各类算法可谓层出不穷。近十几年来,深层神经网络的发展在机器学习领域取得了显著进展。通过构建分层或「深层」结构,模型能够在有监督或无监督的环境下从原始数据中学习良好的表征,这被认为是其成功的关键因素。

而深度森林,是 AI 领域重要的研究方向之一。

2017 年,周志华和冯霁等人提出了深度森林框架,这是首次尝试使用树集成来构建多层模型的工作。2018 年,周志华等人又在研究《Multi-Layered Gradient Boosting Decision Trees》中探索了多层的决策树。今年 2 月,周志华团队开源深度森林软件包 DF21:训练效率高、超参数少,在普通设备就能运行。

就在近日,TensorFlow 开源了 TensorFlow 决策森林 (TF-DF)。TF-DF 是用于训练、服务和解释决策森林模型(包括随机森林和梯度增强树)生产方面的 SOTA 算法集合。现在,你可以使用这些模型进行分类、回归和排序任务,具有 TensorFlow 和 Keras 的灵活性和可组合性。

谷歌大脑研究员、Keras之父François Chollet表示:「现在可以用Keras API训练TensorFlow决策森林了。」

对于这一开源项目,网友表示:「这非常酷!随机森林是我最喜欢的模型。」

决策森林

决策森林是一系列机器学习算法,其质量和速度可与神经网络相竞争(它比神经网络更易于使用,功能也很强大),实际上与特定类型的数据配合使用时,它们比神经网络更出色,尤其是在处理表格数据时。

随机森林是一种流行的决策森林模型。在这里,你可以看到一群树通过投票结果对一个例子进行分类。

决策森林是由许多决策树构建的,它包括随机森林和梯度提升树等。这使得它们易于使用和理解,而且可以利用已经存在的大量可解释性工具和技术进行操作。

决策树是一系列仅需做出是 / 否判断的问题,使用决策树将动物分成鸡、猫、袋鼠。

TF-DF 为 TensorFlow 用户带来了模型和一套定制工具:

  • 对初学者来说,开发和解释决策森林模型更容易。不需要显式地列出或预处理输入特征(因为决策森林可以自然地处理数字和分类属性)、指定体系架构(例如,通过尝试不同的层组合,就像在神经网络中一样),或者担心模型发散。一旦你的模型经过训练,你就可以直接绘制它或者用易于解释的统计数据来分析它。

  • 高级用户将受益于推理时间非常快的模型(在许多情况下,每个示例的推理时间为亚微秒)。而且,这个库为模型实验和研究提供了大量的可组合性。特别是,将神经网络和决策森林相结合是很容易的。

如上图所示,只需使用一行代码就能构建模型,相比之下,动图中的下面代码是用于构建神经网络的代码。在 TensorFlow 中,决策森林和神经网络都使用 Keras。可以使用相同的 API 来实验不同类型的模型,更重要的是,可以使用相同的工具,例如 TensorFlow Serving 来部署这两种模型。

以下是 TF-DF 提供的一些功能:

  • TF-DF 提供了一系列 SOTA 决策森林训练和服务算法,如随机森林、CART、(Lambda)MART、DART 等。

  • 基于树的模型与各种 TensorFlow 工具、库和平台(如 TFX)更容易集成,TF-DF 库可以作为通向丰富 TensorFlow 生态系统的桥梁。

  • 对于神经网络用户,你可以使用决策森林这种简单的方式开始 TensorFlow,并继续探索神经网络。

代码示例

下面进行示例展示,可以让使用者简单明了。

  • 项目地址:https://github.com/tensorflow/decision-forests

  • TF-DF 网站地址:https://www.tensorflow.org/decision_forests

  • Google I/O 2021 地址:https://www.youtube.com/watch?v=5qgk9QJ4rdQ

模型训练

在数据集 Palmer's Penguins 上训练随机森林模型。目的是根据一种动物的特征来预测它的种类。该数据集包含数值和类别特性,并存储为 csv 文件。

Palmer's Penguins 数据集示例。

模型训练代码:

# Install TensorFlow Decision Forests
!pip install tensorflow_decision_forests
# Load TensorFlow Decision Forests
import tensorflow_decision_forests as tfdf
# Load the training dataset using pandas
import pandas
train_df = pandas.read_csv("penguins_train.csv")
# Convert the pandas dataframe into a TensorFlow dataset
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="species")
# Train the model
model = tfdf.keras.RandomForestModel()
model.fit(train_ds)

请注意,代码中没有提供输入特性或超参数。这意味着,TensorFlow 决策森林将自动检测此数据集中的输入特征,并对所有超参数使用默认值。

评估模型

现在开始对模型的质量进行评估:

# Load the testing dataset
test_df = pandas.read_csv("penguins_test.csv")
# Convert it to a TensorFlow dataset
test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_df, label="species")
# Evaluate the model
model.compile(metrics=["accuracy"])
print(model.evaluate(test_ds))
# >> 0.979311
# Note: Cross-validation would be more suited on this small dataset.
# See also the "Out-of-bag evaluation" below.
# Export the model to a TensorFlow SavedModel
model.save("project/my_first_model")

带有默认超参数的随机森林模型为大多数问题提供了一个快速和良好的基线。决策森林一般会对中小尺度问题进行快速训练,与其他许多类型的模型相比,需要较少的超参数调优,并且通常会提供强大的结果。

解读模型

现在,你已经了解了所训练模型的准确率,接下来该考虑它的可解释性了。如果你希望理解和解读正被建模的现象、调试模型或者开始信任其决策,可解释性就变得非常重要了。如上所述,有大量的工具可用来解读所训练的模型。首先从 plot 开始:

tfdf.model_plotter.plot_model_in_colab(model, tree_idx=0)

其中一棵决策树的结构。

你可以直观地看到树结构。此外,模型统计是对 plot 的补充,统计示例包括:

  • 每个特性使用了多少次?

  • 模型训练的速度有多快(树的数量和时间)?

  • 节点在树结构中是如何分布的(比如大多数 branch 的长度)?

这些问题的答案以及更多类似查询的答案都包含在模型概要中,并可以在模型检查器中访问。

# Print all the available information about the model
model.summary()
>> Input Features (7):
>>   bill_depth_mm
>>   bill_length_mm
>>   body_mass_g>>
...
>> Variable Importance:
>>   1.    "bill_length_mm" 653.000000 ################
>>   ...
>> Out-of-bag evaluation: accuracy:0.964602 logloss:0.102378
>> Number of trees: 300
>> Total number of nodes: 4170
>>   ...
# Get feature importance as a array
model.make_inspector().variable_importances()["MEAN_DECREASE_IN_ACCURACY"]
>> [("flipper_length_mm", 0.149),
>>      ("bill_length_mm", 0.096),
>>      ("bill_depth_mm", 0.025),
>>      ("body_mass_g", 0.018),
>>      ("island", 0.012)]

在上述示例中,模型通过默认超参数值进行训练。作为首个解决方案而言非常好,但是调整超参数可以进一步提升模型的质量。可以如下这样做:

# List all the other available learning algorithms
tfdf.keras.get_all_models()
>> [tensorflow_decision_forests.keras.RandomForestModel,
>>  tensorflow_decision_forests.keras.GradientBoostedTreesModel,
>>  tensorflow_decision_forests.keras.CartModel]
# Display the hyper-parameters of the Gradient Boosted Trees model
? tfdf.keras.GradientBoostedTreesModel
>> A GBT (Gradient Boosted [Decision] Tree) is a set of shallow decision trees trained sequentially. Each tree is trained to predict and then "correct" for the errors of the previously trained trees (more precisely each tree predicts the gradient of the loss relative to the model output).....Attributes:num_trees: num_trees: Maximum number of decision trees. The effective number of trained trees can be smaller if early stopping is enabled. Default: 300.max_depth: Maximum depth of the tree. `max_depth=1` means that all trees will be roots. Negative values are ignored. Default: 6....# Create another model with specified hyper-parameters
model = tfdf.keras.GradientBoostedTreesModel(num_trees=500,growing_strategy="BEST_FIRST_GLOBAL",max_depth=8,split_axis="SPARSE_OBLIQUE",)
# Evaluate the model
model.compile(metrics=["accuracy"])
print(model.evaluate(test_ds))#
>> 0.986851

参考链接:

https://blog.tensorflow.org/2021/05/introducing-tensorflow-decision-forests.html

-------------------

END

--------------------

我是王博Kings,985AI博士,华为云专家、CSDN博客专家(人工智能领域优质作者)。单个AI开源项目现在已经获得了2100+标星。现在在做AI相关内容,欢迎一起交流学习、生活各方面的问题,一起加油进步!

我们微信交流群涵盖以下方向(但并不局限于以下内容):人工智能,计算机视觉,自然语言处理,目标检测,语义分割,自动驾驶,GAN,强化学习,SLAM,人脸检测,最新算法,最新论文,OpenCV,TensorFlow,PyTorch,开源框架,学习方法...

这是我的私人微信,位置有限,一起进步!

王博的公众号,欢迎关注,干货多多

王博Kings的系列手推笔记(附高清PDF下载):

博士笔记 | 周志华《机器学习》手推笔记第一章思维导图

博士笔记 | 周志华《机器学习》手推笔记第二章“模型评估与选择”

博士笔记 | 周志华《机器学习》手推笔记第三章“线性模型”

博士笔记 | 周志华《机器学习》手推笔记第四章“决策树”

博士笔记 | 周志华《机器学习》手推笔记第五章“神经网络”

博士笔记 | 周志华《机器学习》手推笔记第六章支持向量机(上)

博士笔记 | 周志华《机器学习》手推笔记第六章支持向量机(下)

博士笔记 | 周志华《机器学习》手推笔记第七章贝叶斯分类(上)

博士笔记 | 周志华《机器学习》手推笔记第七章贝叶斯分类(下)

博士笔记 | 周志华《机器学习》手推笔记第八章集成学习(上)

博士笔记 | 周志华《机器学习》手推笔记第八章集成学习(下)

博士笔记 | 周志华《机器学习》手推笔记第九章聚类

博士笔记 | 周志华《机器学习》手推笔记第十章降维与度量学习

博士笔记 | 周志华《机器学习》手推笔记第十一章稀疏学习

博士笔记 | 周志华《机器学习》手推笔记第十二章计算学习理论

博士笔记 | 周志华《机器学习》手推笔记第十三章半监督学习

博士笔记 | 周志华《机器学习》手推笔记第十四章概率图模型

点分享

点收藏

点点赞

点在看

周志华团队 | TensorFlow开源决策森林库TF-DF相关推荐

  1. 最喜欢随机森林?周志华团队 DF21 后,TensorFlow 开源决策森林库 TF-DF

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 转自 | 机器之心 TensorFlow 决策森林 (TF-DF) ...

  2. 最喜欢随机森林?周志华团队DF21后,TensorFlow开源决策森林库TF-DF

    来源:机器之心本文约2500字,建议阅读9分钟TensorFlow 开源了 TensorFlow 决策森林 (TF-DF). TensorFlow 决策森林 (TF-DF) 现已开源,该库集成了众多 ...

  3. 周志华团队:深度森林挑战多标签学习,9大数据集超越传统方法

    来源:arXiv 本文转载自新智元(公众号ID:AI_era),未经许可请勿二次转载. [导读]南京大学周志华团队最新研究首次将深度森林引入到多标签学习中,提出多标签深度森林方法MLDF,在9个基准数 ...

  4. 南大周志华团队开源深度森林软件包DF21:训练效率高、超参数少,普通设备就能跑 | AI日报...

    中国学者研发新型电子纹身,实现8倍延展,有望用于医疗.VR和可穿戴机器人等领域 可穿戴设备,已经成为我们生活中极为常见的一种设备,它们体积轻巧.佩戴方便.检测数据齐全,但也存在一个很明显的缺点--无法 ...

  5. ID3的REP(Reduced Error Pruning)剪枝代码详细解释+周志华《机器学习》决策树图4.5、图4.6、图4.7绘制

    处理数据对象:离散型数据 信息计算方式:熵 数据集:西瓜数据集2.0共17条数据 训练集(用来建立决策树):西瓜数据集2.0中的第1,2,3,6,7,10,14,15,16,17,4 请注意,书上说是 ...

  6. 【数据产品案例】周志华团队和蚂蚁金服合作:用分布式深度森林算法检测套现欺诈

    案例来源:@AI科技大本营 案例地址: https://mp.weixin.qq.com/s?__biz=MzI0ODcxODk5OA==&mid=2247495146&idx=1&a ...

  7. 9大数据集6大度量指标完胜,周志华等提出深度森林处理多标签学习

    2019-11-25 11:01:57 选自arXiv 机器之心编译参与:路雪.一鸣 近日,南大周志华等人首次提出使用深度森林方法解决多标签学习任务.该方法在 9 个基准数据集.6 个多标签度量指标上 ...

  8. 【深度森林第三弹】周志华等提出梯度提升决策树再胜DNN

    [深度森林第三弹]周志华等提出梯度提升决策树再胜DNN 技术小能手 2018-06-04 14:39:46 浏览848 分布式 性能 神经网络 还记得周志华教授等人的"深度森林"论 ...

  9. 《周志华机器学习详细公式推导版》发布,Datawhale开源项目pumpkin-book

    点击上方↑↑↑蓝字关注我们~ 「2019 Python开发者日」全日程揭晓,请扫码咨询 ↑↑↑ 来源 | Datawhale(ID:Datawhale) 如果让你推荐两本国内机器学习的入门经典作,你会 ...

最新文章

  1. 9开启线程日志_GC 日志分析
  2. 1.2控制台的大体设置:
  3. 阿里巴巴云舒:弹性计算的安全问题
  4. 7、java中的面向对象思想和体现
  5. 做值钱的事比赚钱更有意义
  6. SSH集成项目,使用注解方式,竟然还有这样的问题!!
  7. 有道词典java下载电脑版下载手机版下载安装_【有道词典官方下载】有道词典PC版下载_多特软件站...
  8. r语言 wiod_数据可视化基本套路总结
  9. 服务器30hz显示器240hz,显示器刷新率上不去,这锅到底让谁背
  10. php手册chm打开空白
  11. 免费公开课:讲解DevExpress 2016.2新版本功能
  12. Xilinx ZYNQ Ultrascale+ 性能测试之 Video Multi Scaler
  13. 计算机太极之光,3000多名研究生赛太极,五大太极拳流派名家展风采
  14. qq邮箱怎么发送html文件在哪里,QQ邮箱怎么发送文件夹
  15. XP IIS之——问题总结
  16. poj 1061青蛙的约会
  17. iOS进阶 - 包大小:如何从资源和代码层面实现全方位瘦身
  18. tensorflow-tf基础
  19. VSCode的LeetCode插件中国区账号密码登录错误
  20. $.ajax跨域请求数据的解决方案

热门文章

  1. groovy怎样从sql语句中截取表名_SQl-查询篇
  2. html5获取城市,HTML5 geolocation API获得用户当前城市名
  3. java8实战怎么样_Java8中你可能不知道的一些地方之Stream实战
  4. 在虚拟机中安装linux6,如何在vmvare中安装redhat linux6虚拟机
  5. 易语言注入 c dll,易语言DLL注入模块简单型
  6. python列表存储乱码_python 列表中文乱码
  7. stm32——modbus例程网址收藏
  8. 创建一个dynamics 365 CRM online plugin (三) - PostOperation
  9. MySQL无法创建外键、查询外键的属性
  10. 【统计学习】随机梯度下降法求解感知机模型