元学习—模型不可知元学习(MAML)

在之前的文章中,我们介绍了神经图灵机和记忆增强网络(MANN),主要介绍了其对于内存中信息的读取与写入。有兴趣的读者可以参考我之前的博客元学习—神经图灵机。在今天的文章中,我们来介绍一种更加常见的元学习的学习方法,即模型不可知元学习。

1. MAML原理

1.1 MAML引入

MAML是一种最近被提出的,最为主流的一种元学习的方法。其是元学习上的一个重大突破。在元学习中,众所周知,其目标是学会学习。在元学习中,我们从大量的相关学习任务中获取一小部分的样本点,然后通过元学习器来生成一个快速的学习器,再通过少量的样本作用在新的相关的任务之上。

MAML背后的思想是寻找出更好的初始化参数。通过这种更好的初始化参数,模型可以通过少量的梯度下降的步骤来应用到新的任务之上。

下面我们举一个使用神经网络的分类任务作为例子。一般的来讲,我们初始训练过程往往是从一组随机参数开始的,通过最小化loss函数来实现梯度下降的过程,以此对于参数进行调优。即我们通过Loss函数来计算损失,通过梯度下降的方式来寻找新的参数值,新的参数值能够保证Loss变的更小,通过不断的迭代,我们将Loss值降到最小,同时最小的Loss值对应的参数值即为最优值(注意:这个Loss值最小,大多数是局部最小,并非全局最小。)

在MAML中,根据我们上面的描述,我们的目标是希望获取一组相对最优的参数来作为模型的初始化参数,那么应该如何获取这种最优参数呢? 在MAML中,我们使用的是从一些相似数据分布和相似任务上来进行获取。因此,当有一个新的任务开始时,我们不会使用一个随机的参数来进行初始化,我们可以通过将其他相关任务的最优参数进行迁移,作为新任务的初始化参数。这样做的好处有两个,第一个是可以减少梯度下降的步骤,而第二个是可以减少训练过程的数据需求。

这里,我们举一个例子来理解一下MAML计算参数与一般模型计算参数的过程对比,假设我们当前有三个任务,分别使用T1,T2,T3T_1,T_2,T_3T1​,T2​,T3​来进行标记。对于一般的模型而言,首先,我们随机的初始化我们的模型参数θ,并利用模型来实现对任务T1T_1T1​进行训练。然后,通过梯度下降的方式来最小化损失函数L。通过这一次的训练过程,我们可以为任务T1T_1T1​寻找到一个相对最优的参数θ1′θ_1'θ1′​。类似的方式,通过随机初始化参数,可以为任务T2,T3T_2,T_3T2​,T3​寻找相对最优的参数θ2′,θ3′θ_2',θ_3'θ2′​,θ3′​。即,我们通过一组随机初始化的参数θ,可以生成三个相对最优的参数θ1′,θ2′,θ3′θ_1',θ_2',θ_3'θ1′​,θ2′​,θ3′​。即如下图所示:

进一步,在MAML中。为了在初始化的时候替换掉随机生成的参数,以此来减少梯度下降的步数,缩短训练时间。这里选择其他相关任务训练出来的参数θ′θ'θ′来指导初始的参数θ,即如下图所示:

这里,值得考虑的一个问题是,我们选择的指导参数θ′θ'θ′是否能够同时适应三个任务T1,T2,T3T_1,T_2,T_3T1​,T2​,T3​?,从这个角度出发,就需要我们考虑的指导参数θ′θ'θ′应该是一种共同的,泛化的参数。

进一步,当有新的任务T4T_4T4​的时候,我们可以选择使用优化之后的参数θθθ来进行作为新任务的初始化参数。

最后,我们简单的总结一下MAML的基本思路,即寻找一个优化的参数θ,这个参数对于相关任务是通用的,其能够帮助我们使用更少量的样本进行学习,缩短训练时间。这也意味着我们可以将MAML应用到任意的使用梯度下降的学习方法中。下面,我们来具体探索MAML中原理和细节。

1.2 MAML算法流程

