这篇文章是MAML的升级版本,即MAML++。他针对MAML的一些不足之处做了对应的改进,如稳定性、收敛速度、表现力等均得到提升。
由于自己的算法实现中有用到MAML,为了让整体算法有一个好的性能,就来阅读了下这篇MAML升级版——MAML++。

参考列表:
①MAML++论文解读
②模拟退火算法简介
③模拟退火算法详解
④MAML++PyTorch源码

How to Train your MAML

  • 前言
    • MAML的不足之处
    • MAML vs MAML++
  • MAML
    • MAML简介
    • MAML的功能
    • MAML核心公式
  • MAML++
    • 存在的6个问题
      • Q1 Training Instability
      • Q2 Second Order Derivative Cost
      • Q3 Absence of Batch Normalization Statistic Accumulation
      • Q4 Shared (across step) BN Bias
      • Q5 Shared Inner Loop (across step and parameter) Lr
      • Q6 Fixed Outer Loop Lr
    • MAML++的解决方案
      • Multi-Step Loss Optimization
      • Derivative Annealing
      • LSLR
      • CA
  • 实验结果
  • 总结

前言

作者指出MAML的不足之处,针对这些不足之处做出相应改善,形成MAML++算法。

MAML的不足之处

  1. 对网络结构敏感。
  2. MAML训练不稳定。
  3. MAML需要进行较大的超参数调节。
  4. 计算量大。

MAML vs MAML++

Figure 1:

这是一组在MAML和MAML++之间对比的结果,可以看出MAML++在训练过程中更加稳定、收敛速度更快且表现力更强(正确率)。

MAML

MAML是2017年的paper,应该是比较熟悉的一篇文章了。

MAML简介

①关于MAML论文笔记,点这里。
②关于MAML简化版本FOMAML,点这里。
③关于MAML升级版本Meta-SGD,点这里。

MAML的功能

学习到一个既具备可以适应到很多环境(或者叫task),又可以快速适应(指内更新只更新一次)到具体某个新的task上的合适的参数。
合适的含义:假设参数都是一维的

  1. MAML的训练结果是蓝色的 ϕ \phi ϕ,而不是红色的 ϕ \phi ϕ。
  2. 虽然红色的 ϕ \phi ϕ在2个 l o s s loss loss上拥有更低的值,但是MAML训练的目标是要降低Fast-weights θ \theta θ在各自task上的 l o s s loss loss。
  3. 显然蓝色的 ϕ \phi ϕ在经过一次内更新之后可以更快速地使 θ 1 \theta_1 θ1​、 θ 2 \theta_2 θ2​在各自地task上达到最低的 l o s s loss loss;而红色的 ϕ \phi ϕ很难通过一次更新使得 θ 2 \theta_2 θ2​到达属于task2的最低 l o s s loss loss点。
  4. 因此MAML训练的核心就是训练出一个Meta-Learner的参数 ϕ \phi ϕ,要求是这个 ϕ \phi ϕ经过一次内更新之后得到的 θ \theta θ在各自的task上达到最低的 l o s s loss loss值。这里要注意的是——不是 ∑ i L T i ( f i ( ϕ ) ) \sum_i\mathcal{L}_{\mathcal{T}_i}(f_i(\phi)) ∑i​LTi​​(fi​(ϕ))最低,而是要 ∑ i L T i ( f i ( θ i ) ) \sum_i\mathcal{L}_{\mathcal{T}_i}(f_i(\theta_i)) ∑i​LTi​​(fi​(θi​))最低。
  5. 综上所述,也许当前蓝色 ϕ \phi ϕ作为初始参数不是很好,因为此时2个task的 l o s s loss loss都比较高;但是这个 ϕ \phi ϕ对于2个 θ \theta θ来说是不错的,因为接下来2个task可以很容易就降低 l o s s loss loss;MAML训练的目标不是在个别task上做到最优参数 θ ∗ \theta^* θ∗,而是可以适应到所有task上,能让他们更快的去学习,这也就是为什么MAML的 L o s s Loss Loss设计成所有task的 l o s s loss loss之和的平均。

MAML核心公式

