背景

Meta Learning,又称为 learning to learn,Meta Learning希望使得模型获取一种“学会学习”的能力,使其可以在获取已有“知识”的基础上快速学习新的任务,对于新的类别,只需要少量的样本就能快速学习(Few-shot Learning)。

Few-shot Learning 是 Meta Learning 在监督学习领域的应用。

数据集

早期研究都基于以下两个图像数据集:

Omniglot:https://github.com/brendenlake/omniglot

包含1623个不同的火星文字符,每个字符包含20个手写的case

miniImageNet:https://github.com/yaoyao-liu/mini-imagenet-tools

包含100类共60000张彩色图片,其中每类有600个样本

主流算法

MAML(入门+重要)

2017年发表,到2022年7月12日已经收获493的引用 https://arxiv.org/pdf/1703.03400.pdf

MAML与其说是一个深度学习模型,倒不如说是一个框架,提供一个meta-learner(MAML的精髓所在,learing to learn)用于训练base-learner(根据新数据实际用于预测任务的模型)。

绝大多数深度学习模型都可以作为base-learner无缝嵌入MAML中。

(一)目的

MAML的目的是获取一组更好的模型初始化参数(即让模型自己学会初始化)

可以这么理解:假设我们目前有3个tasks,分别为T 1 , T 2 , T 3 。按照以前模型的训练方式,首先,我们随机初始化模型参数θ,然后开始训练任务T 1 ,接着最小化损失函数L 来更新网络的参数,这样我们就会得到新的参数θ 1 。同理,我们可以接着更新其他两个任务。但以前模型的训练方式,是每个任务都是随机初始化θ开始,每个任务都是独立的。如果我们把三个任务初始化的θ到公用的位置,则不需要更多的梯度更新步骤。MAML就是做这件事的。

(二)专有术语介绍:

构建的任务分为训练任务(Train Task),测试任务(Test Task)。

