NBDT: Neural-Backed Decision Trees

文章目录

  • NBDT: Neural-Backed Decision Trees
    • 简介
    • 摘要
    • 初步
    • 相关工作
    • 方法
      • 使用嵌入式决策规则进行推理
      • 构建诱导层次结构
      • 用树木监督损失进行训练
    • 实验结果分析
      • 节点语义的可解释性
      • 避免准确性与可解释性的权衡
      • 树遍历的可视化实现
    • 补充笔记
    • 详细步骤
      • 建立诱导层次(Induced Hierarchy)树结构
      • 使用树树监督损失调整网络CNN参数

简介

论文标题

  • NBDT: Neural-Backed Decision Trees
  • NBDT:神经支持决策树
  • 2020.1

贡献

  • 我们提出了一种将任何分类神经网络作为决策树运行的方法,方法是定义一组嵌入的决策规则,这些规则可以从完全连通层构造出来。我们还设计了易于神经网络学习的诱导层次结构。
  • 我们提出了树监督损失,它使神经网络的准确率提高了0.5个百分点,并产生了高精度的NBDT。我们在小型、中型和大型图像分类数据集上证明了我们的NBDT达到了与神经网络相当的精度。
  • 我们为我们的模型决策提供了语义解释的定性和定量证据。

该工具可以直接在以下地址在线使用:

  • Demo:http://nbdt.alvinwan.com/demo/
  • Colab:http://nbdt.alvinwan.com/notebook/
  • git : https://github.com/alvinwan/neural-backed-decision-trees
  • 论文:https://arxiv.org/abs/2004.00221

论文中先讲的推理后讲的训练

摘要

深度学习正被用于需要准确和合理的预测的环境中,从金融到医学成像。虽然最近有为模型预测提供事后解释的工作,但探索更直接可解释的模型以匹配最先进的准确性的工作相对较少。从历史上看,决策树一直是平衡可解释性和准确性的黄金标准。然而,最近将决策树与深度学习相结合的尝试导致了以下模型:(1)即使在较小的数据集(例如MNIST)上,实现的精度也远远低于现代神经网络(例如ResNet),以及(2)需要显著不同的体系结构,迫使实践者在准确性和可解释性之间做出选择。我们通过创建神经支持的决策树(NBDs)来摆脱这一困境,它(1)实现了神经网络的准确性,(2)不需要对神经网络的体系结构进行任何改变。使用最新的WideResNet,NBDT在CIFAR10、CIFAR100、TinyImageNet上的基本神经网络的精确度在1%以内;在ImageNet上,NBDT在EfficientNet上的精确度在2%以内。这在ImageNet上产生了最先进的可解释模型,NBDT将基准∼提高了14%到75.30%TOP-1准确率。此外,我们还通过半自动过程展示了我们模型决策的定性和定量可解释性。代码和预先培训的NBDT可以在github.com/alvinwan/neuralbacked-decision-trees.上找到。

初步

在这项工作中,我们提出了神经支持决策树(NBDT)来使最先进的计算机视觉模型具有可解释性。这些NBDT不需要特殊的体系结构:任何用于图像分类的神经网络都可以通过微调和自定义损失转换成NBDT。此外,NBDT通过将图像分类分解成中间决策序列来执行推理。然后,这个决策序列可以映射到更多可解释的概念,并在底层类中揭示可感知的信息层次结构。关键的是,与计算机视觉中关于决策树的先前工作相比,NBDT在CIFAR10[18]、CIFAR100[18]、TinyImageNet[19]和ImageNet[8]上的最新结果具有竞争力,并且比基于可比决策树的方法精确度大大提高(高达18%),同时也更易于解释。

我们介绍了一种针对NBDT的两阶段训练过程。首先,我们计算一个层次,称为诱导层次(图1,步骤1)。该层次结构是从已经在目标数据集上训练的神经网络的权重导出的。其次,我们使用专门为该树设计的自定义损失来微调网络,称为树监督损失(图1,步骤2)。这种损失迫使模型在给定固定的树层次结构的情况下最大化决策树的准确性。然后,我们分两步进行推理:(1)使用网络主干(图1,步骤3)为每幅训练图像构造特征.然后,对于每个节点,我们在给定决策树层次结构的情况下,在网络的权重空间中计算最能代表其子树中的叶子的向量-我们将该向量称为代表性向量。(2)从根节点开始,将每个样本发送给与样本具有最相似代表向量的子节点。我们继续采摘和遍历这棵树,直到我们到达一片树叶。与此叶相关的类是我们的预测(图1,步骤4)。这与引入可解释性障碍的相关工作形成对比,例如不纯树叶[16]或模型集合[1,16]