设 θ = θ 0 \theta=\theta_0 θ=θ0​是Meta-Learner的参数 ϕ \phi ϕ的初始值,一般都做随机化处理; S b S_b Sb​是support-set; θ i b \theta^b_i θib​是task b在第 i i i次更新之后的Learner参数,故对于每一个task b,MAML的内更新为:
θ i b = θ i − 1 b − α ∇ θ L S b ( f θ i − 1 b ( θ i − 1 b ) ) (1) \theta^b_i=\theta^b_{i-1}-\alpha\nabla_\theta\mathcal{L}_{S_b}(f_{\theta^b_{i-1}}(\theta^b_{i-1}))\tag{1} θib​=θi−1b​−α∇θ​LSb​​(fθi−1b​​(θi−1b​))(1)虽然MAML的特点就是更新一次,并不是说只能更新一次。
设 θ N b ( θ 0 ) \theta_N^b(\theta_0) θNb​(θ0​)表示从 θ 0 \theta_0 θ0​开始,在task b上内更新N次到达的参数值; T b T_b Tb​指的是在Query-set,故MAML外更新的目标函数为:
L m e t a ( θ 0 ) = ∑ b = 1 B L T b ( f θ N b ( θ 0 ) ( θ N b ) ) (2) \mathcal{L}_{meta}(\theta_0) = \sum^{\mathcal{B}}_{b=1}\mathcal{L}_{T_b}(f_{\theta^b_N(\theta_0)}(\theta^b_N))\tag{2} Lmeta​(θ0​)=b=1∑B​LTb​​(fθNb​(θ0​)​(θNb​))(2)外更新的目标函数和内更新不一样,外更新是在所有task上的平均值。需要注意的 θ N b ( θ 0 ) \theta^b_N(\theta_0) θNb​(θ0​)是在Query-set做的update,FOMAML就是在此基础上对MAML进行简化——省去了二阶导数的求取。
有了损失函数,就可以做MAML的外更新:
θ 0 = θ 0 − β ∇ θ ∑ b = 1 B L T b ( f θ N b ( θ 0 ) ) (3) \theta_0 = \theta_0 - \beta\nabla_\theta\sum^\mathcal{B}_{b=1}\mathcal{L}_{T_b}(f_{\theta^b_N(\theta_0)})\tag{3} θ0​=θ0​−β∇θ​b=1∑B​LTb​​(fθNb​(θ0​)​)(3)可以看出来标准的MAML算法对网络权值的更新是基于最后一个时间步 N N N上所有测试集上损失函数的平均值;每一次时间步从 0 → N 0\to N 0→N都是在同一个测试集上完成的。

MAML++

文章的论述结构就是指出MAML的6个问题,然后针对6个问题提出改进方案,也就是说改进了MAML一路过关升级,灭掉自己存在6个问题之后,就能升级到MAML++。

存在的6个问题

Q1 Training Instability

作者指出MAML的训练很不稳定。我们在做外循环的时候,由于Fast_weights需要多次通过网络产生,那么在从外循环backward的时候,就会从外循环梯度回传到内循环,并且网络的每一层都会被回传好几次,这样就会容易产生梯度爆炸或者梯度衰减,就很麻烦。并且如果你的网络结构的深度很大时,那么别说整个网络会被传几次的问题了,光传完一个网络就可能出现梯度爆炸或者梯度衰减。那么当梯度值会出现较大的问题时,参数的更新自然就会出事,那么通过这个参数产生的结果出现不稳定现象也是正常的,详情见上述Figure 1。针对梯度不稳定现象,作者提出了MSL(详情见后文)来解决。

Q2 Second Order Derivative Cost

标准MAML使用全二阶偏导下的参数更新,虽然理论上会很好的泛化性能,但是对计算资源与训练时间的消耗是不友好的。

Q3 Absence of Batch Normalization Statistic Accumulation

由于这部分在自己的算法中没有涉及,就没有去详细了解。

Q4 Shared (across step) BN Bias

由于这部分在自己的算法中没有涉及,就没有去详细了解。

Q5 Shared Inner Loop (across step and parameter) Lr

MAML有一个升级版本叫Meta-SGD,是一种去自己学习MAML内更新学习率 α \alpha α的元学习算法。该算法实现简单,但是由于需要针对model的每一个参数去学习一个learning rate,故对于计算资源以及存储资源的消耗是很大的,尤其是当你的model具有很大的参数的时候,Meta-SGD确实需要消耗一部分存储资源。

Q6 Fixed Outer Loop Lr

标准MAML的外循环采用固定学习率 β \beta β,这就意味着需要花费一定的时间去找到一个合适的学习率,这种静态学习率会使得算法不那么灵活。

MAML++的解决方案

Multi-Step Loss Optimization

