元学习是解决小样本学习问题的重要方法之一,现已取得较为优异的成绩。元学习方法大体上可以分为基于优化的和基于度量两种。基于度量的方法是非参数方法,包括孪生网络、关系网络、匹配网络等。基于优化的方法是参数化方法,典型代表之一是MAML(Model-Agnostic Meta-Learning)。MAML在训练任务上学习一个易于调节的初始化参数,面对新的测试任务时迁移该初始化参数,并利用梯度下降法微调该参数,以达到较好的效果。MAML算法思路简捷、效果优异,近年来产生了诸多变体。下面将带大家梳理其中较为典型的改进方法。

文章目录

  • MAML算法回顾
    • MAML
    • FOMAML
  • 提高运行速率
    • Reptile
    • DKT
  • 提高预测精度
    • MTNET
    • CAVIA
    • Pruning
    • TAML

MAML算法回顾

MAML

论文地址:https://arxiv.org/pdf/1703.03400.pdf

MAML内层循环(算法流程图中step4-step6)将θ\thetaθ向着最适合每个任务的方向更新为θi\theta _{\rm{i}}θi​(support集上),并在query集上计算损失和。在外层循环(step8)中,利用一批任务的损失共同更新θ\thetaθ。如下所示,(1)是内层更新,(2)是外层更新。值得注意的是,与预训练不同,MAML的初始化参数不是针对当前任务的最优参数,而是最易于调节的参数,该参数只需几步就能在新任务上达到最优,易于调节的性能依赖于在support上训练,在query上更新这一思路。
θi′=θ−α∇θLTi(fθ)(1){\theta _{\rm{i}}}^\prime = \theta - \alpha {\nabla _\theta }{\cal L_{{\cal T_i}}}({f_\theta }) \tag{1} θi​′=θ−α∇θ​LTi​​(fθ​)(1)
θ←θ−β∇θi′∑Ti∼p(T)LTi(fθi′)(2)\theta \leftarrow \theta - \beta {\nabla _{{\theta _i}^\prime }}{\sum _{{{\cal T_i}}\sim{p(\cal T)}}{\cal L_{{\cal T_i}}}({f_{{\theta _i}^\prime }})} \tag{2} θ←θ−β∇θi​′​Ti​∼p(T)∑​LTi​​(fθi​′​)(2)
具体流程如下所示:

FOMAML

原作者在MAML的基础上提出FOMAML,区别在(2)中求导对象不同,FOMAML无需计算二阶导,推导过程利用了多元函数的链式求导法则。
θi′=θ−α∇θLTi(fθ)(1){\theta _{\rm{i}}}^\prime = \theta - \alpha {\nabla _\theta }{\cal L_{{\cal T_i}}}({f_\theta })\tag{1} θi​′=θ−α∇θ​LTi​​(fθ​)(1)
θ←θ−β∇θi′∑Ti∼p(T)LTi(fθi′)(2)\theta \leftarrow \theta - \beta {\nabla _{{\theta _i}^\prime }}{\sum _{{{\cal T_i}}\sim{p(\cal T)}}{\cal L_{{\cal T_i}}}({f_{{\theta _i}^\prime }})}\tag{2} θ←θ−β∇θi​′​Ti​∼p(T)∑​LTi​​(fθi​′​)(2)


从做法来看,MAML的改进策略有传统数理方法(简化二阶导,FOMAML;隐函数积分,iMAML等)、计算机方法(MAML++等)、贝叶斯方法(BMAML等)以及强化学习(ESMAML)、在线学习和其他方法。
然而,这样的分类方法太过冗杂。从解决问题的角度,我将MAML的改进思路分为两种:提高运行速率和提高预测精度。下面依次介绍最经典的几个代表:

提高运行速率

Reptile

论文地址:https://arxiv.org/pdf/1803.02999.pdf

Reptile是最早的改进方法之一。它省略了外层循环,在support∪query集上多次求导,每次求导的方向是Fast weight的方向,其最终的更新方向是多次求导的矢量和与原参数的线性组合,也就是Slow weight的方向。内层循环如下所示:
ϕ←ϕ+ε1n∑i=1n(ϕ~i−ϕ)\phi \leftarrow \phi + \varepsilon \frac{1}{n}\sum\limits_{i = 1}^n {({{\tilde \phi }_i} - \phi )} ϕ←ϕ+εn1​i=1∑n​(ϕ~​i​−ϕ)

DKT

论文地址:https://arxiv.org/pdf/1910.05199.pdf

