谨以此篇记录第一次学习元学习的相关论文+代码。

主要参考:

原文:https://arxiv.org/abs/1703.03400

Model-Agnostic Meta-Learning (MAML)模型介绍及算法详解 - 知乎

MAML 论文及代码阅读笔记 - 知乎

超级详细的两篇文章,从文章讲解到代码实现。结合起来看效果更优

元学习: 学习如何学习【译】

这篇博文,主要看了关于基于优化的元学习方法,公式推导特别清晰。看过公式推导,再结合代码,会发现思路很清晰。

需要注意:

1:元学习,训练集、测试集均由任务构成,而每一个任务都包含{support,query}。

假设:任务池有10,000任务,epoch=1; batch_size=4,即有4个任务; 5way1shot分类任务。

则总需要轮次1(epoch=1)轮。每次处理4个任务(batch_size=4),处理2500(10000/4)次。每个任务构成为:T_i={support set,query_set}, support set大小: 4*(5*1)  ,query_set大小(4*(5*15)).query set 每个类别所取样本数默认为15个。

如下图: 一个方框表示一个样本,A,B,C,D,E表示5种不同的的类别。

2:MAML源代码中将所有样本之间先做成任务,直接在任务池中抽取需要的任务。

3:batchsize,批处理的样本数,在元学习中表示为任务数。

4:在support set上求loss,梯度更新update_step(5)次。在query set 上 fine-tuning 时梯度更新 update_step_test (10)次。

5:计算loss有所不同。

这里针对一个batch(即4个任务)解释。

foward:

针对每个任务,记录第一次更新的情况。先利用第一个任务的support set 计算初始化fast_weights,然后对比初始化参数与fast_weights正确率。

具体到某个任务:(在内部更新update_step次)

1)support set上利用第一次计算的fast_weights 计算loss,求梯度,梯度更新求新的fast_weights.

2)query set 上利用更新的fast_weights计算loss. 这里不需要更新fast_weights。而是记录query样本上的loss以及预测的准确率。

同时更新update_step次。

待到该batch所有任务(4个任务)都计算结束后,将此时的loss汇总,然后更新元学习器的参数。

fine-tuning:

分析方式与foward一致。针对一个batch(即4个任务)解释。

同样的,处理方式。

以上两处与forward一样。

具体到某个任务时,有些差别。在fine-tuning 阶段,针对一个任务,更新update_step_test (10)次.

1)在support set上,利用之前的fast_weights计算loss,计算梯度,更新fast_weights,得到新的fast_weights.

2)在query set 上,此时为x_qry,(此处为1个任务。forward 处为x_qry[i],当时有4个任务),利用新的fast_weights,经过网络预测结果,计算loss。不再需要外层的梯度更新。

记录预测的正确样本数。

6:双层优化,可能指的是,在内部support set 上更新参数:(当前任务参数更新)

theta_pi = theta_pi - train_lr * grad

而在外层更新时,是在query set所有样本上的loss和上更新参数:(元参数更新)

self.meta_optim.zero_grad()
loss_q.backward()
self.meta_optim.step()

个人理解:

虽然MAML一般使用FOMAML(一阶MAML),该方法的源码将梯度更新的最后一次记录下来,这里更新时得到是多次更新的结果(五阶MAML,update_step的次数)

以上仅为个人理解,如有错误还请海涵。另,诚邀各位批评指正。