图1:神经支持诊断树。在步骤1中,我们使用预先训练的网络的完全连接层权重来构建层次结构(第3.2节)。在步骤2中,我们使用自定义损耗(秒)微调网络。3.3)。在步骤3中,我们使用神经网络主干对样本进行特征化。在步骤4中,我们使用完全连接层的权重来运行决策规则(SEC。3.1)。如上所示,步骤3中的橙色箭头与步骤4中的树的橙色节点相关联。同样,绿色箭头映射到绿色节点。该树获取传入样本与橙色w1和绿色w2矢量中的每一个之间的内积;预测具有较高内积的叶子。

如何构建树,如何训练树,如何选择树分类

相关工作

从决策树到神经网络。最近的工作还用决策树[13]提供的权重播种神经网络,重新引起了人们对基于梯度的方法[29]的兴趣。这些方法在UCI数据集[9]上显示了非常稀疏特征和稀疏样本的经验证据。

神经网络到决策树 最近的工作[10]使用蒸馏,训练决策树来模拟神经网络的输入输出函数。所有这些工作都是在简单的数据集(如UCI[9]或MNIST[20])上进行评估,而我们的方法是在更复杂的数据集(如CIF
AR10[18]、CIF AR100[18]和TinyImageNet[19])上进行评估。

将神经网络与决策树相结合。最近的工作是将神经网络与决策树相结合,将推理扩展到有很多高维样本数据集。深度神经决策森林[16]的性能与ImageNet上的神经网络相匹配。然而,这发生在剩余网络开始之前,通过使用不纯的树叶和需要森林来牺牲模型的可解释性。Murthy等人。[23]提出为决策树中的每个节点建立一个新的神经网络,并给出可解释的输出。艾哈迈德等人。1通过在所有节点之间共享主干来修改这一点,但仅支持深度-2树;NofE认为ImageNet的性能与ResNet之前的架构相媲美。我们的方法进一步建立在此基础上,不仅共享主干,而且共享完全连接层;此外,我们在保持可解释性的同时,还显示了与最先进的神经网络(包括残差网络)的竞争性能。

一些工作没有明确地将神经网络和决策树相结合,而是从决策树中借鉴了神经网络的思想,反之亦然。特别地,几种重新设计的神经网络结构利用决策树分支结构[35,21,34]。虽然精确度提高了,但这种方法牺牲了决策树的可解释性。其他人使用决策树来分析神经网络权重[39,24]。这会带来相反的后果,要么牺牲准确性,要么不支持预测机制。正如我们假设和展示的那样,高精度的决策树对于解释和解释高精度的模型是必要的。此外,我们具有竞争力的性能表明,不需要牺牲准确性和可解释性。

视觉解释。一个正交但占主导地位的可解释性方向包括生成显著图,该图突出神经网络决策所使用的空间证据[30,37,28,38,27,26,25,31]。诸如引导反向传播[30]、去卷积[37,28]、GradCAM[27]和积分梯度[31]之类的白盒技术使用网络的梯度来确定图像中最显著的区域,而诸如LIME[26]和RISE[25]之类的黑盒技术通过扰动输入并测量预测中的变化来确定像素重要性。显著图只解释单个图像,当网络出于错误的原因(例如,一只鸟被错误地归类为飞机)查看正确的东西时,它是没有帮助的。另一方面,我们的方法在整个数据集上表示模型的先验,并显式地将每个分类分解为一系列中间决策。

方法

在本节中,我们描述了将任何分类神经网络转换为决策树的建议步骤,如图1所示:

(1)建立诱导层次结构(SEC。3.2)、

(2)使用树监督损失(SEC)对模型进行微调。3.3)。

对于推理,(3)使用神经网络主干对样本进行特征化,

以及(4)运行嵌入在完全连接层(SEC)中的决策规则。3.1)。

使用嵌入式决策规则进行推理

