论文地址:https://arxiv.org/pdf/2004.00221.pdf

源码地址:https://link.zhihu.com/?target=https%3A//github.com/alvinwan/neural-backed-decision-trees

在线示例:https://research.alvinwan.com/neural-backed-decision-trees/

原博文地址:https://zhuanlan.zhihu.com/p/136015811?utm_source=qq&utm_medium=social&utm_oi=977115204467802112

首先简单讲一下这个可解释的深度学习方法,传统的深度学习方法结合决策树(因为决策树做分类等问题,更加贴切人类思考的逻辑,也就是判断后得到结果),从而实现了精度与可解释性能力的兼顾。

引用机器之心平台说的那句话:“鱼和熊掌我都要!BAIR公布神经支持决策树新研究,兼顾准确率与可解释性。”

接下来的文章转载自机器之心平台,仅供学习记录,如有侵删。

1、背景

随着深度学习在金融、医疗等领域的不断落地,模型的可解释性成了一个非常大的痛点,因为这些领域需要的是预测准确而且可以解释其行为的模型。然而,深度神经网络缺乏可解释性也是出了名的,这就带来了一种矛盾。可解释性人工智能(XAI)试图平衡模型准确率与可解释性之间的矛盾,但 XAI 在说明决策原因时并没有直接解释模型本身。

1.1 决策树及其优缺点

决策树是一种用于分类的经典机器学习方法,它易于理解且可解释性强,能够在中等规模数据上以低难度获得较好的模型。之前很火的微软小冰读心术极可能就是使用了决策树。小冰会先让我们想象一个知名人物(需要有点名气才行),然后向我们询问 15 个以内的问题,我们只需回答是、否或不知道,小冰就可以很快猜到我们想的那个人是谁。

周志华老师曾在「西瓜书」中展示过决策树的示意图:

决策树示意图。

尽管决策树有诸多优点,但历史经验告诉我们,如果遇上 ImageNet 这一级别的数据,其性能还是远远比不上神经网络。

「准确率」和「可解释性」,「鱼」与「熊掌」要如何兼得?把二者结合会怎样?最近,来自加州大学伯克利分校和波士顿大学的研究者就实践了这种想法。

1.2 NBDT 神经支持决策树

他们提出了一种神经支持决策树「Neural-backed decision trees」,在 ImageNet 上取得了 75.30% 的 top-1 分类准确率,在保留决策树可解释性的同时取得了当前神经网络才能达到的准确率,比其他基于决策树的图像分类方法高出了大约 14%。

  • BAIR 博客地址:https://bair.berkeley.edu/blog/2020/04/23/decisions/
  • 论文地址:https://arxiv.org/abs/2004.00221
  • 开源项目地址:https://github.com/alvinwan/neural-backed-decision-trees

这种新提出的方法可解释性有多强?我们来看两张图。

OpenAI Microscope 中深层神经网络可视化后是这样的:

而论文所提方法在 CIFAR100 上分类的可视化结果是这样的:

哪种方法在图像分类上的可解释性强已经很明显了吧。(因为这种可解释,实现了人类的判断过程,所以可解释能力就更强,能够“说人话”)

1.3  决策树的优势与缺陷

在深度学习风靡之前,决策树是准确性和可解释性的标杆。下面,我们首先阐述决策树的可解释性。(可解释性能力很强!!)

如上图所示,这个决策树不只是给出输入数据 x 的预测结果(是「超级汉堡」还是「华夫薯条」),还会输出一系列导致最终预测的中间决策。我们可以对这些中间决策进行验证或质疑。

然而,在图像分类数据集上,决策树的准确率要落后神经网络 40%。神经网络和决策树的组合体也表现不佳,甚至在 CIFAR10 数据集上都无法和神经网络相提并论。(这里指的是  神经网络与决策树的简单组合,   其准确率低的感人)

