论文目录

  • 0 概述
    • 0.1 论文题目
    • 0.2 摘要
  • 1 简介
  • 2 相关的工作
  • 3 提出的方法
    • 3.1 前言
      • 3.1.1 提出问题
      • 3.1.2 模型无关元学习 Model-agnostic meta-learning
    • 3.2 具有任务自适应损失函数的元学习 MeTAL
      • 3.2.1 概述
      • 3.2.2 任务自适应损失函数 Task-adaptive loss function
      • 3.2.3 架构
  • 4 实验
    • 4.1 少样本分类
      • 4.1.1 数据集
      • 4.1.2 实验结果
    • 4.2 跨域少样本分类 Cross-domain few-shot classification
      • 4.2.1 数据集
      • 4.2.2 实验结果
    • 4.3 少样本回归 Few-shot regression
    • 4.4 消融研究 Ablation studies
      • 4.4.1 学习损失函数
      • 4.4.2 任务自适应损失函数 Task-adaptive loss function
      • 4.4.3 半监督内环优化 Semi-supervised inner-loop optimization
      • 4.4.4 任务说明
    • 4.5 可视化
  • 5 结论

0 概述

0.1 论文题目

用于少样本学习的基于任务自适应损失函数的元学习(Meta-Learning with Task-Adaptive Loss Function for Few Shot Learning)

0.2 摘要

  在少样本学习的场景中,挑战在于当每个任务只有很少的标记示例可用时,在新的未见示例上泛化并表现良好。模型无关元学习(MAML)因其灵活性和对各种问题的适用性而成为具有代表性的少样本学习方法之一。然而,MAML及其变体通常采用简单的损失函数,没有任何辅助损失函数或正则化项来帮助实现更好的泛化。问题在于每个应用程序和任务可能需要不同的辅助损失函数,特别是当任务多样化和不同的时候。我们没有尝试为每个应用程序和任务手动设计一个辅助损失函数,而是引入了一个新的元学习框架,该框架的损失函数可以适应每个任务。我们提出的框架名为基于任务自适应损失函数的元学习(MeTAL) ,证明了其在不同领域的有效性和灵活性,如少样本分类和少样本回归。

1 简介

  训练深度神经网络需要大量的标记数据和相应的努力,这阻碍了它在新领域的迅速应用。因此,人们对于少样本学习越来越感兴趣,其目标是 在只有少数标记示例(如支持示例)的情况下,赋予人工智能系统学习新概念的能力。少样本学习的核心挑战是减轻深度神经网络在少数据情况下过度拟合的敏感性,并在新示例(例如查询样本)上实现泛化。

  最近,元学习,也称为“学会学习”(learning-to-learn),已成为少样本学习的主要方法之一。元学习用于少样本学习领域,以学习一个能够适应新任务并在少量数据情况下泛化的学习框架。

  图 1 基于优化的元学习框架中的内环优化概述。(a) 传统的方法,如MAML,在适应任务的过程中利用一个固定的给定的经典损失函数(如分类的交叉熵)。(b) 我们提出的方案,MeTAL,则是元学习一个损失函数,其参数 ϕ\phiϕ 在适应第 iii 个任务的第 jjj 个步骤中被调整为适应当前任务状态 τττ 。

  在元学习算法中,基于优化的元学习因其灵活性而受到不同领域的关注,能够在不同领域中应用。基于优化的元学习算法通常被表述为双层优化。在这样的公式中,外环优化训练学习算法以实现泛化,而内环优化使用学习算法使基础学习器适应具有少量样本的新任务。

  模型无关元学习(MAML)是一种开创性的基于优化的元学习方法,它学习一组初始网络权重值来实现泛化。学习到的初始化作为一个良好的起点,以适应样本少、更新少的新任务。尽管学习到的初始化被训练成一个很好的起点,但MAML经常面临着实现泛化的困难,特别是当训练和测试阶段之间的任务多样化或显著不同的时候。一些工作试图通过尝试找到更好的初始化或更好的快速适应过程(内环更新规则)。然而,这些方法在内环优化中采用一个简单的损失函数(如分类中的交叉熵),尽管其他的辅助损失函数,如ℓ2ℓ_2ℓ2​正则化项,可以帮助实现更好的泛化。

  另一方面,我们专注于为MAML框架中的内环优化设计更好的损失函数。如图1所示,我们提出了一个名为基于任务自适应损失函数的元学习(MeTAL) 的新的框架,来学习一个自适应的损失函数,从而更好的泛化每个任务。具体来说,MeTAL通过两个元学习器学习 任务自适应损失函数(task-adaptive loss function)一个元学习器用于学习损失函数,另一个元学习器用于生成转换学习损失函数的参数。我们的任务适应性损失函数设计得很灵活,因为有标签(如支持)和无标签(如查询)的样本都可以一起使用,以便在内环优化过程中使基础学习器适应每个任务。

  实验结果表明,MeTAL大大提高了MAML的泛化能力。由于MeTAL的简单性和灵活性,我们进一步证明了它不仅在不同领域,而且在其他基于MAML的算法中的有效性。当应用于其他基于MAML的算法时,MeTAL不断带来泛化性能的大幅提升,在基于MAML的算法中引入了一个新的最先进的性能。这说明了任务适应性损失函数的重要性,与初始化方案和内环更新规则相比,它受到的关注较少。总的实验结果强调,为内环优化学习一个更好的损失函数是学习一个更好的内环更新或更好的初始化的重要补充部分。