首先,我们的NBDT方法使用神经网络主干对每个样本进行特征化;主干由最终完全连接层之前的所有神经网络层组成

其次,在每个节点,我们取特征化样本x∈Rd与每个子节点的代表向量ri之间的内积。请注意,所有代表向量都是从神经网络的完全连通层权重计算出来的。因此,这些决策规则被“嵌入”在神经网络中。

第三,我们使用这些内积来做出硬决策或软决策,如下所述。

为了激励我们为什么使用内积,我们将首先构建一个等价于完全连接层的退化决策树

完全连接层

全连接层的权重矩阵为W∈Rk×dW \in \mathbb{R}^{k \times d}W∈Rk×d。用特征化样运行推理是矩阵向量的乘积:

其中,⟨x,wi⟩=y^i\left\langle x, w_{i}\right\rangle=\hat{y}_{i}⟨x,wi​⟩=y^​i​,最大的就是预测的

决策树

考虑一棵最小树,它有一个根节点和k个子节点。每个子节点是叶,并且每个子节点具有代表向量,即行向量ri=wir_{i}=w_{i}ri​=wi​。用特征化样本x运行推断意味着取x和每个子节点的代表向量ri之间的内积,其被写为⟨x,ri⟩=⟨x,wi⟩=y^i\left\langle x, r_{i}\right\rangle=\left\langle x, w_{i}\right\rangle=\hat{y}_{i}⟨x,ri​⟩=⟨x,wi​⟩=y^​i​。与全连接层一样,最大乘积y^i\hat{y}_{i}y^​i​的指标也是我们的类预测。图2(B.)说明了这一点。

尽管这两个计算的表示方式不同,但都是通过取最大内积argmax⁡⟨x,wi⟩\operatorname{argmax}\left\langle x, w_{i}\right\rangleargmax⟨x,wi​⟩的索引来预测类。我们将决策树推理称为运行嵌入式决策规则

接下来,我们将朴素决策树扩展到退化情况之外。我们的判决规则要求每个子节点具有代表性向量ri。因此,如果我们将一个非叶子子代添加到根,那么这个非叶子子代将需要一个代表性向量。我们天真地认为非叶的代表向量是所有子树的叶的代表向量的平均值。对于包含中间节点的更复杂的树结构,现在有两种方式来运行推理:

  • 硬决策树。计算所有子节点上每个节点的argmax。对于每个节点,获取与最大内积对应的子节点,并遍历该子节点。这个过程选择一片叶子(图2,A.硬)。
  • 软决策树。在每个节点上计算所有子节点上的Softmax,以获得每个节点上每个子节点的概率。对于每个叶,获取从其父级遍历该叶的概率。然后取遍历树叶的父代与其祖辈的概率。继续直到到达根部。这个乘积是那片叶子及其到根部的路径的概率。树遍历将为每个叶生成一个概率。在这个叶子分布上计算argmax,以选择一个叶子(图2,C.Soft)。

这允许我们将任何分类神经网络作为嵌入的决策规则序列来运行。然而,以这种方式简单地运行标准问题的预先训练的神经网络将导致较差的精度。在下一节中,我们将讨论如何通过微调神经网络使其在确定层次结构后执行良好,从而最大限度地提高精度。

笔记:

主干部分: 完全连接层之前的所有神经网络层

完全连接层用于拆成决策树

每个节点都有对应的一行向量,其中

叶子节点的特征向量对应原有的权重矩阵中的一行向量

非叶子节对应其子树叶子的所有特征向量的平均值,

树的结构是通过层次聚类或者是wordnet预定义层次结构而来的

如何选择分支是通过取最大内积argmax⁡⟨x,wi⟩\operatorname{argmax}\left\langle x, w_{i}\right\rangleargmax⟨x,wi​⟩的方式来的,称为嵌入式决策规则

主要是讲如何通过内积选择分支

构建诱导层次结构

使用上述内积决策规则,网络可以直观地更容易地学习决策树层次结构。这些更容易的层次结构可以更准确地反映网络是如何达到高精度的。为此,我们对从完全连接的层权重W提取的类代表W运行分层聚集聚类,如上一节所述。3.1,每个叶是一个Wi(图3,步骤B),并且每个中间节点的代表向量是其子树叶子的所有代表的平均值(图3,步骤C)。我们把这个层次称为诱导层次(图3)。