算法流程图:
DKT(深度核迁移)方法把模型初始化参数认为是点估计的先验信息,通过先验和似然来估计后验分布。之前的最小化损失函数等价于这里的最大化似然函数。

该方法从贝叶斯定理角度出发,为MAML提供概率解释和不确定性度量。面对新任务时,不光迁移模型初始化参数ϕ\phiϕ,同时迁移高斯核参数θ\thetaθ。与Reptile类似,该方法只需一层循环。具体推导采用第二类最大似然法(ML-Ⅱ),把P(Tty∣Ttx,θ^,ϕ^)P({\cal T}_t^y|{\cal T}_t^x,\boldsymbol{\hat \theta} ,\hat \phi )P(Tty​∣Ttx​,θ^,ϕ^​)写成积分形式并用条件概率公式展开即可。

本文的另一个创新点在于考虑了跨域问题,即训练任务和测试任务分别取自不同的数据集。


提高预测精度

MTNET

论文地址:https://arxiv.org/pdf/1801.05558.pdf

该方法认为外层循环要保证所有任务总损失最小,这样损失了一个自由度,会导致每个任务梯度更新不够灵活。因而在外层循环中再学习一个矩阵(T-net),相当于对原始参数的线性变换,投影到子空间上。另外,该方法还学习了一类随机变量,该随机变量生成MASK矩阵,决定每个训练任务上更新哪些层,这样减少了过拟合的风险(MT-net)。MT-net如下所示:

CAVIA

论文地址:https://arxiv.org/pdf/1810.03642.pdf

从MAML内外层更新的思路来看,MAML和DKT都假定每个任务的所有参数都是任务特定的,需要在内层循环中更新,而CAVIA则假定每个任务的参数分为任务共享的部分和任务特定的部分。

该方法将需要更新的参数分为任务相关的部分(ϕ\phiϕ)和任务共享的部分(θ\thetaθ)两种。任务相关的参数又叫上下文参数,只在内层循环中更新,任务共享的参数则在外层循环中更新。对于测试任务,只做内层循环,更新任务特定的部分。这样就避免了过拟合问题。


