这里是引用

MAML全文目录

  • 论文地址
  • 摘要
  • 介绍
  • 相关概念
    • model-agnostic
    • N-way K-shot
    • Task
    • 5-way 5-shot的实验设置
  • 算法流程
  • fine-tune算法流程
  • 参考文献

论文地址

https://arxiv.org/abs/1703.03400

摘要

  • 一种与模型无关的元学习算法
  • 适用于适用梯度下降更新训练的模型,分类、回归、强化学习等
  • 元学习的目标:在大量不同的任务上预先训练出一个模型,用这个模型可以在新任务的少量样本上进行训练。预训练,新任务小样本
  • 本文只需要经过几步梯度更新的微调就可以取得不错的效果
  • 效果:在小样本的图像分类成绩最好,小样本回归问题和强化学习策略梯度中也取得了好的成绩

介绍

  • 元学习阶段,使用全部任务做训练样本
  • 没有扩展模型参数,也没有限制模型结构
  • 学习过程是使新任务的损失函数对参数的敏感度最大化,当敏感度高时,参数微小的变化就可以使损失函数大幅度变化
  • 与最先进的小样本学习模型相比,模型使用了更少的参数。回归任务中,在任务可变性下加速强化学习,大大优于初始化的直接预训练
  • 训练过程:先从任务中抽样一个任务TiT_iTi​,然后从qiq_iqi​中抽取KKK个样本训练模型,然后得到TiT_iTi​的损失函数LTiL_{T_i}LTi​​,然后在新样本上进程测试,通过qiq_iqi​中抽样的新数据的测试误差随参数的变化情况,来提高模型的性能。
  • 测试误差和训练误差:在所有抽样任务TiT_iTi​上的测试误差构成了元学习训练阶段的训练误差。
  • 评估阶段(最后阶段):从任务分布P(T)P(T)P(T)中抽样一些新任务,每个任务具有K个样本,通过KKK个样本的学习,来作为最后的模型评估
  • 元测试的任务在元训练期间
  • 创新:例如,神经网络可能学习广泛适用于P(T)P(T)P(T)中所有任务的内部特征,而不是单个任务的特征表示
  • 方法:就是为了让模型fff在抽样的任务上快速适应,同时不产生过拟合。就是去找一组对任务中变化比较敏感的参数
  • 梯度更新方法:θi′=θ−α∇θLTi(fθ)\theta_i^\prime=\theta-\alpha\nabla_\theta L_{T_i}(f_\theta)θi′​=θ−α∇θ​LTi​​(fθ​)
  • f(θ)f(\theta)f(θ)是由参数θ\thetaθ表示的模型,LTi(fθ)L_{T_i}(f_\theta)LTi​​(fθ​)是任务TiT_iTi​的损失函数,θi′\theta_i^\primeθi′​是是在TiT_iTi​经过一次或者多次梯度下降更新得到的参数更新
  • 优化方法:
    • 模型的参数通过优化所有任务上的fθ′f_\theta^{'}fθ′​来进行更新,min⁡θ∑Ti∼p(T)LTi(fθ′)=∑Ti∼p(T)α∇θLTi(fθ)LTi(fθ−α∇θLTi(fθ))\min_\theta\sum_{T_i \sim p(T)}L_{T_i}(f_\theta^{'})=\sum_{T_i \sim p(T)}\alpha\nabla_\theta L_{T_i}(f_\theta)L_{T_i}(f_{\theta-\alpha\nabla_\theta L_{T_i}(f_\theta)})θmin​Ti​∼p(T)∑​LTi​​(fθ′​)=Ti​∼p(T)∑​α∇θ​LTi​​(fθ​)LTi​​(fθ−α∇θ​LTi​​(fθ​)​)
      – 元学习阶段的优化是在模型参数θ\thetaθ上进行的,而上述目标是使用更新过的θ′\theta^{'}θ′得到的,提出的方法要在新任务上通过一次或几次梯度更新来优化模型参数
    • meta的优化,通过梯度下降来进行更新的:θ←θ−β∇θ∑Ti∼p(T)LTi(fθ)\theta \gets \theta-\beta\nabla_\theta\sum_{T_i \sim p(T)} L_{T_i}(f_\theta)θ←θ−β∇θ​Ti​∼p(T)∑​LTi​​(fθ​)

相关概念

model-agnostic

  • model-agnostic:

    • model-agnostic即模型无关。
    • MAML与其说是一个深度学习模型,倒不如说是一个框架,提供一个meta-learner用于训练base-learner
    • meta-learner即MAML的精髓
    • base-learner则是在目标数据集上被训练,并实际用于预测任务的真正的数学模型。
    • 大多数深度学习模型都可以作为base-learner嵌入MAML
    • meta-learner →\rightarrow→ base-learner

N-way K-shot

  • N-way K-shot:

    • 是few-shot learning(小样本学习)中常用的实验设置。小样本学习指利用很少的被标记数据训练数学模型的过程(MAML擅长的)
    • N-way指训练数据中有N个类别
    • K-shot指每个类别下有K个被标记数据

Task

  • Task:

    • 假设一个场景:我们需要利用MAML训练一个数学模型Mfine−tuneM_{fine-tune}Mfine−tune​(fine-tune为微调),目的是对未知标签的图片做分类,类别包括 P1∼P5P_1 \sim P_5P1​∼P5​ (每个类别有5个已标注样本用于训练。另外每个类别有15个已标注样本用于测试,一共100个已标注样本)。我们的训练数据除了P1∼P5P_1 \sim P_5P1​∼P5​ 中已标注的样本外,还包括另外10个类别的图片C1∼C10C_1 \sim C_{10}C1​∼C10​(每种类别有30个已标注样本,一共300个已标注),用于帮助训练元学习模型 MmetaM_{meta}Mmeta​ 。我们的实验设置为5-way 5-shot也就是C1∼C10C_1 \sim C_{10}C1​∼C10​的随机抽取的5个类别中的样本是先来训练元学习模型 MmetaM_{meta}Mmeta​
    • 训练过程大概为:
      • MAML首先利用 C1∼C10C_1 \sim C_{10}C1​∼C10​ 的数据集训练元模型MmetaM_{meta}Mmeta​,再在P1∼P5P_1 \sim P_5P1​∼P5​的数据集上精调(fine-tune)得到最终的模型 Mfine−tuneM_{fine-tune}Mfine−tune​ ,下面的算法流程主要就是这部分
    • C1∼C10C_1 \sim C_{10}C1​∼C10​ 即meta-train classes,C1∼C10C_1 \sim C_{10}C1​∼C10​ 包含的共计300个样本,即 Dmeta−trainD_{meta-train}Dmeta−train​ ,是用于训练MmetaM_{meta}Mmeta​的数据集;
    • 与之相对的,P1∼P5P_1 \sim P_5P1​∼P5​ 即meta-test classes,P1∼P5P_1 \sim P_5P1​∼P5​ 包含的共计100个样本,即 Dmeta−testD_{meta-test}Dmeta−test​ ,是用于训练和测试 Mfine−tuneM_{fine-tune}Mfine−tune​ 的数据集

5-way 5-shot的实验设置

  • 5-way 5-shot的实验设置:

    • task T,相当于普通深度学习模型训练过程中的一条训练数据:

      • 在训练MmetaM_{meta}Mmeta​阶段,从C1∼C10C_1 \sim C_{10}C1​∼C10​中随机取5个类别,每个类别再随机取20个已标注样本,组成一个task T
      • 每个类别随机取20个已标注的样本中,其中的5个已标注的样本称为Tsupport set,另外15个样本称为Tquery set
      • 这个task T,相当于普通深度学习模型训练过程中的一条训练数据。类似于SGD,要搞batch,即反复从训练数据分布中抽取若干(分布解释看李沐的)个这样的task T,组成一个batch
    • 训练Mfine−tuneM_{fine-tune}Mfine−tune​阶段,task、support set、query set的含义和训练MmetaM_{meta}Mmeta​阶段相同

算法流程

  • MAML预训练阶段的算法的目的:得到模型MmetaM_{meta}Mmeta​
  • 第一行的Require:
    • 指的是在Dmeta−trainD_{meta-train}Dmeta−train​ 中task的分布。这里就是反复随机抽取的task T,形成若干个T(例如抽取1000个)组成的task池,作为MAML的训练集
    • 这里有一个问题:训练样本Dmeta−trainD_{meta-train}Dmeta−train​数量有限,要组合形成那么多的task,岂不是不同task之间会存在样本的重复?或者某些taskquery set会成为其他tasksupport set
      • 答:MAML的目的,就在于fast adaptation,即通过对大量task的学习,获得足够强的泛化能力,从而面对新的、从未见过的task时,通过fine-tune就可以快速拟合。task之间,只要存在一定的差异即可
    • MAML的训练是基于task的,而这里的每个task就相当于普通深度学习模型训练过程中的一条训练数据
  • 第二行的require:
    • α,β\alpha,\betaα,β:step size其实就是学习率
    • MAML是基于二重梯度(gradient by gradient)(下面有讲什么是二重梯度),每次迭代包括两次参数更新的过程,所以有两个学习率需要调整
  • 步骤1:
    • 随机初始化模型的参数θ\thetaθ
  • 步骤2:
    • 是一个循环,可以理解为一轮迭代过程或一个epoch,预训练的过程是可以有多个epoch的
  • 步骤3:
    • 从分布中随机采样若干数量的task(例如5个),形成一个batch
  • 步骤4-7:
    • 是第一次梯度更新的过程
    • copy了原模型,然后在copy的模型上计算出新的参数,用于第二个梯度的计算过程中
    • 每一个task更新一次参数(5个task就是更新5次),就是一个batch结束,这个在算法中可以反复执行多次
  • 步骤5:
    • 使用batch中某个task中的support set(一个task就是k个,例如5w5k实验设置中的就是5个已标注的样本),来计算参数的梯度,总的support setN*K个(5w5k就是25个,取5个类别,从C1∼C10C_1 \sim C_{10}C1​∼C10​中随机取5个类别,再从每个类别的20个已标注样本中取5个
    • Loss方法,回归任务,就是MSE;分类任务,就是cross-entropy(交叉熵)(李沐视频有讲什么是交叉熵)
  • 步骤6:
    • 第一次参数梯度的更新
  • 步骤4-7完,MAML完成了第一次梯度更新。根据第一次梯度更新得到的参数,计算第二次梯度更新(步骤8)。第二次的梯度更新时计算出的梯度,直接通过SGD作用在原模型(下面说)上,也就是模型真正用于更新其参数的梯度。也就是第一次梯度更新是为了第二次梯度更新,第二次梯度才是真正更新模型参数。
    • 二重梯度:

      • 原模型θa\theta_aθa​,先复制一份原模型,θa→copyθb\theta_a \rightarrow_{copy} \theta_bθa​→copy​θb​,得到θb\theta_bθb​。
      • 在θb\theta_bθb​上,做反向传播(这个要搞懂什么意思)及更新参数,得到第一次梯度更新的结果θb′\theta_b^{'}θb′​。
      • 在θb′\theta_b^{'}θb′​上,计算第二次梯度更新,计算出来,不更新θb′\theta_b^{'}θb′​,更新原模型θa\theta_aθa​
      • 这里需要理解是copy的模型是每一个task都会copy一份,如10个task会copy10个临时模型,在10个临时模型上,在各自的task上独立更新一个梯度(步骤4-7),然后整合起来用于步骤8,也就是更新原模型
      • 这样做,就是因为每一个task都会更新一次参数,用原模型,会导致使用上一个task的更新过的参数
      • 从原模型的角度来看,只进行了一次梯度更新(步骤8),但是第二次梯度更新(步骤8)依赖于第一次(步骤4-7)
      • 总结:第一次梯度,不作用于原模型,第二次梯度用于原模型
  • 步骤8:
    • 第二次梯度更新的过程
    • 与步骤7不同处:
      • 1.不是分别利用每个task的Loss更新梯度,直接和常用的模型训练一样,计算一个batch的loss和,对梯度进行随机梯度下降SGD
      • 2.这里的样本是taskquery set(如5w5k中 15*5个样本),是为了增强模型在task上的泛化,避免过拟合support set
      • 该步骤结束后,即完成在当前batch中的训练,回到步骤3,采样下一个batch

fine-tune算法流程

  • 完成以上步骤,就是MAML预训练得到MmetaM_{meta}Mmeta​的全部过程
  • 接下来要完成的就是面对新的task,在MmetaM_{meta}Mmeta​的基础上,精调得到Mfine−tuneM_{fine-tune}Mfine−tune​:
    • 步骤1中,fine-tune不再随机初始化参数,而是利用训练好的MmetaM_{meta}Mmeta​初始化参数
    • 步骤3中,fine-tune只需要抽取一个task进行学习,也不用形成batch,fine-tune利用这个tasksupport set训练模型,利用query set测试模型。
      • 实际操作时,会在Dmeta−testD_{meta-test}Dmeta−test​上随机抽取许多个task(例如500个),分别微调模型MmetaM_{meta}Mmeta​,并对最后的测试结果进行平均,从而避免极端情况
    • 没有步骤8,因为taskquery set是用来测试模型的,标签对模型是未知的。因此这个过程没有第二次梯度更新,直接用第一次梯度计算的结果更新参数

参考文献

1.Model-Agnostic Meta-Learning (MAML)模型介绍及算法详解 作者:徐不知
2.[meta-learning] 对MAML的深度解析 作者:周威
3.MAML 论文及代码阅读笔记 作者:Rust-in
4.MAML原论文

Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks(MAML)研读笔记相关推荐

  1. Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks(MAML)简析

    在看MAML这篇论文的时候,因为是初学者,很多都不懂,网上查了许多资料也没看明白,最后来来回回找了很多资料结合原文才看懂一些,在这简单分享一下. 什么是元学习? 元学习(meta-learning)已 ...

  2. Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks阅读笔记

    Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks-阅读笔记 Abstract MAML算法 问题设置 MAML算法 M ...

  3. MAML:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks论文精读及详解

    由于论文写得比较抽象,偏向于数学,因此,在开篇首先谈谈我自己对MAML的理解,在后面再简要的抽取一下论文的核心部分 元学习解决的问题 首先,对于深度学习领域,模型初始化的权重参数尤为重要,模型参数初始 ...

  4. 深度学习材料:从感知机到深度网络A Deep Learning Tutorial: From Perceptrons to Deep Networks

    In recent years, there's been a resurgence in the field of Artificial Intelligence. It's spread beyo ...

  5. 理解Meta Learning 元学习,这篇文章就够了!

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 AI编辑:我是小将 本文作者:谢杨易 1 什么是meta lear ...

  6. (转)Paper list of Meta Learning/ Learning to Learn/ One Shot Learning/ Lifelong Learning

    Meta Learning/ Learning to Learn/ One Shot Learning/ Lifelong Learning 2018-08-03 19:16:56 本文转自:http ...

  7. Meta Learning在NLP领域的应用

    Hi,这里是哈林,今天来跟大家聊一聊Meta Learning在NLP领域的一些应用. 哈林之前在学校科研的方向是NLP,个人对如何将先进的机器学习算法应用到NLP场景很感兴趣(因为好水paper), ...

  8. 强化学习-把元学习(Meta Learning)一点一点讲给你听

    目录 0 Write on the front 1 What is meta learning? 2 Meta Learning 2.1 Define a set of learning algori ...

  9. 有关meta learning 要读的论文清单

    参考博客:https://blog.csdn.net/weixin_37589575/article/details/92801610 论文链接 <Optimization as a Model ...

最新文章

  1. 第三个Sprint冲刺第八天(燃尽图)
  2. KDD 2019 | 结合属性随机游走的图递归网络
  3. linux优化网页加载过程,HTML页面加载和解析流程 介绍
  4. IFTTT 加入开源大家庭,已开源5个项目
  5. 青龙羊毛——飞鸽花转省毛毛(搬运)
  6. python最难学的是什么_python是最难学的语言吗
  7. 大熊君说说JS与设计模式之(门面模式Facade)迪米特法则的救赎篇------(监狱的故事)...
  8. iOS开发--TableView详细解释
  9. POJ 1664 把苹果
  10. 【Linux系统编程】线程同步与互斥:POSIX无名信号量
  11. SpringBoot2 整合 AXIS 服务端和客户端
  12. CTF【解密】字符串flag被加密成已知新字符串,请解密出flag,可以使用Python解码出WriteUp
  13. 【OpenCV 例程200篇】38. 图像的反色变换(图像反转)
  14. ios 支付宝支付 回调数据_iOS逆向支付宝
  15. MyEclipse用(JDBC)连接SQL出现的问题~
  16. ServletContext的应用(共享数据、获取初始化参数、请求转发、读取资源文件)【源码解析】
  17. 在Unbuntu 上安装Phalcon
  18. JDK8的shenandoah GC/zgc啥时能转正?
  19. 大庆金桥:基于 SpreadJS 开发实现计量器具检定证书的在线生成与打印
  20. 手把手教你用psp手动制作背景透明的图片

热门文章

  1. Nginx是什么,为什么使用Nginx
  2. react 调用微信jsdk扫一扫
  3. 文献阅读笔记 # Bitcoin: A Peer-to-Peer Electronic Cash System
  4. 使用PreTranslateMessage替代钩子函数处理键盘消息
  5. 如何抠图图片?这个方法值得点赞收藏
  6. 【ps功能精通】4.简单背景图片抠图
  7. 庖丁解牛linux内核 百度云,庖丁解牛Linux内核-1
  8. Sitemap网站地图生成工具(适用于所有网站)
  9. 掌握SQL Monitoring这些特性,SQL优化通通不在话下
  10. 高瓴投的澳斯康生物冲刺科创板:年营收4.5亿 丢掉与康希诺合作