此外,我们还使用另一种基于WordNet的层次结构进行了实验。Wordnet[22]提供了一个现有的名词层次结构,我们利用它在语言上将每个数据集中的类联系起来。我们找到了WordNet层次结构的最小子集,其中包括所有类作为叶子,修剪冗余的叶子和单子中间节点。因此,WordNet关系为该候选决策树提供了“自由”和可解释的标签,例如将一只猫也归类为哺乳动物和生物。为了利用这个“自由”的标签源,我们通过找到每个子树叶子的最早祖先,为诱导层次结构中的每个中间节点自动生成假设。

图3:构建诱导层次结构。步骤A,将预先训练好的神经网络最终的全连通层的权值加载到权重矩阵W∈Rd×kW \in \mathbb{R}^{d \times k}W∈Rd×k;步骤B,以W为代表的每一列作为每个叶节点的代表向量。例如,A中的红色w1被指定给B中的红色叶子。步骤C使用每对叶子的平均值作为父代的代表向量。例如,B中的w1和w2(红色和紫色)平均为C中的w5(蓝色)。步骤D。对于每个祖先,取其根所在的子树。子树中所有树叶的平均表示向量。这个平均值是祖先的代表性矢量。在这个图中,祖先是根,所以它的代表向量是所有叶子w1、w2、w3、w4的平均值。

这幅图说明了如何构建一棵树,注意构建树之前一般需要预训练的权重

用树木监督损失进行训练

上面提出的所有决策树都有一个主要问题:即使鼓励原始神经网络为每一类分离代表向量,但它没有被训练为为每个内部节点分离代表向量。图4说明了这一点。为了解决这个问题,我们添加了损失项,鼓励神经网络在训练期间分离内部节点的代表。现在我们依次解释硬决策规则和软决策规则的附加损失条款(图5)。

图4:病理性诊断树。在地块中,一簇点用绿色圆圈标记,另一簇用黄色标记。每个圆的中心由它的两个灰点的平均值给出。在每个绘图的右侧绘制相应的决策树。答:一旦给出一个点,决策树的根将计算具有最接近代表向量(绿色或黄色的点)的子节点。请注意,类4(红色)的所有样本将比正确的父级(绿色)更接近错误的父级(黄色)。这是因为A试图用4聚类2,用3聚1。因此,神经网络很难获得高精度,因为它需要大幅移动所有的点来区分黄色和绿色的点。B:对于相同的点,这棵树将1与2聚为一簇,而将3与4聚为一簇,从而产生更多可分离的簇。请注意,B中的决策边界(虚线)相对于绿点和黄点的边距要大得多。因此,对于神经网络来说,这棵树更容易对点进行正确分类。

图A:红色点是绿色类,但他更接近黄色点的圆心而不是绿色

图B:这种情况下划分好,不容易出错

由于直接使用平均值作为特征向量有误差,所以这里添加了损失函数