这个MSL是用于解决MAML不稳定性的,也是全文这么多技巧中最有效的方法,它主要由2部分组成:

  1. 标准的MAML算法如公式(3)所示,meta-learner参数的更新取决于所有测试集上最后一步的loss值,因此你拿着这个值去做backward的时候,之前几个step的参数只能被隐式的优化,他们产生的loss无法直接拿来显式的优化,因此算法的稳定性就会欠缺一些。因此其核心思想就是改单步(最后一个step)为多步(Multi-Step)。属于是用消耗训练时间、计算量的代价来换取算法的稳定性,因为你反向传播的次数会变多。具体的loss结构如下:纵向是一个子list,里面是同一个时间步,但是每个loss来源于不同的测试集;横向是不同的几个子list,表示同一个测试集的不同时间步,绿色的最后一个时间步的loss。经典的MAML就是将绿色的全部加起来然后做backward,而MAML++的做法是将所有的格子都加起来,然后从左到右不同颜色的格子分别赋予不同的权值,相同颜色的权值一样(从左到右依次是时间步 0 → N 0\to N 0→N)。

  2. 另一个改进是加入了退火权重因子 v i v_i vi​如上图所示,这个因子针对不同的loss会有不同的值,越是往后的step具有更大的权重分配,毕竟后面的step更加重要,更加准确嘛,所以理应得到更大的关照。我自己在实现的时候,使用的是指数式滑动平均,类似于软更新 τ \tau τ的做法。

总的来说,MSL的提出使得MAML的稳定性得到提升,同时算法的性能也得了加强!

Derivative Annealing

导数退火算法:核心思想就是标准MAML算法(存在二阶偏导计算)和FOMAML(将标准MAML简化成一阶偏导)共同使用。
标准MAML采用二阶偏导:优点是计算准确,泛化能力强;缺点是backward消耗时间久,对计算资源消耗较大。
FOMAML采用一阶偏导近似二阶:优点是计算速度快;缺点是泛化能力不如二阶偏导,毕竟是近似的,就不那么准确。
那么作者的提出的思想也很简单:就是前50个episodes采用FOMAML,之后的训练采用二阶偏导。这么做的好处就是既能加快训练速度,又可以保证足够的泛化能力。这种做法包含着退火的思想,“粗糙”的近似能让模型得到更好的训练。
作者还指出一个有意思的发现:DA比单独使用标准MAML算法更加稳定,FOMAML的训练相当于标准MAML的一种预训练,让后期MAML对模型的训练更好之外,可以避免标准MAML算法出现梯度衰减、梯度爆炸现象。

LSLR

LSLR主要针对的是Li et al.(2017)提出的Meta-SGD算法并做出相应改善(关于Meta-SGD的论文以及笔记):
如上图所示,Meta-SGD针对网络每个参数都会去学习一个 l e a r n i n g − r a t e learning-rate learning−rate以及一个搜索方向。本文作者提出这样比较消耗计算资源,相应的提出针对网络的每个层去学习一个 l e a r n i n g − r a t e learning-rate learning−rate以及一个搜索方向。

CA

余弦退火调整学习率(Cosine Annealing scheduling):
余弦退火是一种简单有效的动态调整学习率的算法,其主要利用余弦函数缓慢下降、迅速下降、再缓慢下降的函数特征来映射到学习率的变化上,其中的退火指的是学习率在下降到余弦函数的最低点之后会立即回到最大值点,开启新的一个周期,这样的话就可以帮助模型逃离出次优点,去寻找更好的次优点(甚至全局最优),如下图所示(图来自上面的链接):

这里讲的是对MAML算法的外循环,也就是Meta-leraner训练的优化。之前标准的MAML算法采用静态(固定)的外循环学习率。固定的学习率使得超参数的调节十分苦难而且不灵活。作者指出CA算法可以使得外循环学习率 β \beta β动态可变,而且可以提高算法的性能。
PyTorch提供了有关余弦退火的实现:

  1. 一篇不错的参考文章点这里。
  2. 官网介绍:①torch.optim.lr_scheduler.CosineAnnealingLR、②torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
  3. 还有我自己整理的一张用法框架说明:

实验结果

从实验结果来看,单个技巧对MAML的提升不明显,有的甚至不如标准MAML,但是组合拳的效果对标准MAML有了较大的提升。
提升效果如下:
MAML++在泛化性能(表现力)和稳定性上都远胜于标准MAML算法!

总结

  1. MSL可用于提升MAML的稳定性以及泛化性能,但会降低算法速率以及增加计算资源消耗。
  2. DA可以增加MAML训练效率以及缓解梯度爆炸、消失问题。
  3. LSLR可以减少Meta-SGD给MAML算法带来增加的存储消耗问题,同时可以有一个可学习的内循环学习率 α \alpha α。
  4. CA可以为MAML带来动态调整的外更新学习率 β \beta β,可以帮助算法缓解次陷入次优解的问题。