每个任务都有自己的训练集(Support Set)、测试集( Query Set

N-ways,K-shot(数据中包含N个类别,每个类别有K个样本)

(三)训练流程

以训练 miniImage 数据集为例,按4:1划分数据集

Train Task:从训练集(80 个类,每类 600 个样本)中随机采样 5 个类,每个类 1 个样本(5-way 1-shot),构成Support Set,去学习 learner;然后从训练集的样本(采出的5 个类,每类剩下的样本)中抽 15 个样本采样构成Query Set,用来获得 learner 的 loss,去学习 meta leaner。

Test Task:(20 个类,每类 600 个样本)中随机采样5个类,每个类1 个样本(与training阶段一致,5-way 1-shot),构成支撑集 Support Set,去学习 learner;然后从测试集剩余的样本(采出的5 个类,每类剩下的样本)中抽 15 个样本采样构成 Query Set,用来获得 learner 的参数,进而得到预测的类别概率。

(四)实现代码

## 网络构建部分: refer: https://github.com/dragen1860/MAML-TensorFlow
​
#################################################
# 任务描述:5-ways,1-shot图像分类任务,图像统一处理成 84 * 84 * 3 = 21168的尺寸。
# support set:5 * 1
# query set:5 * 15
# 训练取1个batch的任务:batch size:4
# 对训练任务进行训练时,更新5次:K = 5
#################################################
​
print(support_x) # (4, 5, 21168)
print(query_x) # (4, 75, 21168)
print(support_y) # (4, 5, 5)
print(query_y) # (4, 75, 5)
print(meta_batchsz) # 4
print(K) # 5
​
model = MAML()
model.build(support_x, support_y, query_x, query_y, K, meta_batchsz, mode='train')
​
class MAML:def __init__(self):passdef build(self, support_xb, support_yb, query_xb, query_yb, K, meta_batchsz, mode='train'):""":param support_xb: [4, 5, 84*84*3] :param support_yb: [4, 5, n-way]:param query_xb:  [4, 75, 84*84*3]:param query_yb: [4, 75, n-way]:param K:  训练任务的网络更新步数:param meta_batchsz: 任务数,4"""
​self.weights = self.conv_weights() # 创建或者复用网络参数;训练任务对应的网络复用meta网络的参数training = True if mode is 'train' else False      def meta_task(input):""":param support_x:   [setsz, 84*84*3] (5, 21168):param support_y:   [setsz, n-way] (5, 5):param query_x:     [querysz, 84*84*3] (75, 21168):param query_y:     [querysz, n-way] (75, 5):param training:    training or not, for batch_norm:return:"""
​support_x, support_y, query_x, query_y = inputquery_preds, query_losses, query_accs = [], [], [] # 子网络更新K次,记录每一次queryset的结果## 第0次对网络进行更新support_pred = self.forward(support_x, self.weights, training) # 前向计算support setsupport_loss = tf.nn.softmax_cross_entropy_with_logits(logits=support_pred, labels=support_y) # support set losssupport_acc = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(support_pred, dim=1), axis=1),tf.argmax(support_y, axis=1))grads = tf.gradients(support_loss, list(self.weights.values())) # 计算support set的梯度gvs = dict(zip(self.weights.keys(), grads))# 使用support set的梯度计算的梯度更新参数,theta_pi = theta - alpha * gradsfast_weights = dict(zip(self.weights.keys(), \[self.weights[key] - self.train_lr * gvs[key] for key in self.weights.keys()]))
​# 使用梯度更新后的参数对quert set进行前向计算query_pred = self.forward(query_x, fast_weights, training)query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y)query_preds.append(query_pred)query_losses.append(query_loss)# 第1到 K-1次对网络进行更新for _ in range(1, K):           loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.forward(support_x, fast_weights, training),labels=support_y)grads = tf.gradients(loss, list(fast_weights.values()))gvs = dict(zip(fast_weights.keys(), grads))fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - self.train_lr * gvs[key]for key in fast_weights.keys()]))query_pred = self.forward(query_x, fast_weights, training)query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y)# 子网络更新K次,记录每一次queryset的结果query_preds.append(query_pred)query_losses.append(query_loss)
​for i in range(K):query_accs.append(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(query_preds[i], dim=1), axis=1),tf.argmax(query_y, axis=1)))result = [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs]return result
​# return: [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs]out_dtype = [tf.float32, tf.float32, tf.float32, [tf.float32] * K, [tf.float32] * K, [tf.float32] * K]result = tf.map_fn(meta_task, elems=(support_xb, support_yb, query_xb, query_yb),dtype=out_dtype, parallel_iterations=meta_batchsz, name='map_fn')support_pred_tasks, support_loss_tasks, support_acc_tasks, \query_preds_tasks, query_losses_tasks, query_accs_tasks = result
​if mode is 'train':self.support_loss = support_loss = tf.reduce_sum(support_loss_tasks) / meta_batchszself.query_losses = query_losses = [tf.reduce_sum(query_losses_tasks[j]) / meta_batchszfor j in range(K)]self.support_acc = support_acc = tf.reduce_sum(support_acc_tasks) / meta_batchszself.query_accs = query_accs = [tf.reduce_sum(query_accs_tasks[j]) / meta_batchszfor j in range(K)]
​# 更新meta网络,只使用了第 K步的query loss。这里应该是个超参,更新几步可以调调optimizer = tf.train.AdamOptimizer(self.meta_lr, name='meta_optim')gvs = optimizer.compute_gradients(self.query_losses[-1])# def ********

参考:

1.原论文:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networkshttps://arxiv.org/pdf/1703.03400.pdf

2.小样本学习(Few-shot Learning)综述小样本学习(Few-shot Learning)综述

3.一文入门元学习(Meta-Learning)(附代码)一文入门元学习(Meta-Learning)(附代码) - 知乎

4.Model-Agnostic Meta-Learning (MAML)模型介绍及算法详解Model-Agnostic Meta-Learning (MAML)模型介绍及算法详解 - 知乎

5.从代码上解析Meta-learning从代码上解析Meta-learning_洛克-李的博客-CSDN博客