2 相关的工作

  少样本学习目的在于解决每个任务只有少数例子可用的情况。最终目标是通过这些给定的几个示例学习新任务,同时在未知示例上实现泛化。为此,元学习算法试图通过学习以前任务中的先验知识来解决少样本学习问题,然后用这些知识来适应新的任务,而不会过度拟合。

  根据先验知识的学习和任务适应过程的制定方式,元学习系统一般可以分为基于度量的学习、基于黑箱或网络的学习和基于优化的学习方法。基于度量学习的方法将先验知识编码到一个嵌入空间中,在这个空间中,相似(不同)的类别相互之间更接近(更远)。 黑箱或基于网络的方法采用网络或外部存储器直接生成权重,权重更新或预测。同时,基于优化的方法采用双层优化来学习学习过程,例如初始化和权重更新,以适应样本很少的新任务。

  在这项工作中,我们专注于模型无关元学习(MAML)算法,这是优化方法中最流行的实例之一,因为它的简单性和在不同问题领域的适用性。MAML将先验知识表述为可学习的初始化,在用给定的几个例子进行基于梯度的微调后,可以为新任务实现良好的泛化性能。虽然MAML以其简单性和灵活性而闻名,但它也以其相对较低的泛化性能而闻名。最近有研究表明,通过加强初始化的学习方案或改进基于梯度的微调过程,可以提高整体性能。

  然而,在内环优化期间,上述工作仍然只使用与任务相对应的常见损失函数(例如分类中的交叉熵)。另一方面,常见的深度学习框架通常使用辅助损失项,如ℓ2ℓ_2ℓ2​正则化项,以防止过拟合。 由于少样本学习的目标是在仅有的几个例子的适应下实现对未见过的例子的泛化,使用辅助损失项似乎是一个自然的选择。最近引入的一些方法在内环优化中应用了辅助损失函数,以降低计算成本或提高泛化能力。其他工作试图学习强化学习、监督学习的损失函数,并将无监督学习纳入少样本学习。然而,这些方法的损失函数要么有特定的任务要求,如RL中的环境交互,要么在训练后保持固定。当新的任务可能喜欢不同的损失函数时,固定的损失函数可能是不利的,特别是在训练任务和新任务有明显不同的情况下(例如跨域的少样本分类)。

  为此,我们提出了一个基于任务自适应损失函数(MeTAL)的新元学习框架。特别是,通过元网络学习特定任务的损失函数,其参数与给定任务相适应。MeTAL不仅实现了出色的性能,而且还保持了简单性,可以与其他元学习算法联合使用。

3 提出的方法

3.1 前言