MAML代码学习记录相关推荐

  1. DAB-Deformable-DETR代码学习记录之模型构建

    DAB-DETR的作者在Deformable-DETR基础上,将DAB-DETR的思想融入到了Deformable-DETR中,取得了不错的成绩.今天博主通过源码来学习下DAB-Deformable- ...

  2. ECCV2022细粒度图像检索SEMICON代码学习记录

    代码链接:GitHub - aassxun/SEMICON 环境配置 # 创建&激活虚拟环境 conda create -n semicon python==3.8.5 conda activ ...

  3. 深度学习+心脏医学图像分割——自动心脏诊断挑战赛(ACDC)项目的代码学习记录

    自己的研究方向是心脏AI相关(心脏MRI+深度学习这样子),最近在学习医学图像分割--自动心脏诊断挑战赛(ACDC)的代码: GitHub - baumgach/acdc_segmenter: Pub ...

  4. 【LVI-SAM代码学习记录】

    文章目录 目录 文章目录 前言 一.思路 二.LIO部分代码阅读 1.imagaProjection() 2.featureTracker() 3.imuPreintegration() 4.mapO ...

  5. OTFS代码学习记录Ⅰ

    该代码是Monash University的几位老师开源的(Raviteja Patchava, Yi Hong, and Emanuele Viterbo)首先感谢一下^_^ 接下来就开始我们的代码 ...

  6. 小样本学习记录————用于深度网络快速适应的模型不可知元学习(MAML)

    小样本学习记录----MAML用于深度网络快速适应的模型不可知元学习 相关概念 小样本学习(Few-Shot Learning) 元学习(Meta-Learning) MAML思想 MAML算法 论文 ...

  7. Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)

    Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...

  8. 2021-01-22学习记录 || 通过二维数组初始化窗体并进行代码重构

    今天主要是通过二维数组将整个界面16个数字块展示出来,并为了下一步添加左移.右移功能创建子类MainFrame继承JFrame类并进行代码重构. 二维数组展示初始化界面 由于2048小游戏需要16个数 ...

  9. AMBA总线协议之AHB学习记录(1)—ahb_bus(附verilog代码)

    目录 0.前言 1.AHB简介 2.ahb_bus实现(verilog) 3.总结反思 & 后面学习计划 0.前言 前段时间粗略过了一下riscv指令集相关内容,并对开源项目tinyriscv ...

  10. Opencv+Python学习记录9:掩膜(掩码)的使用(内附详细代码)

    一,基本概念 OpenCV中的很多函数都会指定一个掩模,也被称为掩码,例如: 计算结果=cv2.add(参数1,参数2,掩模) 当使用掩模参数时,操作只会在掩模值为非空的像素点上执行,并将其他像素点的 ...

最新文章

  1. 认知实习培训第四天总结
  2. linq to sql 插入值,以及如何取回自增的ID
  3. vue/return-in-computed-property Enforce that a return statement is present in computed property
  4. v8引擎和v12引擎_v8和v12发动机的区别
  5. Pytorch有什么节省内存(显存)的小技巧?
  6. [导入][凤穿牡丹][2008精品年代剧][全38集][李小冉 应采儿]
  7. 在LaTeX中添加Visio绘图
  8. recover的用法
  9. WIN10下Apache启动失败
  10. 数据中台到底是什么?
  11. 一文带你了解redux的工作流程——action/reducer/store
  12. 电脑上怎么看主板型号
  13. 安装 python cuda
  14. 老大让我优化数据库,我上来就分库分表,他过来就是一jio
  15. 关于某蔡傅里叶变换课的思考(元旦前更新)
  16. 求正多边形各顶点的坐标(数学)
  17. 正则表达式同时匹配中英文_,还控制长度
  18. [netplus] 初心之让人人能写高性能网络服务器
  19. 中医测试体质的软件,中医体质测试系统,中医九种体质测试,在线测试中医体质,中医体质自测...
  20. spring mvc + xheditor编辑器的使用

热门文章

  1. 【Python】判断多边形的形状为凸多边形还是凹多边形
  2. 深度linux双显卡死机,Deepin配置IntelNvidia双显卡
  3. 产品经理项目流程(四)——需求文档
  4. CentOS7自行搭建KMS服务器
  5. 简述u盘安装计算机系统的方法,电脑系统安装常见的两种方式(U盘)
  6. Postman的安装
  7. HTML页面跳转的5种方法分析介绍
  8. Java和Android笔试题
  9. vue3实现动态组件加载写法
  10. 【线性代数】范德蒙行列式