meta—learning调研及MAML概述相关推荐

  1. Meta Learning在NLP领域的应用

    Hi,这里是哈林,今天来跟大家聊一聊Meta Learning在NLP领域的一些应用. 哈林之前在学校科研的方向是NLP,个人对如何将先进的机器学习算法应用到NLP场景很感兴趣(因为好水paper), ...

  2. 【李宏毅2020 ML/DL】P88-96 Meta Learning – MAML | Reptile

    我已经有两年 ML 经历,这系列课主要用来查缺补漏,会记录一些细节的.自己不知道的东西. 本节内容综述 元学习就是 Learn to learn ,让机器变成 a better learner .Me ...

  3. 8.7 Meta learning元学习全面理解、MAML、Reptile

    文章目录 1.介绍 为什么需要元学习? few-shot learning reinforcement learning 2.概念 3. Meta learning 三个步骤 定义一组learning ...

  4. 【李宏毅2020 ML/DL】P97-98 More about Meta Learning

    我已经有两年 ML 经历,这系列课主要用来查缺补漏,会记录一些细节的.自己不知道的东西. 本节内容综述 本节课由助教 陈建成 讲解. 本节 Outline 见小细节. 首先是 What is meta ...

  5. Meta Learning 元学习

    来源:火炉课堂 | 元学习(meta-learning)到底是什么鬼?bilibili 文章目录 1. 元学习概述 Meta 的含义 从 Machine Learning 到 Meta-Learnin ...

  6. 理解Meta Learning 元学习,这篇文章就够了!

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 AI编辑:我是小将 本文作者:谢杨易 1 什么是meta lear ...

  7. 元学习Meta learning深入理解

    目录 基本理解 元学习与传统的机器学习不同在哪里? 基本思想 MAML MAML与pre-training有什么区别呢? 1. 损失函数不同 2. 优化思想不同 MAML的优点及特点 MAML工作机理 ...

  8. Meta-Weight-Net[NIPS‘2019]:关于元学习/域自适应(meta learning/domain adaptation)优化噪声标签与类别不平衡的问题

    目录 研究背景 一.为什么存在类别不平衡现象? 二. Meta-Weight-Net[NIPS'2019] 1.Focal Loss 2.self-pacd learning 3.Meta-Weigh ...

  9. [转载]Meta Learning单排小教学

    原文链接:Meta Learning单排小教学 虽然Meta Learning现在已经非常火了,但是还有很多小伙伴对于Meta Learning不是特别理解.考虑到我的这个AI游乐场将充斥着Meta ...

  10. 强化学习-把元学习(Meta Learning)一点一点讲给你听

    目录 0 Write on the front 1 What is meta learning? 2 Meta Learning 2.1 Define a set of learning algori ...

最新文章

  1. linux的一些机制Signal, Fork,
  2. 【SpringCloud】Ribbon:负载均衡
  3. Django之用户上传文件的参数配置
  4. HTML/BODY的背景渲染原理
  5. java三年,Java开发三年,你不得不了解的JVM(一)
  6. java getparameter 乱码_request.getParameter(“参数名”) 中文乱码解决方法
  7. 把java程序打包成.exe
  8. 【渝粤教育】国家开放大学2018年春季 0275-21T内科护理学 参考试题
  9. BIM学习笔记(一)
  10. Matlab系列教程_数值计算_求方差和标准差
  11. 简洁404页面源码 | 自适应404页面HTML源代码下载
  12. WWW15年:改变世界的15个网站
  13. linux系统添加打印机失败,Linux下设置网络打印机
  14. a标签去掉下划线以及字体颜色
  15. java定时任务不执行_【SpringBoot】为什么我的定时任务不执行?
  16. ruby on rais3 入门——环境搭建详细步骤(windows下)
  17. R星安装不完全无法载入social club(错误码:1)解决办法
  18. 自己做游戏(一)-PhotonServer配置
  19. Windows 11 应用商店打不开,点了没反应解决办法,亲测可用
  20. 究竟什么是CRM(客户关系管理系统)呢?

热门文章

  1. html文字logo
  2. 项管:配置管理、变更管理、文档管理、知识管理及其他
  3. 期货反跟单行业里的恶意剥削
  4. ISIS协议基础知识
  5. JS中定义对象和集合
  6. 通用的电子商务商城后台管理界面模板——后台
  7. 单片机段式LCD驱动教程
  8. erp系统实施方案会遇到哪些问题?
  9. JAVA实现生成GIF动态图加文字(完整版无License带锯齿优化处理)
  10. 机器学习——Azure机器学习模型在线搭建实验原理+详细操作步骤+分析(以UCI数据库的数据为例)