3.1.1 提出问题

  我们首先介绍了在少样本学习背景下的元学习的初步情况。元学习框架假设了一个任务{Ti}i=1T\{T_i\}^T_{i=1}{Ti​}i=1T​的集合,其中每个任务都被假设来自任务分布p(T)p(T)p(T)。 每个任务TiT_iTi​由数据集DiD_iDi​的两个不相交的集合组成:支持集DiSD_i^SDiS​和查询集DiQD_i^QDiQ​。每个集合又由若干对输入x\pmb xxx和输出y\pmb yyy组成:DiS={(xis,yis)}sK=1D_i^S = \{(\pmb x^s_i , \pmb y_i^s)\}^K_s=1DiS​={(xxis​,yyis​)}sK​=1,DiQ={(xiq,yiq)}qM=1D_i^Q = \{(\pmb x^q_i , \pmb y_i^q)\}^M_q = 1DiQ​={(xxiq​,yyiq​)}qM​=1。

  元学习的目标是学习一种学习算法(由具有参数jjj的模型制定),该算法可以从任务分布p(T)p(T)p(T)中快速学习任务。然后,利用学到的学习算法,通过使用基础学习器(参数为θθθ)和使用任务支持实例DiSD_i^SDiS​来学习新的任务Ti,其由下式给出:

    θi=argθminL(DiS;θ,ϕ)\pmb θ_i = arg_{\pmb θ}\ min\ L(D^S_i;\pmb θ,\pmb \phi)θθi​=argθθ​ min L(DiS​;θθ,ϕϕ) (1)

  其中,LLL表示一个损失函数,用于评估一项任务的表现。由于支持集DiSD_i^SDiS​被用来学习一个任务,当每个任务有k个支持实例时,少样本学习通常被称为k-shot学习(∣DiS∣=K=k)(|D_i^S| = K = k)(∣DiS​∣=K=k)。
  由此产生的特定任务基础学习器由参数θi\pmbθ_iθθi​表示。然后,根据任务特定的基本学习器θi\pmb θ_iθθi​如何泛化到未知的查询示例DiQD_i^QDiQ​来 评估由ϕ\phiϕ指定的学习算法参数。因此,元学习算法的目标变成了:

    ϕ∗=argϕminETi∼p(T)[L(DiQ;θi,ϕ)]\phi^* = arg_\phi\ min \mathop{E}\limits_{T_i \sim p(T)}[L(D_i^Q;\pmbθ_i,\phi)]ϕ∗=argϕ​ minTi​∼p(T)E​[L(DiQ​;θθi​,ϕ)] (2)

3.1.2 模型无关元学习 Model-agnostic meta-learning

  MAML将先验知识编码为可学习的初始化,作为跨任务的基础学习器网络权重的一组良好的初始值。该公式中,基础学习器的元学习初始化导致了双层优化:内环优化和外环优化。 对于内环优化,基础学习器通过支持实例DiSD_i^SDiS​进行微调,从可学习的初始化θθθ开始,通过梯度下降对每个任务进行固定次数的权重更新。因此,在初始化θi,0=θ\pmbθ_{i,0}=\pmbθθθi,0​=θθ后,通过梯度下降最小化任务适应目标(等式(1))。第jjj步的内环优化表示为:

    θi,j+1=θi,j−α∇θi,jL(DiS;θi,j)\pmbθ_{i,j+1} = \pmbθ_{i,j} - \alpha\nabla_{\pmbθ_{i,j}}L(D_i^S;\pmbθ_{i,j})θθi,j+1​=θθi,j​−α∇θθi,j​​L(DiS​;θθi,j​) (3)

  然后,经过J次内环更新步骤,特定任务的基础学习器参数θi\pmbθ_iθθi​变成θi,J\pmbθ_{i,J}θθi,J​。

  在外环优化的情况下,元学习的初始化θ\pmbθθθ是由具有参数θi\pmbθ_iθθi​(或θi,J)的特定任务基础学习器在未见过的查询实例DiQD_i^QDiQ​上的泛化性能进行评估的。然后,将未知示例上评估的泛化用作反馈信号,以更新初始化θ\pmbθθθ。换句话说,MAML使元学习算法的目标最小化,如等式(2)所示,如下所示:

    θ←θ−η∇θ∑TiL(DiQ;θi)\pmbθ \leftarrow \pmbθ - \eta\nabla_{\pmbθ}\sum\limits_{T_i}L(D_i^Q;\pmb\theta_i)θθ←θθ−η∇θθ​Ti​∑​L(DiQ​;θθi​) (4)

3.2 具有任务自适应损失函数的元学习 MeTAL