以上就是MAML++带来的几个小技巧,大家在用MAML的时候可以看情况往自己的MAML里添加,总的来说还是几个不错的idea!

Meta-Learning之How to train your MAML相关推荐

  1. 8.7 Meta learning元学习全面理解、MAML、Reptile

    文章目录 1.介绍 为什么需要元学习? few-shot learning reinforcement learning 2.概念 3. Meta learning 三个步骤 定义一组learning ...

  2. 强化学习-把元学习(Meta Learning)一点一点讲给你听

    目录 0 Write on the front 1 What is meta learning? 2 Meta Learning 2.1 Define a set of learning algori ...

  3. 李宏毅机器学习35——meta learning

    摘要: 这节课学习了meta learning(元学习) 元学习是通过一系列task的训练,让机器成为一个更好的学习者,当机器遇到新的学习任务时,就能更快的完成.类比到现实生活中,元学习不是指某一类的 ...

  4. 【李宏毅2020 ML/DL】P88-96 Meta Learning – MAML | Reptile

    我已经有两年 ML 经历,这系列课主要用来查缺补漏,会记录一些细节的.自己不知道的东西. 本节内容综述 元学习就是 Learn to learn ,让机器变成 a better learner .Me ...

  5. 理解Meta Learning 元学习,这篇文章就够了!

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 AI编辑:我是小将 本文作者:谢杨易 1 什么是meta lear ...

  6. 元学习Meta learning深入理解

    目录 基本理解 元学习与传统的机器学习不同在哪里? 基本思想 MAML MAML与pre-training有什么区别呢? 1. 损失函数不同 2. 优化思想不同 MAML的优点及特点 MAML工作机理 ...

  7. 【李宏毅2020 ML/DL】P97-98 More about Meta Learning

    我已经有两年 ML 经历,这系列课主要用来查缺补漏,会记录一些细节的.自己不知道的东西. 本节内容综述 本节课由助教 陈建成 讲解. 本节 Outline 见小细节. 首先是 What is meta ...

  8. Meta-Weight-Net[NIPS‘2019]:关于元学习/域自适应(meta learning/domain adaptation)优化噪声标签与类别不平衡的问题

    目录 研究背景 一.为什么存在类别不平衡现象? 二. Meta-Weight-Net[NIPS'2019] 1.Focal Loss 2.self-pacd learning 3.Meta-Weigh ...

  9. [转载]Meta Learning单排小教学

    原文链接:Meta Learning单排小教学 虽然Meta Learning现在已经非常火了,但是还有很多小伙伴对于Meta Learning不是特别理解.考虑到我的这个AI游乐场将充斥着Meta ...

最新文章

  1. 如何用DELPHI区分彩色图和黑白图(多色与单色)??
  2. php imap配置,php中的自定义IMAP命令
  3. JVM内存堆布局图解分析
  4. 计算机复位启动如何操作,电脑能开机但进入不了系统,怎么办?按什么键能恢复系统?...
  5. jQuery文本动画效果
  6. 如何保证Redis与数据库的双写一致性?进阶加薪全靠它!
  7. 数据库-MySQL-数据库和表
  8. 函数平移口诀_初三二次函数平移规律的口诀
  9. python 课后作业_python 学习笔记 五 课后作业
  10. 【报告分享】美好城市指数:短视频与城市繁荣关系白皮书.pdf(附下载链接)...
  11. 全网音乐解析_Android干货 | 手机端的音乐神器
  12. 数据分析最难的就是怎样读懂数据!
  13. Android lowmemorykiller分析
  14. 《构建之法》8、9、10
  15. 软件研发软件基础设施的建设
  16. 5操作系统的运行机制和体系结构
  17. 【CVPR2020】Designing Network Design Spaces
  18. 从两个和尚挑水的故事得到的启发
  19. vue+ElementUI页面模板
  20. 数据结构—栈与队列【顺序存储、链式存储、卡特兰数、优先级队列】

热门文章

  1. 这4个文档排版方式掌握了,工作效率提高的不止一点点!
  2. 28个UI免费漂亮的切换开关PSD下载
  3. 【操作系统基础】操作系统的分类与发展
  4. Excel无法完全显示超过11位数的数字?超简单方法1秒教给你!
  5. iar 预编译会把非条件的去掉_IAR使用最全方法.docx
  6. 华为mate40计算机,华为Mate40兼容欧米多手机电脑,秒变笔电轻办公
  7. 揭开手机app中摇一摇的神秘面纱
  8. 把多个txt文件合成一个txt文件
  9. 根据程序流程图化程序流图_程序流程图如何画?绘图软件有哪些
  10. 让文本框输入的文字距离左边框4px