通过之前的描述,我们对于MAML的背景已经有了一定的了解,下面我们来探索MAML中的一些细节问题。假设,我们的模型为fff,并且其可以通过参数θθθ来进行描述,即fθf_θfθ​。这里,我们在定义一些相关的任务T,T中任务的分布概率为P(T)P(T)P(T)。

首先,我们先用随机值对于参数θθθ进行随机的初始化。进一步,我们通过概率分布P(T)P(T)P(T)对于任务集合中的任务进行采用,这里选择5个相关任务,作为一个batch,即表达为T={T1,T2,T3,T4,T5}T=\{T_1,T_2,T_3,T_4,T_5\}T={T1​,T2​,T3​,T4​,T5​}。然后,对于每一个任务TiT_iTi​,我们可以采用k个样本点来训练这个模型。至此,根据每一个任务,我们可以计算出来其损失函数LTi(fθ)L_{T_i}(f_θ)LTi​​(fθ​),我们通过梯度下降来最小化这个损失,寻找能够使得的损失函数最小的参数,即:
θi′=θ−α▽θLTi(fθ)θ_i'=θ-α▽_θL_{T_i}(f_θ)θi′​=θ−α▽θ​LTi​​(fθ​)
其中,θi′θ_i'θi′​表示的是对于任务TiT_iTi​的最优化参数,θθθ表示的是初始化参数,α是一个超参数,LTi(fθ)L_{T_i}(f_θ)LTi​​(fθ​)表示的是梯度计算结果。