3.2.1 概述

  以前的元学习公式假定对一个给定的任务Ti是完全监督的设置,他们使用支持集DiSD_i^SDiS​中的标记实例,通过最小化一个固定的给定损失函数LLL来寻找特定任务的基础学习器θiθ_iθi​。另一方面,我们的目标是控制或元学习损失函数本身,以调节整个自适应或内环优化过程,从而实现更好的泛化。我们从元学习一个内环优化损失函数Lϕ(⋅)L_\phi(·)Lϕ​(⋅)开始,该函数由一个具有元可学习参数 ϕ\pmb\phiϕϕ 的小型神经网络建模。因此等式(3)中的内环更新变为,
    θi,j+1=θi,j−α∇θi,jLϕ(τi,j)\pmb\theta_{i,j+1} = \pmb\theta_{i,j} - \alpha\nabla_{\pmb\theta_{i,j}}L_\phi(\tau_{i,j})θθi,j+1​=θθi,j​−α∇θθi,j​​Lϕ​(τi,j​) (5)
  其中τi,jτ_{i,j}τi,j​表示时间步长j处TiT_iTi​的任务状态,在典型元学习公式的情况下,通常只是支持集DiSD_i^SDiS​,如等式(3)所示。由于不同的任务(特别是在跨领域的情况下)在适应过程中可能更喜欢不同的正则化或辅助损失函数,甚至是损失函数本身,以实现更好的泛化,我们的目标是学习使损失函数本身适应每个任务。为了使元学习的损失函数具有自适应性,一种自然的设计选择可能是执行梯度下降,类似于方程(3)中基础学习者参数θiθ_iθi​的更新方式。 然而,这样的设计会导致一个大的计算图,特别是在用高阶梯度训练元学习算法时。另外,可以应用 仿射变换(Affine transformation) 来使损失函数适应给定的任务。在使特征响应自适应和使元学习初始化自适应方面,若干工作证明了以某些输入为条件的仿生变换是有效的。为了使损失函数任务具有自适应性而不产生巨大的计算负担,我们建议通过仿射变换动态变换损失函数参数ϕ\pmb\phiϕϕ:
    ϕ′=γϕ+β\pmb\phi' = \pmb\gamma\pmb\phi + \pmb\betaϕϕ′=γγϕϕ+ββ (6)
  其中,ψ\pmbψψψ是元学习损失函数参数,γ,β\pmbγ,\ \pmbβγγ, ββ是元学习器g(τj;ψ)g(τ_j;\ ψ)g(τj​; ψ)生成的变换参数,由ψ\pmbψψψ参数化。

  为了训练我们的元学习框架来泛化不同的任务,包括优化参数θ\pmbθθθ、ϕ\pmb\phiϕϕ和ψ\pmbψψψ,外环优化是在给定每个任务TiT_iTi​的情况下进行的,每个任务TiT_iTi​都是特定于任务的学习者θiθ_iθi​及其在查询集DiQD_i^QDiQ​中的示例,如中所示,
    (θ,ϕ,ψ)←(θ,ϕ,ψ)−η∇(θ,ϕ,ψ)∑TiL(DiQ;θi)(\pmb\theta,\pmb\phi,\pmb\psi) \leftarrow (\pmb\theta,\pmb\phi,\pmb\psi) - \eta\nabla_{(\pmb\theta,\pmb\phi,\pmb\psi)}\sum\limits_{T_i}L(D^Q_i;\pmb\theta_i)(θθ,ϕϕ,ψψ)←(θθ,ϕϕ,ψψ)−η∇(θθ,ϕϕ,ψψ)​Ti​∑​L(DiQ​;θθi​) (7)
  算法 1总结了我们方法的整个训练过程。

3.2.2 任务自适应损失函数 Task-adaptive loss function

  由于我们的损失元学习器LϕL_{\pmb\phi}Lϕϕ​和元学习器gψg_{\pmbψ}gψψ​是使用神经网络建模的,因此它们的输入可以被公式化为包含有关中间学习状态的辅助任务特定信息,我们将其定义为任务状态τ\pmbτττ。在给定任务TiT_iTi​的第jjj个内循环步骤中,除了经典的损失信息L(DiS;θi,j)L(D_i^S; \pmbθ_{i,j} )L(DiS​;θθi,j​)(在标记的支持集实例DiSD_i^SDiS​上评估),辅助的学习状态信息,如网络权重θi,jθ_{i,j}θi,j​和输出值f(xis;θi,j)f(\pmb x^s_i; \pmbθ_{i,j} )f(xxis​;θθi,j​),可以包含在任务状态τi,j\pmb τ_{i,j}ττi,j​中。

  此外,我们还可以在任务状态中包含来自查询集的未标记示例xiq\pmb x^q_ixxiq​的基本学习者响应,这使得内环优化可以进行半监督学习。这表明我们的框架可以利用这种额外的特定任务信息进行快速适应,这在以前基于MAML的元学习算法中很少被利用,而基于度量的元学习算法,如[21],试图利用未标记的查询实例来最大化性能。半监督内环优化最大限度地发挥了转换设置的优势(假设所有查询示例都同时可用),基于MAML的算法已经隐含地使用了直推设置(transductive setting) 以获得更好的性能[25]。算法2中组织了针对监督和半监督设置的具有任务自适应损失函数的内环优化过程。