这种准确率缺陷使其可解释性的优点变得「一文不值」:我们首先需要一个准确率高的模型,但这个模型也要具备可解释性。

1.4 走近神经支持决策树

现在,这种两难处境终于有了进展。加州大学伯克利分校和波士顿大学的研究者通过建立既可解释又准确的模型来解决这个问题。

研究的关键点是将神经网络和决策树结合起来,保持高层次的可解释性,同时用神经网络进行低层次的决策。如下图所示,研究者称这种模型为「神经支持决策树(NBDT)」,并表示这种模型在保留决策树的可解释性的同时,也能够媲美神经网络的准确性。

在这张图中,每一个节点都包含一个神经网络,上图放大标记出了一个这样的节点与其包含的神经网络。在这个 NBDT 中,预测是通过决策树进行的,保留高层次的可解释性。但决策树上的每个节点都有一个用来做低层次决策的神经网络,比如上图的神经网络做出的低层决策是「有香肠」或者「没有香肠」。

NBDT 具备和决策树一样的可解释性。并且 NBDT 能够输出预测结果的中间决策,这一点优于当前的神经网络。

(除了逻辑上  利用模仿人的逻辑方式进行逐条判断,还给出了中间的步骤,更好地实现了可解释。这一点在Pro2Pnet上也有体现,讲中间的每一个模板都单独展示并保存。)

如下图所示,在一个预测「狗」的网络中,神经网络可能只输出「狗」,但 NBDT 可以输出「狗」和其他中间结果(动物、脊索动物、肉食动物等)。

此外,NBDT 的预测层次轨迹也是可视化的,可以说明哪些可能性被否定了。

与此同时,NBDT 也实现了可以媲美神经网络的准确率。在 CIFAR10、CIFAR100 和 TinyImageNet200 等数据集上,NBDT 的准确率接近神经网络(差距<1%),在 ImageNet 上的准确性差距也在 2% 内,达到了 75.30%,比基于决策树的现有最佳方法高出整整 14%,实现了可解释模型在准确率上的新 SOTA。

(讲道理,这里的SOTA原意应该是视觉效果,但是 精度上的state of the art ,应该当做“史无前例”与“追求完美”)

针对论文实验精度  尊重“基准”再追求“SOTA”问题,分享一个链接:
“只有达到SOTA的方法才能发文章吗”

https://zhuanlan.zhihu.com/p/57528638

1.5 神经支持决策树是如何解释的

(A)个体特殊实例可解释性

最有参考价值的辩证理由是面向该模型从未见过的对象。例如,考虑一个 NBDT(如下图所示),同时在 Zebra 斑马样本上进行推演。虽然此模型从未见过斑马,但下图所显示的中间决策是正确的-斑马既是动物又是蹄类动物。

对于从未见过的物体而言,个体预测的合理性至关重要。这一点,很多模型都很难做到(因为他们都直接给出错误分类答案,而不能够将正确的中间推导过程展示出来)

(B)避免准确性与可解释性的权衡

此外,研究者发现使用 NBDT,可解释性随着准确性的提高而提高。这与文章开头中介绍的准确性与可解释性的对立背道而驰,即:NBDT 不仅具有准确性和可解释性,还可以使准确性和可解释性成为同一目标。