对于T中5个任务都进行计算之后,我们可以获得各个任务的相对最优的参数集合,即θ′={θ1′,θ2′,θ3′,θ4′,θ5′}θ'=\{θ_1',θ_2',θ_3',θ_4',θ_5'\}θ′={θ1′​,θ2′​,θ3′​,θ4′​,θ5′​}。在采样下一个batch的任务之前,我们使用一个元更新或者元优化的策略。在之前的一步中,我们通过梯度下降计算出了相对最优的参数θi′θ_i'θi′​,并且通过任务TiT_iTi​中的参数对应的梯度,来更新了我们初始化的随机参数θ,这使得我们初始随机的参数θ,移动到了一个相对最优的位置。在一个批次的任务的训练中,减少了梯度下降的步数,这一步被称为“元步”,“元更新”,“元优化”或者“元训练”。通过公式,可以将其描述为:
θ=θ−β▽θ∑Ti−p(T)LTi(fθi′)θ=θ-β▽_θ∑_{T_i-p(T)}L_{T_i}(f_{θ_i'})θ=θ−β▽θ​Ti​−p(T)∑​LTi​​(fθi′​​)
在上述的公式中,θ表示的是初始化的参数,β表示的是一个超参数。LTi(fθi′)L_{T_i}(f_{θ_i'})LTi​​(fθi′​​)表示的是通过参数θi′θ_i'θi′​所计算出来的关于任务TiT_iTi​的梯度结果。这里,我们可以进一步的使用对于各个任务的相对最优参数θi′θ_i'θi′​对于的梯度和的平均值来进行计算。

最后,我们对于MAML算法的流程进行一下简单的总结。MAML算法一共可以分成两个循环,其中一个内部循环被用来确定当前任务集合中的各个任务对应的最优参数θi′θ_i'θi′​。外层的循环用于通过内层计算出来的最优参数对应的梯度来更新我们的初始的随机参数θ。我们使用一张图来描述一下这个过程:

2 MAML模型的应用

2.1 监督学习中的MAML模型

MAML模型善于去寻找最优的模型初始化参数。进一步,我们来描述一下其在监督学习过程中的使用过程。首先,我们先给出监督学习的损失函数的定义形式:

如果是监督学习中的回归学习,我们可以采用均方误差的形式来定义其损失函数:
LTi(fθ)=∑xj,yj−Ti∣∣fθ(xi)−yi∣∣22L_{T_i}(f_θ)=∑_{x_j,y_j-T_i}||f_θ(x_i)-y_i||_2^2LTi​​(fθ​)=xj​,yj​−Ti​∑​∣∣fθ​(xi​)−yi​∣∣22​
如果是监督学习中的分类任务,我们使用交叉熵的损失函数:
LTi(fθ)=∑xj,yj−Tiyjlogfθ(xj)+(1−yj)log(1−fθ(xj))L_{T_i}(f_θ)=∑_{x_j,y_j-T_i}y_jlogf_θ(x_j)+(1-y_j)log(1-f_θ(x_j))LTi​​(fθ​)=xj​,yj​−Ti​∑​yj​logfθ​(xj​)+(1−yj​)log(1−fθ​(xj​))

下面,我们来逐步的介绍MAML的使用过程

  1. 假设我们当前拥有一个模型f,可以通过参数θ来进行描述。并且我们有一个分布为p(T)p(T)p(T)的相关任务集合。首先,我们来随机初始化参数θ。
  2. 我们对任务集合中的任务进行采样,假设我们当前采样的任务集合为T={T1,T2,T3}T=\{T_1,T_2,T_3\}T={T1​,T2​,T3​}
  3. 内层循环:对于当前任务集合T中的每一个任务TiT_iTi​,我们采样K个样本点来生成当前任务的训练集和测试集
    Ditrain={(x1,y1),(x2,y2),...,(xk,yk)}D_i^{train}=\{(x_1,y_1),(x_2,y_2),...,(x_k,y_k)\}Ditrain​={(x1​,y1​),(x2​,y2​),...,(xk​,yk​)}
    Ditest={(x1,y1),(x2,y2),....,(xk,yk)}D_i^{test}=\{(x_1,y_1),(x_2,y_2),....,(x_k,y_k)\}Ditest​={(x1​,y1​),(x2​,y2​),....,(xk​,yk​)}
    这里值得注意的是,我们的这里训练集的样本和测试集的样本是相同的,训练数据集的样本是在内层循环中为具体任务寻找最优参数θi的时候用的。而测试集是在外层循环中,寻找最优的参数θ时被用到。这里的测试集的目的不是来检查模型的表现。其基础的作用是作为外层循环的训练集。我们也可以将我们的测试集称为元训练集
    至此,我们使用监督学习算法作用在DitrainD_i^{train}Ditrain​上面,计算出损失,并使用梯度下降算法来减小损失,获取相对最优参数θi′θ_i'θi′​,即:θi′=θ−α▽θLTi(fθ)θ_i'=θ-α▽_θL_{T_i}(f_θ)θi′​=θ−α▽θ​LTi​​(fθ​)。对于任务集合中的每一个任务,我们都采样K个样本点来在其训练集上进行最小化损失,获取最优参数的操作。最后,我们可以获取一组最优参数:{θ1′,θ2′,θ3′}\{θ_1',θ_2',θ_3'\}{θ1′​,θ2′​,θ3′​}。
  4. 外层循环: 这里我们使用之前定义的测试集来进行元优化。这里,我们使用测试集DitestD_i^{test}Ditest​来最小化损失。通过我们之前计算出来的最优参数{θ1′,θ2′,θ3′}\{θ_1',θ_2',θ_3'\}{θ1′​,θ2′​,θ3′​}对应的梯度结果,我们来最小化外层循环的损失,更新之前的随机参数,即θ=θ−β▽θ∑Ti−p(T)LTi(fθi′)θ=θ-β▽_θ∑_{T_i-p(T)}L_{T_i}(f_{θ_i'})θ=θ−β▽θ​∑Ti​−p(T)​LTi​​(fθi′​​)。
  5. 我们重复第2步到第5步来进行迭代,以此来获取最优的参数θ’。

最后,我们使用一个图来总结一下上述的流程:

元学习—模型不可知元学习(MAML)相关推荐

  1. 小样本学习记录————用于深度网络快速适应的模型不可知元学习(MAML)

    小样本学习记录----MAML用于深度网络快速适应的模型不可知元学习 相关概念 小样本学习(Few-Shot Learning) 元学习(Meta-Learning) MAML思想 MAML算法 论文 ...

  2. 深度学习模型那么多,科学研究选哪个?

    2020-04-26 18:32 导语:深度学习助力科研! 以深度学习为代表的机器学习技术,已经在很大程度颠覆了传统学科的研究方法.然后,对于传统学科的研究人员,机器学习算法繁杂多样,到底哪种方法更适 ...

  3. 受小动物大脑结构启发,研究人员开发出新的深度学习模型:更少神经元,更多智能...

    大数据文摘出品 来源:sciencedaily 编译: 朱科锦.coolboy   从搜索引擎到自动驾驶汽车,人工智能已经进入了我们的日常生活.这与近年来计算能力的巨大提升有关.但是,最新的人工智能研 ...

  4. C++环境下部署深度学习模型方案

    目录 一.问题背景 二.解决方案 2.1 C++调用python 2.2 Python服务接口 2.3 Python转c++(不推荐) 2.4 深度学习部署框架(推荐) 三.总结 3.1 接口形式分类 ...

  5. 用于阿尔茨海默症分期早期检测的多模态深度学习模型

    目前大多数阿尔茨海默症(AD)和轻度认知障碍(MCI)研究使用单一数据模式来预测,例如AD的分期.多种数据模式的融合可以提供AD分期分析的整体视图.因此,我们使用深度学习对成像(磁共振成像(MRI)) ...

  6. AI Earth 深度学习模型替换数值天气预报模型中的参数化方案-大气辐射传输方案

    1.背景 太阳辐射和热辐射是大气和海洋运动的最根本的驱动力.大气辐射传输过程实际上已经可以通过一种叫做LBLRTM的辐射模型精确计算,但是LBLRTM模型同时也最为耗时.因此,有各种各样的辐射传输参数 ...

  7. 在英特尔硬件上部署深度学习模型的无代码方法 OpenVINO 深度学习工作台的三部分系列文章 - CPU AI 第一部

    作者 Taylor, Mary, 翻译 李翊玮 关于该系列 了解如何转换.微调和打包推理就绪的 TensorFlow 模型,该模型针对英特尔®硬件进行了优化,仅使用 Web 浏览器.每一步都在云中使用 ...

  8. 该如何训练好深度学习模型?

    博主不是专业搞竞赛出生,仅依赖博主为数有限的工程经验,从工程实践的方面讨论如何把深度学习模型训练好.如果写的不好,欢迎各位大佬在评论区留言互动.训练好深度学习模型与训练深度学习模型差别还是很大的,训练 ...

  9. Python深度学习(一)深度学习基础

    翻译自Deep Learning With Python(2018) 第一章 深度学习基础:https://www.jianshu.com/p/6c08f4ceab4c 第二章 深度学习的数学构建模块 ...

最新文章

  1. 【技术干货】卷积神经网络中十大拍案叫绝的操作
  2. oracle查看字典结构体,Oracle数据字典的实操
  3. [Linux] 命令行工具
  4. VC++实现获取DNS服务器
  5. ROS中配置主从机需注意的几点
  6. java 窗口GUI
  7. from import 导入时找不到module的解决办法(Python模块包中_init_.py文件的作用)
  8. wxpython按钮形状如何修改_Python图形化界面入门教程 - 使用wxPython自定义表
  9. 使用Java和Scala将Play Framework 2应用程序部署到Openshift
  10. 简述Qt编程中遇到的编码格式问题
  11. Laravel 中asset 函数支持https 协议
  12. 机器学习、深度学习方面书籍收集(持续更新……)
  13. 再来过-docker
  14. 福州化工实验室建设注意隐患分析
  15. WiFi无缝漫游详解
  16. Hugo Travis
  17. 保留至百位并且向上取整 例如:125631 >>125700
  18. linux下的企业级DNS服务器的操作和加速
  19. 共享充电宝方案怎么做
  20. 安卓逆向入门练习之电影天堂APP逆向分析

热门文章

  1. 天猫精灵对接智能设备
  2. python bitwise operator 位运算
  3. ansj分词器的配置
  4. C语言程序设计编程题[七](山西大学876)
  5. DBFS CLI : 02-文件操作相关常用命令
  6. LateX使用笔记(持续更新)
  7. 直播写代码,今晚8点见!
  8. 软件开发过程与项目管理(10.软件项目人员与沟通计划)
  9. 使用xiaopiu常见技巧
  10. 怎么判断数字n是否为2的x次方,即2的幂次呢,比如2,4,8,16,32