3.2.3 架构

  对于我们的任务自适应损失函数Lϕ\pmb L_\phiLLϕ​,我们采用了一个2层MLP,层与层之间有ReLU激活,它返回一个标量值作为输出。为了提高计算效率,内环优化中使用的任务状态τi,j被表述为支持集损失的平均值L(DiS;θi,j)L(D_i^S; θ_{i,j})L(DiS​;θi,j​)、基础学习器权重的逐层平均值θi,j和基础学习器输出值的示例平均值f(xis;θi,j)f(\pmb x^s_i; \pmb θ_{i,j})f(xxis​;θθi,j​)的串联。 假设基础学习器 fff 的L层神经网络返回N维输出值(用于N向分类),任务状态τi,j\pmb τ_{i,j}ττi,j​的维度变为1+L+N,这在计算上是最小的。在半监督学习环境下,这可能会略有增加,因为额外的信息可以从基础学习者f(xiq;θi,j)f(\pmb x^q_i; \pmbθ_{i,j})f(xxiq​;θθi,j​)对未标记的查询例子的反应中得到。

  元网络gψ\large g_ψgψ​也采用了2层MLP,各层之间采用ReLU激活。该网络产生的层间仿射变换参数被应用于损失函数参数ϕ\large \phiϕ。由于我们的元学习框架没有对基础学习器f\ f f 及其目标应用施加任何约束,我们的公式是通用的,可以很容易地应用于任何基于梯度的可微学习算法。更多细节,请参考补充文件和我们的代码1。

4 实验

  在这一节中,我们对几个少样本学习问题进行了实验,如少样本分类、跨域少样本分类和少样本回归,以证实任务自适应损失函数的有效性。我们提出的MeTAL方法的所有实验结果都是在半监督的内环优化下进行的,其中标记的支持实例和未标记的查询实例一起被用于内环优化。请注意,我们没有使用额外的数据,MeTAL只是从直推设置中获得了更多的好处(所有的查询例子都是一次性的),其他的MAML变体也采用了这种方式来提高性能。

4.1 少样本分类

  在少样本分类中,每个任务被定义为N-way k-shot分类,其中N是类别的数量,k是每个类别的例子(样本)的数量。

4.1.1 数据集

  少样本分类最常用的数据集是两个ImageNet的衍生数据集:miniIm-ageNet和tieredImageNet。 这两个数据集都是由三个互不相干的子集(训练集、验证集和测试集)组成,每个子集都由大小为84×84的图像组成。数据集的不同之处在于如何将类拆分为不相交的子集。miniImageNet随机抽样并将类分组为64个类进行元训练,16个用于元验证,20个用于元测试。另一方面,tieredImageNet根据ImageNet类层次结构将类分为34类,并将组分为20类进行元训练,6类进行元验证,8类进行元测试,以尽量减少三个不相交集之间的类相似性。

4.1.2 实验结果

  我们评估了我们的方法MeTAL,并与miniImageNet和tieredImageNet上的其他MAML变体在两种典型设置下进行了比较。5-way 5-shot和5-way 1-shot分类,如 表 1 中所示。结果表明,MeTAL不仅大大提高了MAML的泛化性能,而且可以与其他MAML变体,如MAML++和ALFA结合应用,带来进一步的改善。MAML++学习固定的步骤和层级的内环学习率,而ALFA学习任务自适应内环学习率和正则化项。尽管这些方法不认为损失函数是可学习的,但如果损失函数被视为模型的一部分,那么MeTAL可以被看作是这些方法的一个更普遍的扩展。然而,MeTAL对这些方法的进一步改进表明,改进内环优化目标函数不仅仅是一个简单的扩展,而是一个互补和正交的因素。MeTAL的主要贡献在于将内环损失函数制定为可学习的和任务自适应的。 此外,MeTAL和ALFA[4]一起,大大优于其他使用较大网络(如DenseNet或WideResNet),或经过预训练或用数据增强训练的模型。这些结果表明,我们学习的任务自适应损失函数在实现更好的泛化方面是有效的。