如上图所示,神经元的输入取决于上一层的神经元和上下文参数。
hi(l)=g(∑j=1Jθj,i(l,h)hj(l−1)+∑k=1Kθk,i(l,Φ)Φ0,k+b){h_i}^{(l)} = g(\sum\limits_{j = 1}^J {{\theta _{j,i}}^{(l,h)}{h_j}^{\left( {l - 1} \right)} + } \sum\limits_{k = 1}^K {{\theta _{k,i}}^{(l,\Phi )}{\Phi _{0,k}}} + b)hi​(l)=g(j=1∑J​θj,i​(l,h)hj​(l−1)+k=1∑K​θk,i​(l,Φ)Φ0,k​+b)
作者阐述了该方法在FNN、CNN和RL的应用,在CNN中,作者利用FilM仿射变换(论文地址:https://arxiv.org/pdf/1709.07871.pdf)学习上下文参数。

作者在实验中还阐述了CAVIA对内层循环的学习率α\alphaα具有很好的鲁棒性,在sine实验的结果如下图所示:

Pruning

论文地址:https://arxiv.org/pdf/2007.03219.pdf

该方法利用了元学习剪枝的思想,又称为dense-sparse-dense (DSD)。基于Reptile,预训练一个初始化权重,在每个任务上训练时利用MASK选择一部分参数更新,然后再整体训练几轮,这样就减少了任务的过拟合问题。

TAML

论文地址:https://arxiv.org/pdf/1805.07722.pdf

TAML认为不同的任务对优化起的作用是不同的,这种重要性的度量可以用熵变或者经济学中的一些指标度量。算法图如下:


该方法对损失函数稍加改进,有效地平衡了不同任务的贡献度,在分类问题上取得了较为良好的效果。

基于MAML的改进方法总结相关推荐

  1. 文献学习(part52)--基于泛岭估计对岭估计过度压缩的改进方法

    学习笔记,仅供参考,有错必纠 文章目录 基于泛岭估计对岭估计过度压缩的改进方法 摘要 引言 岭估计方法 岭估计的主要问题 改进的岭估计方法 基于泛岭估计对岭估计过度压缩的改进方法 摘要 岭估计是解决多 ...

  2. 基于能耗均衡的LEACH改进方法

    文章目录 一.理论基础 1.基于能量的簇头选择阈值 2.算法描述 二.仿真与结果分析 1.仿真参数 2.结果分析 三.参考文献 一.理论基础 1.基于能量的簇头选择阈值 由于簇的规模和簇头选择对WSN ...

  3. 《基于场景的工程方法》作者问答录

    <基于场景的工程方法>(Scenario-Focused Engineering,本书中文版正在翻译中)一书描述了在开发与交付基于软件的产品时,一种以客户为中心的精益与敏捷方法.本书所描述 ...

  4. 《机器学习实战》笔记(04):基于概率论的分类方法 - 朴素贝叶斯分类

    基于概率论的分类方法:朴素贝叶斯分类 Naive Bayesian classification 这大节内容源于带你理解朴素贝叶斯分类算法,并非源于<机器学习实战>.个人认为<机器学 ...

  5. faster rcnn resnet_RCNN, Fast R-CNN 与 Faster RCNN理解及改进方法

    RCNN 这个网络也是目标检测的鼻祖了.其原理非常简单,主要通过提取多个Region Proposal(候选区域)来判断位置,作者认为以往的对每个滑动窗口进行检测算法是一种浪费资源的方式.在RCNN中 ...

  6. 【文献学习】强化学习1:基于值函数的方法

    参考文献: [1]<机器学习>,周志华(西瓜书) [2]<强化学习>,邹伟,等(鳄鱼书) (今天看书总是走神,干脆总结一下,希望帮自己理清思路.如果碰巧能被大神看到,如有不正确 ...

  7. 基于matlab的prony方法实现,基于MATLAB的Prony方法实现

    基于MATLAB的Prony方法实现 本文介绍了Prony方法在MATLAB中的实现和应用.首先叙述了Prony方 (本文共2页) 阅读全文>> 瞬时频率是信号重要的瞬时特征参数,由于其在 ...

  8. cnn 回归 坐标 特征图_RCNN, Fast R-CNN 与 Faster RCNN理解及改进方法

    RCNN 这个网络也是目标检测的鼻祖了.其原理非常简单,主要通过提取多个Region Proposal(候选区域)来判断位置,作者认为以往的对每个滑动窗口进行检测算法是一种浪费资源的方式.在RCNN中 ...

  9. 基于麻雀算法改进的无线传感器网络Dv-hop定位算法 - 附代码

    基于麻雀算法改进的无线传感器网络Dv-hop定位算法 文章目录 基于麻雀算法改进的无线传感器网络Dv-hop定位算法 1.DV-Hop算法原理 2.麻雀算法改进DV-Hop算法原理 3.算法测试 4. ...

最新文章

  1. 基于CAN总线的家居安防系统设计
  2. JavaScript的对象
  3. 海量数据库的查询优化及分页算法方案(一)
  4. axure原件 总是丢失_Axure实现提示文本单击显示后自动消失的效果
  5. django报表系统_django使用echarts
  6. Python 获取当前时间或当前时间戳,通过时间戳获取hash
  7. C#LeetCode刷题之#819-最常见的单词(Most Common Word)
  8. C++之继承探究(九):多态的代价
  9. mysql可以装到其他端口吗_linux下怎么在另一个端口安装高版本mysql
  10. hashMap的具体实现
  11. php sns 源码,ThinkSNS v4
  12. 青龙面板之【追书神器】——5.29
  13. APP游戏运营:如何运用数据来指导手游运营
  14. matlab tdb,计算相图中的TDB文件 - 计算模拟 - 小木虫 - 学术 科研 互动社区
  15. java.lang.IllegalArgumentException: 字符[_]在域名中永远无效。 at
  16. spring boot整合JDBC
  17. 2021程序员笔记本电脑推荐
  18. 小学计算机上课课前导入视频教程,小学信息技术教学中微视频的导入实践分析...
  19. 加强中学理化生实验室建设要求,深化教学改革
  20. java 用验证码的形式验证邮箱

热门文章

  1. MA、EMA、SMA的区别
  2. java体温_java实现体温单实例-eclipse-java工程
  3. 【Cinemachine】VirtualCamera虚拟相机详解(一)
  4. 上海交大工科试验班计算机科学与技术,【专业分流】上海交通大学关于2019级工科平台和自然科学试验班专业分流结果公示的通知...
  5. react 添加css_在JS中使用情感CSS将暗模式添加到您的React应用中
  6. idea 2019.1.3注册码(亲测可用)
  7. CarbonData部署和使用
  8. 苹果无线网服务器改什么速度快,苹果改dns提高网速(iphone国内最快的dns)
  9. pytdx 调用沪深300 所有股票实时行情
  10. 编译compile和连接Link