对于硬决策规则,我们使用硬树监督损失。原始神经网络的损失Loriginal \mathcal{L}_{\text {original }}Loriginal ​最小化了跨类的交叉熵。对于k类数据集,这是k路交叉熵损失。每个内部节点的目标是相似的,最小化跨子节点的交叉熵损失。对于具有c个子节点的节点i,这是预测概率D(i)pred \mathcal{D}(i)_{\text {pred }}D(i)pred ​和标签D(i)label \mathcal{D}(i)_{\text {label }}D(i)label ​之间的c路交叉熵损失。我们将这组新的损失术语称为硬树监督损失(等式2)。默认情况下,每个节点的单个交叉熵损失被缩放,使得原始交叉熵损失和树监督损失被相等地加权。我们在SEC中测试了各种权重方案。4.2.。如果我们假设树中有N个节点(不包括树叶),那么我们将有N+1个不同的交叉熵损失项-原始的交叉熵损失项和N个硬树监督损失项。这是Loriginal +Lhard \mathcal{L}_{\text {original }}+\mathcal{L}_{\text {hard }}Loriginal ​+Lhard ​,其中:
Lhard=1N∑i=1NCROSSENTROPY⁡(D(i)pred ,D(i)label )⏟over the cchildren for each node \mathcal{L}_{\mathrm{hard}}=\frac{1}{N} \sum_{i=1}^{N} \underbrace{\operatorname{CROSSENTROPY}\left(\mathcal{D}(i)_{\text {pred }}, \mathcal{D}(i)_{\text {label }}\right)}_{\text {over the } c \text { children for each node }} Lhard​=N1​i=1∑N​over the c children for each node CROSSENTROPY(D(i)pred ​,D(i)label ​)​​
对于软决策规则,我们使用软树监督损失。在3.1节中,我们描述了软决策树如何在树叶上提供单一分布Dpred。我们在这个分布上增加了交叉熵损失。总共有2个不同的交叉熵损失项-原始交叉熵损失项和软树监督损失项。这是Loriginal +Lsoft \mathcal{L}_{\text {original }}+\mathcal{L}_{\text {soft }}Loriginal ​+Lsoft ​,其中:
Lsoft =CROSSENTROPY (Dpred ,Dlabel )\mathcal{L}_{\text {soft }}=\text { CROSSENTROPY }\left(\mathcal{D}_{\text {pred }}, \mathcal{D}_{\text {label }}\right) Lsoft ​= CROSSENTROPY (Dpred ​,Dlabel ​)
硬决策树因为每个节点都是一个概率分布,所以每个节点都算一次交叉熵

软决策树因为整体服从一个概率分布,所以只在叶子上算一次交叉熵

图5:树监督损失有两种变体:硬树监督损失(A)定义了每个节点的交叉熵项。蓝色节点的蓝色框和橙色节点的橙色框说明了这一点。交叉熵被取而代之的是子节点概率。绿色节点是标签叶。虚线节点不包括在从标签到根的路径中,因此没有定义的损耗。软树监督损失(B)定义了所有叶概率上的交叉熵损失。绿叶的概率是通向根部的概率的乘积(在本例中,⟨x,w2⟩⟨x,w6⟩=0.6×0.7\left\langle x, w_{2}\right\rangle\left\langle x, w_{6}\right\rangle=0.6 \times 0.7⟨x,w2​⟩⟨x,w6​⟩=0.6×0.7)。其他树叶的概率被类似地定义。每个叶概率用一个彩色方框表示。然后在该叶概率分布上计算交叉熵,该分布由坐在彼此直接相邻的彩色框表示。

实验结果分析

精度表格略

节点语义的可解释性

由于诱导层次是使用模型权重构建的,因此不会强制对特定属性进行拆分。虽然像WordNet这样的层次结构为节点的含义提供了假设,但图6显示WordNet是不够的,因为树可能会在上下文属性(如水下和陆地)上分裂。为了诊断节点含义,我们执行以下4步测试:

图6:使用(A)WordNet层次结构和(B)来自训练有素的ResNet10模型的诱导树,对TinyImageNet中的10个类进行树可视化。

  1. 假设节点的含义(例如,动物与车辆)。这一假设可以从给定的分类法(如WordNet)自动计算出来,也可以从手动检查每个孩子的叶子中推导出来(图7)。
  2. 收集一个数据集,其中包含测试步骤1中节点的假设含义的新的、看不见的类(例如,Elephant是一种看不见的动物)。此数据集中的样本称为分布外样本,因为它们是从单独标记的数据集中提取的。
  3. 将此数据集中的样本传递给相关节点。对于每个样本,检查所选子节点是否与假设一致。
  4. 假设的准确性是传递给正确孩子的样本的百分比。如果精确度较低,请使用不同的假设重复。

这个过程自动验证WordNet假设,但WordNet之外的假设需要人工干预。图7a描述了由在CIFAR10上训练的WideResNet28x10模型诱导的CIFAR10树。我们的假设是,根节点在Animal和Vehicle上分裂。我们从CIF
AR100收集在培训时间看不到的动物和车辆类的分发外图像。然后我们计算假设的准确性。图7b显示了我们的假设准确地预测了每个看不见的类的样本遍历的是哪个子类。

注意,诱导出来的树形结构需要检查或自动推导每个非叶子节点的含义

避免准确性与可解释性的权衡