(说实话,前面看的那篇“不要做事后解释”的论文中,提到了准确性和可解释性存在一个负相关,并且引用了别人的一篇论文实验结果辅证。而这里的两个特性居然能够同时提升,目标一致,说明这个网络真的还挺牛

分层结构在权重空间中对群集向量进行分类,但是在权重空间中接近的类可能不具有类似的语义含义:下图分别描绘了由WideResNet20x10ResNet10诱导的树。虽然WideResNet诱导的层次结构(8A)对语义相似的类进行分组,但ResNet(8B)诱导的层次结构不是这样,将青蛙、猫和飞机等类分组。WideResNet的准确率提高了4%并且它的分类更加合理,更符合人的判断方式,这解释了语义意义上的差异:我们认为,准确度越高的模型在语义上表现出更多的声音权重空间。因此,与以前的工作不同的是,NBDT的特点是更好的可解释性和更高的准确性,而不是牺牲一个来换取另一个。此外,层次结构中的差异表明,低精度、可解释的模型不能提供对高精度决策的洞察力;需要可解释的、最先进的模型来解释最先进的神经网络。

例如,ResNet10 的准确度比 CIFAR10 上的 WideResNet28x10 低 4%。相应地,较低精度的 ResNet ^ 6 层次结构(左)将青蛙,猫和飞机分组在一起且意义较小,因为很难找到三个类共有的视觉特征。而相比之下,准确性更高的 WideResNet 层次结构(右)更有意义,将动物与车完全分离开了。

(看到这里我有点迷糊了,这个我印象中的Resnet不太一样呀!  但是这篇论文中可是结合了决策树的神经网络,所以这里就以决策树的形式展示了网络结构。)

另外针对 WideResNet ,分享一篇文章:

这个宝藏博主总结了近些年的经典CNN网络,其他的博文也值得一看。

https://www.cnblogs.com/liaohuiqiang/p/9691458.html

因此可以说,准确性越高,NBDT 就越容易解释。

补充一些文章中对相关工作的介绍:

从决策树到神经网络。最近的工作还用决策树[13]提供的权重播种神经网络,重新引起了人们对基于梯度的方法[29]的兴趣。这些方法在UCI数据集[9]上显示了非常稀疏特征和稀疏样本的经验证据。

神经网络到决策树 最近的工作[10]使用蒸馏,训练决策树来模拟神经网络的输入输出函数。所有这些工作都是在简单的数据集(UCI[9]MNIST[20])上进行评估,而我们的方法是在更复杂的数据集(CIFAR10[18]CIF AR100[18]TinyImageNet[19])上进行评估。

将神经网络与决策树相结合。最近的工作是将神经网络与决策树相结合,将推理扩展到有很多高维样本数据集。深度神经决策森林[16]的性能与ImageNet上的神经网络相匹配。然而,这发生在剩余网络开始之前,通过使用不纯的树叶和需要森林来牺牲模型的可解释性。Murthy等人。[23]提出为决策树中的每个节点建立一个新的神经网络,并给出可解释的输出。艾哈迈德等人。1通过在所有节点之间共享主干来修改这一点,但仅支持深度-2树;NofE认为ImageNet的性能与ResNet之前的架构相媲美。我们的方法进一步建立在此基础上,不仅共享主干,而且共享完全连接层;此外,我们在保持可解释性的同时,还显示了与最先进的神经网络(包括残差网络)的竞争性能。

一些工作没有明确地将神经网络和决策树相结合,而是从决策树中借鉴了神经网络的思想,反之亦然。特别地,几种重新设计的神经网络结构利用决策树分支结构[352134]。虽然精确度提高了,但这种方法牺牲了决策树的可解释性。其他人使用决策树来分析神经网络权重[3924]。这会带来相反的后果,要么牺牲准确性,要么不支持预测机制。正如我们假设和展示的那样,高精度的决策树对于解释和解释高精度的模型是必要的。此外,我们具有竞争力的性能表明,不需要牺牲准确性和可解释性。

视觉解释。一个正交但占主导地位的可解释性方向包括生成显著图,该图突出神经网络决策所使用的空间证据[3037283827262531]。诸如引导反向传播[30]、去卷积[3728]GradCAM[27]和积分梯度[31]之类的白盒技术使用网络的梯度来确定图像中最显著的区域,而诸如LIME[26]RISE[25]之类的黑盒技术通过扰动输入并测量预测中的变化来确定像素重要性。显著图只解释单个图像,当网络出于错误的原因(例如,一只鸟被错误地归类为飞机)查看正确的东西时,它是没有帮助的。另一方面,我们的方法在整个数据集上表示模型的先验,并显式地将每个分类分解为一系列中间决策。

2、了解决策规则

使用低维表格数据时,决策树中的决策规则很容易解释,例如,如果盘子中有面包,然后分配给合适的孩子(如下所示)。然而,决策规则对于像高维图像的输入而言则不是那么直接。模型的决策规则不仅基于对象类型,而且还基于上下文,形状和颜色等等。

此案例演示了如何使用低维表格数据轻松解释决策的规则。为了定量解释决策规则,研究者使用了 WordNet3 的现有名词层次;通过这种层次结构可以找到类别之间最具体的共享含义。例如,给定类别 Cat 和 Dog,WordNet 将反馈哺乳动物。

(插一嘴,WordNet家族应该是做自然语言处理用的网络,当然百度大大说这个是一个大牛做的英语词典,只是存在一个特殊的结构,根据单词的意义组成一个“单词的网络”。

像我这种做深度学习可解释性这种佛性方向的,一开始还有点陌生。)

在下图中,研究者定量验证了这些 WordNet 假设。

左侧从属树(红色箭头)的 WordNet 假设是 Vehicle。右边的 WordNet 假设(蓝色箭头)是 Animal。

值得注意的是,在具有 10 个类(如 CIFAR10)的小型数据集中,研究者可以找到所有节点的 WordNet 假设。但是,在具有 1000 个类别的大型数据集(即 ImageNet)中,则只能找到节点子集中的 WordNet 假设。

2.1 How it Works?

Neural-Backed 决策树的训练与推断过程可分解为如下四个步骤:

  1. 为决策树构建称为诱导层级「Induced Hierarchy」的层级;
  2. 该层级产生了一个称为树监督损失「Tree Supervision Loss」的独特损失函数;
  3. 通过将样本传递给神经网络主干开始推断。在最后一层全连接层之前,主干网络均为神经网络;
  4. 以序列决策法则方式运行最后一层全连接层结束推断,研究者将其称为嵌入决策法则「Embedded Decision Rules」。

Neural-Backed 决策树训练与推断示意图。

针对这四个步骤,我的理解就是:

①首先构造一个决策树,用于最后全连接做决策。

②针对第一步构造的数据结构,生成一个特定的树结构监督损失函数(因为这里的网络结构不再是传统的神经网络,所以要针对树结构构建一个更合适的LOSS function)

③搭建一个backbone神经网络获取特征

④嵌入决策法则——按序列进行决策,完成全连接阶段的推断工作

2.2 运行嵌入决策法则(第四步)

这里首先讨论推断问题。如前所述,NBDT 使用神经网络主干提取每个样本的特征。

为便于理解接下来的操作,研究者首先构建一个与全连接层等价的退化决策树,如下图所示:

以上产生了一个矩阵-向量乘法,之后变为一个向量的内积,这里将其表示为$\hat{y}$。

以上输出最大值的索引即为对类别的预测。

简单决策树(naive decision tree):研究者构建了一个每一类仅包含一个根节点与一个叶节点的基本决策树,如上图中「B—Naive」所示。每个叶节点均直接与根节点相连,并且具有一个表征向量(来自 W 的行向量)。

使用从样本提取的特征 x 进行推断意味着,计算 x 与每个子节点表征向量的内积。类似于全连接层,最大内积的索引即为所预测的类别。

全连接层与简单决策树之间的直接等价关系,启发研究者提出一种特别的推断方法——使用内积的决策树。

2.3 构建诱导层级(第一步)

该层级决定了 NBDT 需要决策的类别集合。由于构建该层级时使用了预训练神经网络的权重,研究者将其称为诱导层级。

具体地,研究者将全连接层中权重矩阵 W 的每个行向量,看做 d 维空间中的一点,如上图「Step B」所示。

接下来,在这些点上进行层级聚类。连续聚类之后便产生了这一层级。

2.4 使用树监督损失进行训练(第二、三步)

考虑上图中的「A-Hard」情形。假设绿色节点对应于 Horse 类。这只是一个类,同时它也是动物(橙色)。对结果而言,也可以知道到达根节点(蓝色)的样本应位于右侧的动物处。到达节点动物「Animal」的样本也应再次向右转到「Horse」。所训练的每个节点用于预测正确的子节点。

研究者将强制实施这种损失的树称为树监督损失(Tree Supervision Loss)。换句话说,这实际上是每个节点的交叉熵损失。

在地块中,一簇点用绿色圆圈标记,另一簇用黄色标记。每个圆的中心由它的两个灰点的平均值给出。在每个绘图的右侧绘制相应的决策树。

一旦给出一个点,决策树的根将计算具有最接近代表向量(绿色或黄色的点)的子节点。请注意,类4(红色)的所有样本将比正确的父级(绿色)更接近错误的父级(黄色)。这是因为A试图用4聚类2,用31。因此,神经网络很难获得高精度,因为它需要大幅移动所有的点来区分黄色和绿色的点。B:对于相同的点,这棵树将12聚为一簇,而将34聚为一簇,从而产生更多可分离的簇。请注意,B中的决策边界(虚线)相对于绿点和黄点的边距要大得多。因此,对于神经网络来说,这棵树更容易对点进行正确分类。

图A:红色点是绿色类,但他更接近黄色点的圆心而不是绿色

图B:这种情况下划分好,不容易出错

由于直接使用平均值作为特征向量有误差,所以这里添加了损失函数

3、 使用指南

我们可以直接使用 Python 包管理工具来安装 nbdt:

pip install nbdt

安装好 nbdt 后即可在任意一张图片上进行推断,nbdt 支持网页链接或本地图片。

nbdt https://images.pexels.com/photos/126407/pexels-photo-126407.jpeg?auto=compress&cs=tinysrgb&dpr=2&w=32# OR run on a local image
nbdt /imaginary/path/to/local/image.png

不想安装也没关系,研究者为我们提供了网页版演示以及 Colab 示例,地址如下:

  • Demo:http://nbdt.alvinwan.com/demo/
  • Colab:http://nbdt.alvinwan.com/notebook/

下面的代码展示了如何使用研究者提供的预训练模型进行推断:

from nbdt.model import SoftNBDT
from nbdt.models import ResNet18, wrn28_10_cifar10, wrn28_10_cifar100, wrn28_10  # use wrn28_10 for TinyImagenet200model = wrn28_10_cifar10()
model = SoftNBDT(pretrained=True,dataset='CIFAR10',arch='wrn28_10_cifar10',model=model)

另外,研究者还提供了如何用少于 6 行代码将 nbdt 与我们自己的神经网络相结合,详细内容请见其 GitHub 开源项目。

阅读笔记5:神经支持决策树(可解释性)相关推荐

  1. 阅读笔记2: 深度学习可解释性学习:不要做事后解释

    选择可解释性高的机器学习模型, 而不是解释决策风险高的黑匣子模型 (原论文名:Stop Explaining Black Box Machine Learning Models for High St ...

  2. [论文阅读笔记53]2021深度神经方法的关系三元组抽取综述

    1. 题目 Deep Neural Approaches to Relation Triplets Extraction: A Comprehensive Survey Tapas Nayak†, N ...

  3. 论文阅读笔记——基于CNN-GAP可解释性模型的软件源码漏洞检测方法

    本论文相关内容 论文下载地址--Engineering Village 论文阅读笔记--基于CNN-GAP可解释性模型的软件源码漏洞检测方法 文章目录 本论文相关内容 前言 基于CNN-GAP可解释性 ...

  4. 【知识图谱】本周文献阅读笔记(3)——周二 2023.1.10:英文)知识图谱补全研究综述 + 网络安全知识图谱研究综述 + 知识图谱嵌入模型中的损失函数 + 图神经网络应用于知识图谱推理的研究综述

    声明:仅学习使用~ 对于各文献,目前仅是泛读形式,摘出我认为重要的点,并非按照原目录进行简单罗列! 另:鉴于阅读paper数目稍多,对paper内提到的多数模型暂未细致思考分析.目的是总结整理关于KG ...

  5. 【知识图谱】 | 《知识图谱——方法、实践与应用》阅读笔记

    <知识图谱--方法.实践与应用>的阅读笔记 知识图谱--方法.实践与应用 第1章 知识图谱概述 1.1 什么是知识图谱 1.2 知识图谱的发展历史 1.3 知识图谱的价值 1.4 国内外典 ...

  6. 论文阅读笔记:为什么深度神经网络的训练无论多少次迭代永远有效?可能类内分布已经坍缩为一个点,模型已经崩溃为线性分类器

    论文阅读笔记:Prevalence of neural collapse during the terminalphase of deep learning training,深度学习训练末期普遍的神 ...

  7. 论文Learning to Solve Large-Scale Security-Constrained Unit Commitment Problems阅读笔记

    论文Learning to Solve Large-Scale Security-Constrained Unit Commitment Problems阅读笔记 论文arxiv链接:Learning ...

  8. A Survey of Deep Learning-based Object Detection论文翻译 + 阅读笔记

    A Survey of Deep Learning-based Object Detection论文翻译 + 阅读笔记 //2022.1.7 日下午16:00开始阅读 双阶段检测器示意图 单阶段检测器 ...

  9. 《Beta Embeddings for Multi-Hop Logical Reasoning in Knowledge Graphs》论文阅读笔记

    <Beta Embeddings for Multi-Hop Logical Reasoning in Knowledge Graphs>论文阅读笔记 主要挑战贡献: KG上的推理挑战主要 ...

最新文章

  1. Android移动开发之【通往安卓的神奇之旅】Android的五大布局和AndroidManifest
  2. 每天一点小知识004--关于获取物体名字
  3. zabbix安装报错
  4. 中石油训练赛 - Cafebazaar’s Chess Tournament(FFT)
  5. 小奇遐想 树状数组实现+容斥思想
  6. Mac键盘被锁的解决方法
  7. 提升效率Mysql函数(function)|存储过程(procedure)
  8. 医学图像之DICOM格式解析
  9. 对计算机网络的认识400字,对计算机网络的初步认识
  10. html5的指南针app,HTML5 App实战(五):指南针
  11. 练习时长两年半的Matlab
  12. M0、M1、M2、M3
  13. 传统推荐模型——协同过滤
  14. 为什么越来越多的绘图员开始用云渲染来渲图?
  15. NetFPGA-SUME下reference_nic测试
  16. 读《男子为让孩子成为北京人执意找京籍女结婚(图)》有感——致北漂的XDJM
  17. 解决Iframe嵌入帆软BI系统后,Chrome升级后跨域出现登录界面,Cookie写入不成功。
  18. 【验证码功能合集】vue简单实现验证码功能,纯前端实现验证码,拿来即用【输入,滑动,拼图】
  19. php yield Generator 处理大数组
  20. 数字转换为人民币的大写(复制直接用)

热门文章

  1. 【函数】oracle translate() 详解+实例
  2. python实训主要成果_Python实训周总结
  3. 在 vue-cli 项目中添加标签页图标 favicon
  4. qam映射c程序_基于星型24QAM映射的光概率成型编码方法与流程
  5. 手机怎么搜索备忘录内容?
  6. linux+usb刻录,如何在Ubuntu上安装Etcher-开源USB刻录机工具
  7. AES加密工具类AESUtil
  8. 思考网游企业的上市浪潮
  9. 我的世界国际版中如何安装光影
  10. 正则表达式--前缀r