Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks(MAML)研读笔记
这里是引用
MAML全文目录
- 论文地址
- 摘要
- 介绍
- 相关概念
- model-agnostic
- N-way K-shot
- Task
- 5-way 5-shot的实验设置
- 算法流程
- fine-tune算法流程
- 参考文献
论文地址
https://arxiv.org/abs/1703.03400
摘要
- 一种与模型无关的元学习算法
- 适用于适用梯度下降更新训练的模型,分类、回归、强化学习等
- 元学习的目标:在大量不同的任务上预先训练出一个模型,用这个模型可以在新任务的少量样本上进行训练。
预训练,新任务小样本
- 本文只需要经过几步梯度更新的微调就可以取得不错的效果
- 效果:在小样本的图像分类成绩最好,小样本回归问题和强化学习策略梯度中也取得了好的成绩
介绍
- 元学习阶段,使用全部任务做训练样本
- 没有扩展模型参数,也没有限制模型结构
- 学习过程是使新任务的损失函数对参数的敏感度最大化,当敏感度高时,参数微小的变化就可以使损失函数大幅度变化
- 与最先进的小样本学习模型相比,模型使用了更少的参数。回归任务中,在任务可变性下加速强化学习,大大优于初始化的直接预训练
- 训练过程:先从任务中抽样一个任务TiT_iTi,然后从qiq_iqi中抽取KKK个样本训练模型,然后得到TiT_iTi的损失函数LTiL_{T_i}LTi,然后在新样本上进程测试,通过qiq_iqi中抽样的新数据的测试误差随参数的变化情况,来提高模型的性能。
- 测试误差和训练误差:在所有抽样任务TiT_iTi上的测试误差构成了元学习训练阶段的训练误差。
- 评估阶段(最后阶段):从任务分布P(T)P(T)P(T)中抽样一些新任务,每个任务具有K个样本,通过KKK个样本的学习,来作为最后的模型评估
- 元测试的任务在元训练期间
- 创新:例如,神经网络可能学习广泛适用于P(T)P(T)P(T)中所有任务的内部特征,而不是单个任务的特征表示
- 方法:就是为了让模型fff在抽样的任务上快速适应,同时不产生过拟合。就是去找一组对任务中变化比较敏感的参数
- 梯度更新方法:θi′=θ−α∇θLTi(fθ)\theta_i^\prime=\theta-\alpha\nabla_\theta L_{T_i}(f_\theta)θi′=θ−α∇θLTi(fθ)
- f(θ)f(\theta)f(θ)是由参数θ\thetaθ表示的模型,LTi(fθ)L_{T_i}(f_\theta)LTi(fθ)是任务TiT_iTi的损失函数,θi′\theta_i^\primeθi′是是在TiT_iTi经过一次或者多次梯度下降更新得到的参数更新
- 优化方法:
- 模型的参数通过优化所有任务上的fθ′f_\theta^{'}fθ′来进行更新,minθ∑Ti∼p(T)LTi(fθ′)=∑Ti∼p(T)α∇θLTi(fθ)LTi(fθ−α∇θLTi(fθ))\min_\theta\sum_{T_i \sim p(T)}L_{T_i}(f_\theta^{'})=\sum_{T_i \sim p(T)}\alpha\nabla_\theta L_{T_i}(f_\theta)L_{T_i}(f_{\theta-\alpha\nabla_\theta L_{T_i}(f_\theta)})θminTi∼p(T)∑LTi(fθ′)=Ti∼p(T)∑α∇θLTi(fθ)LTi(fθ−α∇θLTi(fθ))
– 元学习阶段的优化是在模型参数θ\thetaθ上进行的,而上述目标是使用更新过的θ′\theta^{'}θ′得到的,提出的方法要在新任务上通过一次或几次梯度更新来优化模型参数 - meta的优化,通过梯度下降来进行更新的:θ←θ−β∇θ∑Ti∼p(T)LTi(fθ)\theta \gets \theta-\beta\nabla_\theta\sum_{T_i \sim p(T)} L_{T_i}(f_\theta)θ←θ−β∇θTi∼p(T)∑LTi(fθ)
- 模型的参数通过优化所有任务上的fθ′f_\theta^{'}fθ′来进行更新,minθ∑Ti∼p(T)LTi(fθ′)=∑Ti∼p(T)α∇θLTi(fθ)LTi(fθ−α∇θLTi(fθ))\min_\theta\sum_{T_i \sim p(T)}L_{T_i}(f_\theta^{'})=\sum_{T_i \sim p(T)}\alpha\nabla_\theta L_{T_i}(f_\theta)L_{T_i}(f_{\theta-\alpha\nabla_\theta L_{T_i}(f_\theta)})θminTi∼p(T)∑LTi(fθ′)=Ti∼p(T)∑α∇θLTi(fθ)LTi(fθ−α∇θLTi(fθ))
相关概念
model-agnostic
model-agnostic
:- model-agnostic即模型无关。
- MAML与其说是一个深度学习模型,倒不如说是一个框架,提供一个meta-learner用于训练base-learner
- meta-learner即MAML的精髓
- base-learner则是在目标数据集上被训练,并实际用于预测任务的真正的数学模型。
- 大多数深度学习模型都可以作为base-learner嵌入MAML
- meta-learner →\rightarrow→ base-learner
N-way K-shot
N-way K-shot
:- 是few-shot learning(小样本学习)中常用的实验设置。小样本学习指利用很少的被标记数据训练数学模型的过程(MAML擅长的)
- N-way指训练数据中有N个类别
- K-shot指每个类别下有K个被标记数据
Task
Task
:- 假设一个场景:我们需要利用MAML训练一个数学模型Mfine−tuneM_{fine-tune}Mfine−tune(fine-tune为微调),目的是对未知标签的图片做分类,类别包括 P1∼P5P_1 \sim P_5P1∼P5 (
每个类别有5个已标注样本用于训练
。另外每个类别有15个已标注样本用于测试
,一共100个已标注样本)。我们的训练数据除了P1∼P5P_1 \sim P_5P1∼P5 中已标注的样本外,还包括另外10个类别
的图片C1∼C10C_1 \sim C_{10}C1∼C10(每种类别有30个已标注样本,一共300个已标注),用于帮助训练元学习模型 MmetaM_{meta}Mmeta 。我们的实验设置为5-way 5-shot
,也就是
C1∼C10C_1 \sim C_{10}C1∼C10的随机抽取的5个类别中的样本是先来训练元学习模型
MmetaM_{meta}Mmeta - 训练过程大概为:
- MAML首先利用 C1∼C10C_1 \sim C_{10}C1∼C10 的数据集训练元模型MmetaM_{meta}Mmeta,再在P1∼P5P_1 \sim P_5P1∼P5的数据集上精调(
fine-tune
)得到最终的模型 Mfine−tuneM_{fine-tune}Mfine−tune ,下面的算法流程主要就是这部分
- MAML首先利用 C1∼C10C_1 \sim C_{10}C1∼C10 的数据集训练元模型MmetaM_{meta}Mmeta,再在P1∼P5P_1 \sim P_5P1∼P5的数据集上精调(
- C1∼C10C_1 \sim C_{10}C1∼C10 即
meta-train classes
,C1∼C10C_1 \sim C_{10}C1∼C10 包含的共计300个样本,即 Dmeta−trainD_{meta-train}Dmeta−train ,是用于训练MmetaM_{meta}Mmeta的数据集; - 与之相对的,P1∼P5P_1 \sim P_5P1∼P5 即
meta-test classes
,P1∼P5P_1 \sim P_5P1∼P5 包含的共计100个样本,即 Dmeta−testD_{meta-test}Dmeta−test ,是用于训练和测试 Mfine−tuneM_{fine-tune}Mfine−tune 的数据集
- 假设一个场景:我们需要利用MAML训练一个数学模型Mfine−tuneM_{fine-tune}Mfine−tune(fine-tune为微调),目的是对未知标签的图片做分类,类别包括 P1∼P5P_1 \sim P_5P1∼P5 (
5-way 5-shot的实验设置
5-way 5-shot
的实验设置:task T
,相当于普通深度学习模型训练过程中的一条训练数据:- 在训练MmetaM_{meta}Mmeta阶段,从C1∼C10C_1 \sim C_{10}C1∼C10中随机取
5个类别
,每个类别再随机取20个已标注样本
,组成一个task T
。 - 每个类别随机取
20个已标注的样本
中,其中的5个已标注的样本
称为T
的support set
,另外15个样本
称为T
的query set
- 这个
task T
,相当于普通深度学习模型训练过程中的一条训练数据。类似于SGD
,要搞batch
,即反复从训练数据分布中抽取若干(分布解释看李沐的)个这样的task T
,组成一个batch
- 在训练MmetaM_{meta}Mmeta阶段,从C1∼C10C_1 \sim C_{10}C1∼C10中随机取
- 训练Mfine−tuneM_{fine-tune}Mfine−tune阶段,
task、support set、query set
的含义和训练MmetaM_{meta}Mmeta阶段相同
算法流程
MAML
预训练阶段的算法的目的
:得到模型MmetaM_{meta}Mmeta- 第一行的
Require
:- 指的是在Dmeta−trainD_{meta-train}Dmeta−train 中
task
的分布。这里就是反复随机抽取的task T
,形成若干个T
(例如抽取1000个)组成的task
池,作为MAML
的训练集 - 这里有一个问题:训练样本Dmeta−trainD_{meta-train}Dmeta−train数量有限,要组合形成那么多的
task
,岂不是不同task
之间会存在样本的重复?或者某些task
的query set
会成为其他task
的support set
?- 答:
MAML
的目的,就在于fast adaptation
,即通过对大量task
的学习,获得足够强的泛化能力
,从而面对新的、从未见过的task
时,通过fine-tune
就可以快速拟合。task
之间,只要存在一定的差异即可
- 答:
MAML
的训练是基于task
的,而这里的每个task
就相当于普通深度学习模型训练过程中的一条训练数据
。
- 指的是在Dmeta−trainD_{meta-train}Dmeta−train 中
- 第二行的
require
:- α,β\alpha,\betaα,β:
step
size
其实就是学习率 MAML
是基于二重梯度
(gradient by gradient)(下面有讲什么是二重梯度),每次迭代包括两次参数更新的过程,所以有两个学习率需要调整
- α,β\alpha,\betaα,β:
- 步骤1:
- 随机初始化模型的参数θ\thetaθ
- 步骤2:
- 是一个循环,可以理解为一轮迭代过程或一个epoch,预训练的过程是可以有多个epoch的
- 步骤3:
- 从分布中随机采样若干数量的
task
(例如5个),形成一个batch
- 从分布中随机采样若干数量的
- 步骤4-7:
- 是第一次梯度更新的过程
- copy了原模型,然后在copy的模型上计算出新的参数,用于第二个梯度的计算过程中
每一个
task更新一次参数(5个
task就是更新5次),就是一个
batch结束
,这个在算法中可以反复执行多次
- 步骤5:
- 使用batch中某个task中的
support set
(一个task就是k个,例如5w5k实验设置中的就是5个已标注的样本),来计算参数的梯度,总的support set
有N*K
个(5w5k就是25个,取5个类别,从C1∼C10C_1 \sim C_{10}C1∼C10中随机取5个类别
,再从每个类别的20个已标注样本
中取5个
) Loss
方法,回归任务,就是MSE
;分类任务,就是cross-entropy(交叉熵)
(李沐视频有讲什么是交叉熵)
- 使用batch中某个task中的
- 步骤6:
- 第一次参数梯度的更新
- 步骤4-7完,
MAML
完成了第一次梯度更新。根据第一次梯度更新得到的参数,计算第二次梯度更新(步骤8)。第二次的梯度更新时计算出的梯度,直接通过SGD作用在原模型
(下面说)上,也就是模型真正用于更新其参数的梯度。也就是第一次梯度更新是为了第二次梯度更新,第二次梯度才是真正更新模型参数。- 二重梯度:
- 原模型θa\theta_aθa,先复制一份原模型,θa→copyθb\theta_a \rightarrow_{copy} \theta_bθa→copyθb,得到θb\theta_bθb。
- 在θb\theta_bθb上,做反向传播(这个要搞懂什么意思)及更新参数,得到第一次梯度更新的结果θb′\theta_b^{'}θb′。
- 在θb′\theta_b^{'}θb′上,计算第二次梯度更新,计算出来,不更新θb′\theta_b^{'}θb′,更新原模型θa\theta_aθa
这里需要理解是copy的模型是每一个task都会copy一份,如10个task会copy10个临时模型,在10个临时模型上,在各自的task上独立更新一个梯度(步骤4-7),然后整合起来用于步骤8,也就是更新原模型
这样做,就是因为每一个task都会更新一次参数,用原模型,会导致使用上一个task的更新过的参数
- 从原模型的角度来看,只进行了一次梯度更新(步骤8),但是第二次梯度更新(步骤8)依赖于第一次(步骤4-7)
- 总结:第一次梯度,不作用于原模型,第二次梯度用于原模型
- 二重梯度:
- 步骤8:
- 第二次梯度更新的过程
- 与步骤7不同处:
- 1.不是分别利用每个task的Loss更新梯度,直接和常用的模型训练一样,计算一个batch的loss和,对梯度进行随机梯度下降SGD
- 2.这里的样本是
task
的query set
(如5w5k中 15*5个样本),是为了增强模型在task上的泛化,避免过拟合support set
- 该步骤结束后,即完成在当前batch中的训练,回到步骤3,采样下一个batch
fine-tune算法流程
- 完成以上步骤,就是
MAML
预训练得到MmetaM_{meta}Mmeta的全部过程 - 接下来要完成的就是面对新的
task
,在MmetaM_{meta}Mmeta的基础上,精调得到Mfine−tuneM_{fine-tune}Mfine−tune:- 步骤1中,
fine-tune
不再随机初始化参数,而是利用训练好的MmetaM_{meta}Mmeta初始化参数 - 步骤3中,
fine-tune
只需要抽取一个task
进行学习,也不用形成batch,fine-tune
利用这个task
的support set
训练模型,利用query set
测试模型。- 实际操作时,会在Dmeta−testD_{meta-test}Dmeta−test上随机抽取许多个
task
(例如500个),分别微调模型MmetaM_{meta}Mmeta,并对最后的测试结果进行平均,从而避免极端情况
- 实际操作时,会在Dmeta−testD_{meta-test}Dmeta−test上随机抽取许多个
- 没有步骤8,因为
task
的query set
是用来测试模型的,标签对模型是未知的。因此这个过程没有第二次梯度更新,直接用第一次梯度计算的结果更新参数
- 步骤1中,
参考文献
1.Model-Agnostic Meta-Learning (MAML)模型介绍及算法详解 作者:徐不知
2.[meta-learning] 对MAML的深度解析 作者:周威
3.MAML 论文及代码阅读笔记 作者:Rust-in
4.MAML原论文
Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks(MAML)研读笔记相关推荐
- Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks(MAML)简析
在看MAML这篇论文的时候,因为是初学者,很多都不懂,网上查了许多资料也没看明白,最后来来回回找了很多资料结合原文才看懂一些,在这简单分享一下. 什么是元学习? 元学习(meta-learning)已 ...
- Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks阅读笔记
Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks-阅读笔记 Abstract MAML算法 问题设置 MAML算法 M ...
- MAML:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks论文精读及详解
由于论文写得比较抽象,偏向于数学,因此,在开篇首先谈谈我自己对MAML的理解,在后面再简要的抽取一下论文的核心部分 元学习解决的问题 首先,对于深度学习领域,模型初始化的权重参数尤为重要,模型参数初始 ...
- 深度学习材料:从感知机到深度网络A Deep Learning Tutorial: From Perceptrons to Deep Networks
In recent years, there's been a resurgence in the field of Artificial Intelligence. It's spread beyo ...
- 理解Meta Learning 元学习,这篇文章就够了!
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 AI编辑:我是小将 本文作者:谢杨易 1 什么是meta lear ...
- (转)Paper list of Meta Learning/ Learning to Learn/ One Shot Learning/ Lifelong Learning
Meta Learning/ Learning to Learn/ One Shot Learning/ Lifelong Learning 2018-08-03 19:16:56 本文转自:http ...
- Meta Learning在NLP领域的应用
Hi,这里是哈林,今天来跟大家聊一聊Meta Learning在NLP领域的一些应用. 哈林之前在学校科研的方向是NLP,个人对如何将先进的机器学习算法应用到NLP场景很感兴趣(因为好水paper), ...
- 强化学习-把元学习(Meta Learning)一点一点讲给你听
目录 0 Write on the front 1 What is meta learning? 2 Meta Learning 2.1 Define a set of learning algori ...
- 有关meta learning 要读的论文清单
参考博客:https://blog.csdn.net/weixin_37589575/article/details/92801610 论文链接 <Optimization as a Model ...
最新文章
- 第三个Sprint冲刺第八天(燃尽图)
- KDD 2019 | 结合属性随机游走的图递归网络
- linux优化网页加载过程,HTML页面加载和解析流程 介绍
- IFTTT 加入开源大家庭,已开源5个项目
- 青龙羊毛——飞鸽花转省毛毛(搬运)
- python最难学的是什么_python是最难学的语言吗
- 大熊君说说JS与设计模式之(门面模式Facade)迪米特法则的救赎篇------(监狱的故事)...
- iOS开发--TableView详细解释
- POJ 1664 把苹果
- 【Linux系统编程】线程同步与互斥:POSIX无名信号量
- SpringBoot2 整合 AXIS 服务端和客户端
- CTF【解密】字符串flag被加密成已知新字符串,请解密出flag,可以使用Python解码出WriteUp
- 【OpenCV 例程200篇】38. 图像的反色变换(图像反转)
- ios 支付宝支付 回调数据_iOS逆向支付宝
- MyEclipse用(JDBC)连接SQL出现的问题~
- ServletContext的应用(共享数据、获取初始化参数、请求转发、读取资源文件)【源码解析】
- 在Unbuntu 上安装Phalcon
- JDK8的shenandoah GC/zgc啥时能转正?
- 大庆金桥:基于 SpreadJS 开发实现计量器具检定证书的在线生成与打印
- 手把手教你用psp手动制作背景透明的图片
热门文章
- Nginx是什么,为什么使用Nginx
- react 调用微信jsdk扫一扫
- 文献阅读笔记 # Bitcoin: A Peer-to-Peer Electronic Cash System
- 使用PreTranslateMessage替代钩子函数处理键盘消息
- 如何抠图图片?这个方法值得点赞收藏
- 【ps功能精通】4.简单背景图片抠图
- 庖丁解牛linux内核 百度云,庖丁解牛Linux内核-1
- Sitemap网站地图生成工具(适用于所有网站)
- 掌握SQL Monitoring这些特性,SQL优化通通不在话下
- 高瓴投的澳斯康生物冲刺科创板:年营收4.5亿 丢掉与康希诺合作