入的分层结构在权重空间中对群集向量进行分类,但是在权重空间中接近的类可能不具有类似的语义含义:图8分别描绘了由WideResNet20x10和ResNet10诱导的树。虽然WideResNet诱导的层次结构(图8A)对语义相似的类进行分组,但ResNet(图8B)诱导的层次结构不是这样,将青蛙、猫和飞机等类分组。WideResNet的准确率提高了4%,这解释了语义意义上的差异:我们认为,准确度越高的模型在语义上表现出更多的声音权重空间。因此,与以前的工作不同的是,NBDT的特点是更好的可解释性和更高的准确性,而不是牺牲一个来换取另一个。此外,层次结构中的差异表明,低精度、可解释的模型不能提供对高精度决策的洞察力;需要可解释的、最先进的模型来解释最先进的神经网络。

左边的精度高,解释性也好,用来证明精度和解释性是可以相辅相成的

值得注意的是,在具有 10 个类(如 CIFAR10)的小型数据集中,研究者可以找到所有节点的 WordNet 假设。但是,在具有 1000 个类别的大型数据集(即 ImageNet)中,则只能找到节点子集中的 WordNet 假设。

小型分类上可以用wordnet去解释节点,但大型分类网络则不能全部都找到wordnet上的解释

树遍历的可视化实现

为了不仅解释树层次结构,也解释树遍历,我们可视化了通过每个节点的样本的百分比(图9)。这既突出显示了正确的路径(最频繁遍历的路径),又允许我们解释常见的错误路径(图9A)。具体地说,我们可以解释遍历节点的叶子之间共享的属性。这些属性可以是背景或场景,但也可以是颜色或形状。图9B描绘了描述上下文的样本的路径。在这种情况下,很少有动物在海滨环境中被认出,而船只几乎总是在那个环境中被看到。图9C描述了属于不符合假设节点的属性但保持路径一致性的非分布类的样本的路径。在这种情况下,泰迪倾向于动物类,特别是狗,因为它有相似的形状和视觉特征。

图9:三个不同类的路径遍历频率的可视化。(A)分配类:马使用在训练中发现的类样本。中间节点的假设含义来自WordNet。(B)背景类:海边使用训练时看不到的样本,表明对背景的依赖。©混淆类:Teddy使用在节点含义中识别边缘情况的样本。

况下,泰迪倾向于动物类,特别是狗,因为它有相似的形状和视觉特征。

[外链图片转存中…(img-zbTn0SCl-1588410850893)]

图9:三个不同类的路径遍历频率的可视化。(A)分配类:马使用在训练中发现的类样本。中间节点的假设含义来自WordNet。(B)背景类:海边使用训练时看不到的样本,表明对背景的依赖。©混淆类:Teddy使用在节点含义中识别边缘情况的样本。

https://mp.weixin.qq.com/s/WrfLMXfgFbk_SaMvy2pweg

补充笔记

特点

  1. 任何用于图像分类的神经网络都可以通过微调和自定义损失转换成NBDT
  2. NBDT是在一个预训练好的模型的基础上,把最后一层全连接改成一个树再微调得到的
  3. 因此,NBDT的结构可以大致认为是“前面的CNN + 后面的DT(决策树)”
  4. NBDT没有传统神经网络精度高,但比传统决策树的精度高
  5. NBDT平衡了模型准确率模型解释性

过程

如图所示

训练过程

  1. 建立诱导层次(Induced Hierarchy)树结构
  2. 使用树树监督损失调整网络CNN参数

推理过程

  1. 利用诱导层次(Induced Hierarchy)树结构进行判断

详细步骤

建立诱导层次(Induced Hierarchy)树结构

  1. 如图stepA和stepB所示,将全连接层的权重矩阵(倒数第一层和倒数第二层)按照列向量分给每一个叶子节点,这个向量称为节点的代表向量,(叶子节点个数=全连接层节点的个数=分类个数)
  2. 对叶子节点按照代表向量进行层次聚类得到整个树的结构(如图StepD,W5是由W1和W2聚类来的,在其他情况下),树的结构与层次聚类的结果有关
  3. 父节点的代表向量等于子节点代表向量的均值

此时,我们获得了一颗每个节点都有代表向量的树结构.

此时,子节点具有语义,即每个子节点代表着对应的类别.但父节点的语义还不明确

  1. 按照WordNet的层次结构,在WordNet中寻找两个子节点最近的父节点类别,并为父节点标识语义.(e.g. “猫”和“狗”在WordNet中可能最近的归属是都位于“哺乳动物”下,那么“哺乳动物”就被作为“猫”和“狗”的父节点。)