4.2 跨域少样本分类 Cross-domain few-shot classification

  Chen等人[提出的跨域少样本分类,解决了一个更具挑战性和实用性的少样本分类场景,其中元训练任务和元测试任务是从不同的任务分布中抽取的。这样的场景被有意设计为在元训练和元测试之间创造一个大的领域差距,从而评估元学习算法对元水平过拟合的敏感性。 具体来说,如果一个元学习算法过于依赖以前看到的元训练任务的先验知识,而不是专注于给定的几个例子来学习一个新的任务,那么可以说这个算法是元过拟合的。这种元级的过拟合将导致学习系统更有可能无法适应从本质上不同的任务分布中采样的新任务。

4.2.1 数据集

  为了模拟这种具有挑战性的场景,Chen等人首先在miniImageNet上对算法进行了元训练,并在元测试期间对CUB数据集(CUB-200-2011)进行评估。与为一般分类任务编译的ImageNet相比,CUB的目标是细粒度的分类。按照[9]的协议,200个类的数据集被分成100个元训练、50个元验证和50个元测试集。

4.2.2 实验结果

  表 2 显示了MAML、MAML的一个变种ALFA和MeTAL在miniImageNet元训练集上训练并在CUB元测试集上评估的性能。与表1中概述的少样本的分类结果相似,即使在更具挑战性的跨域少样本分类场景下,MeTAL也能极大地提高泛化能力。事实上,MeTAL在跨域少样本分类(~8%)中比少样本分类(4%)中更大程度的提高了MAML和ALFA+MAML的性能。这意味着MeTAL在学习不同领域的新任务方面的有效性及其对领域差距的稳健性(鲁棒性),强调了任务自适应损失函数的重要性。 可以对另一个结果进行观察:MeTAL在ALFA+MAML上的泛化性能的提高和在MAML上的提高一样大,表明MeTAL试图解决的问题的正交性。ALFA[4]也旨在改善内环优化,但不同的是他们专注于开发一个新的权重更新规则(梯度下降)。另一方面,我们关注的是内环优化中使用的损失函数。MeTAL在不同基线和架构上的一致性泛化改进表明,设计一个更好的内环优化损失函数是重要的因素,也是对设计一个更好的权重更新规则的补充。

4.3 少样本回归 Few-shot regression

  为了证明我们的方法MeTAL的灵活性和适用性,我们在少样本回归或k-shot回归上评估MAML和MeTAL。在k-shot回归中,每个任务都是在只给出极少数(k)个采样点的情况下估计给定的未知函数。任务分布由具有目标函数的任务组成,目标函数的参数值在定义的范围内变化。在这项工作中,我们遵循Finn等人用于评估MAML的一般设置。具体来说,每个任务都有一个正弦曲线y(x)=Asin(ωx+b)y(x) = Asin(ωx+b)y(x)=Asin(ωx+b)作为目标函数,其参数值在以下范围内:振幅 A∈[0.1,5.0]A∈[0.1, 5.0]A∈[0.1,5.0],频率 ω∈[0.8,1.2]ω∈[0.8, 1.2]ω∈[0.8,1.2],相位 b∈[0,π]b∈[0, π]b∈[0,π]。对于每个任务,输入数据点 xxx 从[-5.0, 5.0]中取样。回归是通过对基础学习器执行单梯度下降来执行的,基础学习器的神经结构由3个大小为80的层组成,其间包含ReLU非线性激活函数。性能以估计输出值 y^\hat yy^​ 和真实输出值 yyy 之间的均方误差(MSE)衡量。
  表 3 概述了MAML和MeTAL在5-shot、10-shot和20-shot设置下的回归结果。MeTAL再次展示了在不同设置下的一致性能改进。这证明了MeTAL学习到的任务自适应损失函数的适用性和灵活性。

4.4 消融研究 Ablation studies

  为了研究MeTAL中每个模块的贡献,我们在本节中进行了消融研究实验。特别是,我们分析了任务状态信息、损失函数的学习、任务适应性损失函数和半监督内环优化公式的有效性。 所有的消融研究实验都是在5-way 5-shot 少样本分类下,用一个具有 4-CONV 主干的基础学习器进行的。

4.4.1 学习损失函数

  首先,我们分析了学习内环优化损失函数的重要性。详细来说,当内环优化是用一个没有经过自适应学习的损失函数(即模型(2),(3)仅使用元网络LϕL_\phiLϕ​)时,对性能进行测量,并与使用简单的交叉熵时(即MAML表示为模型(1) )进行比较。

  表 4 中总结的消融研究结果显示,学习的损失函数帮助MAML实现了更好的泛化,这表明元学习器已经成功地学习了对泛化有用的损失函数。此外,当交叉熵和学习损失一起使用时,与只使用学习损失时没有明显区别,这意味着学习损失能够保持作为输入的交叉熵损失信息。

4.4.2 任务自适应损失函数 Task-adaptive loss function

  然后,我们研究任务自适应损失函数对整个建议框架的影响。为此,我们使用元模型gψ\large g_ψgψ​来生成仿射变换参数,然后根据公式(6)来调整表4中模型(2)的损失函数元网络Lϕ\large L_\phiLϕ​的参数。得出的元学习算法,即没有半监督内环优化的MeTAL,在 表 5 中表示为模型(4)。如表所示,与固定的学习函数相比,元学习算法得益于任务自适应学习损失函数。

4.4.3 半监督内环优化 Semi-supervised inner-loop optimization

  接下来,我们研究了半监督内环优化公式的有效性。与任务自适应损失函数消融研究类似,我们首先推导出一个新的模型,该模型是通过在表4的模型(2)中加入半监督内环优化公式(将有标签的支持例子和无标签的查询例子放在一起,通过学习的损失函数进行快速适应)而创建的。因此,与我们的最终方法MeTAL相比,所产生的模型,即模型(5),缺乏任务自适应的特性。虽然半监督内环优化有助于提高性能,但它仍然落后于完整算法MeTAL(表示为模型(6)),这说明了任务自适应损失函数的重要性。

4.4.4 任务说明

  我们进行了另一项消融研究,以调查任务状态 τ\tauτ 的每个因素的影响:即基础学习器θi,j\pmb θ_{i,j}θθi,j​的当前权重值,网络的输出(f(xis;θi,j)(f(\pmb x^s_i ; \pmb θ_{i,j})(f(xxis​;θθi,j​)为支持,f(xiq;θi,j)f(\pmb x^q_i ; \pmb θ_{i,j})f(xxiq​;θθi,j​)为查询),以及支持集L(DiS;θi,j)=L(f(xis;θi),yis)L(D_i^S; \pmbθ_{i,j} )=L(f(\pmb x^s_i ; \pmb θ_i), \pmb y_i^s)L(DiS​;θθi,j​)=L(f(xxis​;θθi​),yyis​) 的交叉熵损失。消融结果总结在 表 6 中。当支持实例上的原始交叉熵不包括在任务状态中时,整个内环优化就变成了无监督学习设置,因为在内环优化过程中没有涉及到基础真值信息。在这种情况下,正如人们所期望的那样,MeTAL很难实现泛化,因此在表中排除了这些结果。

4.5 可视化

  图 2 说明了我们提出的元网络 g\large gg 在每个内环步骤的任务(用方框图表示)之间生成的仿射变换参数 γγγ 和 βββ。观察生成的γγγ和βββ值如何在内环步骤中变化,我们可以认为,随着内环优化过程中学习状态的变化,MeTAL可以动态调整损失函数。此外,生成的参数值在不同的任务中也是不同的,特别是在最后一个内环步骤中。这可能意味着整个框架经过训练,可以在最后一步的任务中发挥最大的作用。无论如何,任务之间生成的仿射变换参数值的动态范围验证了MeTAL在适应给定任务的损失函数方面的有效性。

   图2 我们提出的一个元网络g\large gg生成的仿生变换参数γγγ和βββ的图示。然后,使用这些值使损失元网络 lll 适应给定的任务。特别是,该图显示了损失元网络 lll 的第一层的权重的生成值。生成的值显示了它在整个内环步骤中的动态范围,这表明在每个步骤中都有不同的损失函数被选中。此外,在不同的任务中观察到不同的值,特别是在最后一个内环步骤中,暗示了任务对内环优化损失函数的不同偏好。

5 结论

  在这项工作中,我们提出了一个用于少样本学习的具有任务自适应损失函数的元学习框架。所提出的方案被命名为MeTAL,在内环优化过程中,根据当前的任务状态,学习一个适应每个任务的损失函数。因此,MeTAL能够学习每个任务特别需要的损失函数,以便更好地泛化。此外,MeTAL的灵活性不仅使其能够应用于不同的MAML变体和问题域,还允许半监督内循环优化,其中标记的支持示例和未标记的查询示例共同使用以适应任务。总的来说,实验结果强调了为每项任务学习一个好的损失函数的重要性,与少样本学习中的权重更新规则或初始化相比,该损失函数引起的关注相对较少。

论文阅读-2 | Meta-Learning with Task-Adaptive Loss Function for Few Shot Learning相关推荐

  1. 论文阅读——TR-GAN: Topology Ranking GAN with Triplet Loss for Retinal Artery/Vein Classification

    论文阅读--TR-GAN: Topology Ranking GAN with Triplet Loss for Retinal Artery/Vein Classification 基于对抗神经网络 ...

  2. 论文阅读 [TPAMI-2022] Ball $k$k-Means: Fast Adaptive Clustering With No Bounds

    论文阅读 [TPAMI-2022] Ball kkkk-Means: Fast Adaptive Clustering With No Bounds 论文搜索(studyai.com) 搜索论文: B ...

  3. 3D Instance Embedding Learning With a Structure-Aware Loss Function for Point Cloud Segmentation

    Abstract 这封信提出了一个在点云上进行3D实例分割的框架.使用3D卷积神经网络作为主干,同时生成语义预测和实例嵌入.除了嵌入信息,点云还提供反映点之间关系的3D几何信息.考虑到这两种类型的信息 ...

  4. 【论文阅读】 AdaptivePose: Human Parts as Adaptive Points

    DOI:https://doi.org/10.1609/aaai.v36i3.20185 AAAI 2022         Published:2022-06-28 Others阅读/整理:翻译1. ...

  5. 【论文阅读】AU检测|《Deep Adaptive Attention for Joint Facial Action Unit Detection and Face Alignment》

    <Deep Adaptive Attention for Joint Facial Action Unit Detection and Face Alignment>(ECCV 2018) ...

  6. 论文阅读笔记:《Neural3D: Light-weight Neural Portrait Scanning via Context-aware Correspondence Learning》

    Neural3D: Light-weight Neural Portrait Scanning via Context-aware Correspondence Learning 论文动机 方法 整体 ...

  7. 【论文阅读】【基于方面的情感分析】Deep Context- and Relation-Aware Learning for Aspect-based Sentiment Analysis

    文章目录 Deep Context- and Relation-Aware Learning for Aspect-based Sentiment Analysis 一.该论文关注的是解决ABSA问题 ...

  8. 【深度学习与智能反射面:论文阅读】:Enabling Large Intelligent Surfaces with Compressive Sensing and Deep Learning

    文章目录 前言 中心思想 具体实现 A.==COMPRESSIVE SENSING== BASED 智能反射面设计 B. ==DEEP LEARNING BASED== LIS I智能反射面设计 实现 ...

  9. 论文阅读(联邦学习):Exploiting Shared Representations for Personalized Federated Learning

    Exploiting Shared Representations for Personalized Federated Learning 原文传送门 http://proceedings.mlr.p ...

最新文章

  1. CommonJS规范与AMD规范的理解
  2. 解决SQL Server里sp_helptext输出格式错行问题
  3. Mysql 5.5 源码安装
  4. IT运维管理与ITIL
  5. C语言库函数大全及应用实例六
  6. vs中工具箱代表的意思_“日”除了代表太阳,其实还有这种意思,特别是出现在这些词语中的“日”...
  7. 检验密码强度的JS类(from thin's blog)
  8. Pwn2Own 2020 曝出的Linux 内核漏洞已修复
  9. atitit.基于bat cli的插件管理系统.doc
  10. 网络学习笔记之路由器基本命令行操作
  11. php面试题之三——PHP语言基础(基础部分)
  12. 20145213《Java程序设计》第五周学习总结
  13. Cell Host | 张群业/王哲/张澄-肠道微生物群失调加重腹主动脉瘤
  14. Python-pptx Slides
  15. QQ20岁:20年版本迭代只做一件事情!
  16. 书终于出来了:《Unity3D平台AR与VR开发快速上手》
  17. 乐视再次被外媒质疑,消息称FF首款汽车将无法按时出货
  18. POJ_1845_Sumdiv_各种数学
  19. Ardunio开发实例-磁簧开关
  20. Chrome浏览器主页被劫持的解决

热门文章

  1. PCIe资料分享-快速入门
  2. 情感饥渴导致“80一代”不敬业过上“跳蚤人生” 《转》
  3. MySQL几种编码格式的区别(utf8、utf8mb4、utf8mb4_general_ci、utf8mb4_unicode_ci 、utf8mb4_0900_ai_ci)
  4. Linux QQ 解决闪退的方法
  5. 领1000元红包,来钉钉领智办事开工福利!
  6. 消防设备电源监控系统对老旧小区升级的意义
  7. java3d dda算法,DDA.MidPoint.Bresenham三种算法的实现
  8. 移动设备和Sharepoint 2013 - 第四部分:位置
  9. 新版白话空间统计(20)空间关系概念化之点临近
  10. 【MicroPython ESP32】1.8“tft ST7735带中文驱动显示示例