此时,决策树构建完毕.

使用树树监督损失调整网络CNN参数

首先,看原始CNN的全连接层是如何进行判断是哪个分类的

设,全连接层的权重矩阵为W∈Rk×dW \in \mathbb{R}^{k \times d}W∈Rk×d。

其中,⟨x,wi⟩=y^i\left\langle x, w_{i}\right\rangle=\hat{y}_{i}⟨x,wi​⟩=y^​i​,最大乘积y^i\hat{y}_{i}y^​i​就是预测的.

对于最简单的树

如图b所示,考虑一棵最简单的树,它有一个根节点和k个子节点,并且每个子节点具有代表向量wiw_{i}wi​.

因此,叶子节点与x的内积⟨x,wi⟩=y^i\left\langle x, w_{i}\right\rangle=\hat{y}_{i}⟨x,wi​⟩=y^​i​,就相当于是全连接层的判断分类,最大乘积y^i\hat{y}_{i}y^​i​就是目标类别.

这成为运行嵌入式决策规则

对于复杂情况

子节点拥有了判断能力,但是父节点的向量值为子节点的均值,这可能不具备判断能力

为了让父节点和子节点一样具有判断能力(这样决策树才能从上往下判断),需要对原始网络结构参数进行调整,

使其适应新的决策树判断方式.微调的损失函数有两种方式,根据调整使用的损失函数不同,决策树可分为两种

**硬决策树 **每次x会与左右两边的子节点分别算内积,哪边大就把x归为哪一边,一直计算到叶节点为止,最后x落到的叶节点,即为x所属的最终类别。

这就是每个父节点和两个子节点是一个单独的概率分布.(局部最优,相同父节点的叶子节点和为1)

软决策树。则x会自顶向下遍历全部中间节点并计算内积,然后叶节点的最终概率是到达叶节点的路径上各中间节点的概率之乘积,最后通过比较各叶节点上的最终概率值的大小,即可确定x所属类别。

决策树的概率分布统一.(全局最优,所有叶子节点和为1)

损失函数

如果我们假设树中有N个节点(不包括树叶),那么我们将有N+1个不同的交叉熵损失项:原始的交叉熵损失项和N个硬树监督损失项。这是Loriginal +Lhard \mathcal{L}_{\text {original }}+\mathcal{L}_{\text {hard }}Loriginal ​+Lhard ​,
Lhard=1N∑i=1NCROSSENTROPY⁡(D(i)pred ,D(i)label )⏟over the cchildren for each node \mathcal{L}_{\mathrm{hard}}=\frac{1}{N} \sum_{i=1}^{N} \underbrace{\operatorname{CROSSENTROPY}\left(\mathcal{D}(i)_{\text {pred }}, \mathcal{D}(i)_{\text {label }}\right)}_{\text {over the } c \text { children for each node }} Lhard​=N1​i=1∑N​over the c children for each node CROSSENTROPY(D(i)pred ​,D(i)label ​)​​
Hard模式损失函数计算的是“路径交叉熵”

对于软决策规则,我们使用软树监督损失。在3.1节中,我们描述了软决策树如何在树叶上提供单一分布Dpred。我们在这个分布上增加了交叉熵损失。总共有2个不同的交叉熵损失项:原始交叉熵损失项和软树监督损失项。这是Loriginal +Lsoft \mathcal{L}_{\text {original }}+\mathcal{L}_{\text {soft }}Loriginal ​+Lsoft ​,其中:
Lsoft =CROSSENTROPY (Dpred ,Dlabel )\mathcal{L}_{\text {soft }}=\text { CROSSENTROPY }\left(\mathcal{D}_{\text {pred }}, \mathcal{D}_{\text {label }}\right) Lsoft ​= CROSSENTROPY (Dpred ​,Dlabel ​)
Soft模式则计算的是“叶节点交叉熵”

https://blog.csdn.net/weixin_38912070/article/details/106561422

[论文解读]NBDT: Neural-Backed Decision Trees相关推荐

  1. NIPS2018最佳论文解读:Neural Ordinary Differential Equations...

    雷锋网 AI 科技评论按,不久前,NeurIPS 2018 在加拿大蒙特利尔召开,在这次著名会议上获得最佳论文奖之一的论文是<Neural Ordinary Differential Equat ...

  2. 论文解读:Combining Distant and Direct Supervision for Neural Relation Extraction

    论文解读:Combining Distant and Direct Supervision for Neural Relation Extraction 夏栀的博客--王嘉宁的个人网站 正式上线,欢迎 ...

  3. (论文解读)High-frequency Component Helps Explain the Generalization of Convolutional Neural Networks

    目录 论文解读之: High-frequency Component Helps Explain the Generalization of Convolutional Neural Networks ...

  4. Neural Tangent Kernel 理解(一)原论文解读

    欢迎关注WX公众号,每周发布论文解析:PaperShare, 点我关注 NTK的理解系列 暂定会从(一)论文解读,(二)kernel method基础知识,(三)神经网络表达能力,(四)GNN表达能力 ...

  5. 论文解读:Semantic Neural Machine Translation using AMR

    论文解读:Semantic Neural Machine Translation using AMR   机器翻译主要得益于语义表征能力,为了缓解数据的稀疏性,作者提出一种神经机器翻译方法,利用AMR ...

  6. [Scene Graph] Neural Motifs: Scene Graph Parsing with Global Context 论文解读

    [Scene Graph] Neural Motifs: Scene Graph Parsing with Global Context (CVPR 2018) 论文解读 简介 这篇文章工作的创新之处 ...

  7. Exploring the Connection Between Binary andSpiking Neural Networks论文解读

    Exploring the Connection Between Binary andSpiking Neural Networks论文解读 前言 总说 提出B-SNN(论文中为Ⅲ) 实验和结果(论文 ...

  8. 论文解读:Improved Neural Relation Detection for Knowledge Base Question Answering

    论文解读:Improved Neural Relation Detection for Knowledge Base Question Answering   本文解决KBQA中的子问题--Relat ...

  9. 论文解读:Question Answering over Knowledge Base with Neural Attention Combining Global Knowledge Info...

    论文解读:Question Answering over Knowledge Base with Neural Attention Combining Global Knowledge Informa ...

最新文章

  1. 原代脂肪细胞提取的准备内容
  2. 开源网店系统_amazon都做不行,就不可构建外贸网店系统吗?
  3. BP+SGD+激活函数+代价函数+基本问题处理思路
  4. Ubuntu 加速安装Opencv 3.4.3
  5. js实战代码系列—周杰伦给你报时间+网页页签制作模板+jQuery初体验
  6. windows + visual studio 2010 配置SVN
  7. [洛谷P4183][USACO18JAN]Cow at Large P
  8. 一网打尽Mac上的软件套装 - Omi特别篇(附演示视频)
  9. 超人视觉助我成功转型机器视觉行业
  10. Python文本挖掘练习(一)// 新闻摘要
  11. mysql求回购率_用户行为分析——回购率、复购率(SQL、Python求解)
  12. BeyondCorp 打造得物零信任安全架构
  13. 知识管理:营建学习型团队
  14. WebDAV之葫芦儿·派盘+FX文件管理器
  15. Java中violate关键字详解(2)?真正了解violate
  16. UE4让物体始终朝向摄像机(二)—RInterp To用法
  17. IDEA 插件的设置和引用,以及插件库
  18. 西部数据硬盘 跳线 (收藏)
  19. 【Stewart并联机器人运动学逆解可视化仿真】
  20. 华为Mate30EPro和华为mate40哪个好

热门文章

  1. pygame坦克大战
  2. Anton and Fairy Tale
  3. 卓训教育:给孩子讲故事,打造学习愿景
  4. android集成sdk 马甲包,Android配置马甲包
  5. html如何画出四个圆圈,css3如何绘制一个圆圆的loading转圈动画
  6. 深度学习——深度神经网络结构
  7. grokking algorithms K-nearest neighbors第十章 K-邻近算法 中文翻译
  8. 2022年京东618活动规则:618满减规则为299减50
  9. scss exceeded maximum budget. Budget 4.00 kB was not met by 130 bytes with a total of 4.13 kB.
  10. VSCode RemoteSSH 过程试图写入的管道不存在问题 解决