元学习系列文章

  1. optimization based meta-learning

    1. 《Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks》 论文翻译笔记
    2. 元学习方向 optimization based meta learning 之 MAML论文详细解读
    3. MAML 源代码解释说明 (一)
    4. MAML 源代码解释说明 (二)
    5. 元学习之《On First-Order Meta-Learning Algorithms》论文详细解读
    6. 元学习之《OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING》论文详细解读
  2. metric based meta-learning
    1. 元学习之《Matching Networks for One Shot Learning》代码解读
  3. model based meta-learning: 待更新…

文章目录

  • 前言
  • Matching Network
    • 特征提取模块
    • Full Context Embeddings 模块
    • 距离度量模块
    • attention模块
  • 实验结果
  • 参考资料

前言

此篇是 metric-based metalearning 的第一篇,所谓 metric-based 即通过某种度量方式来判断测试样本和训练集中的哪个样本最相似,进而把最相似样本的 label 作为测试样本的 label,总体思想有点类似于 KNN。

Matching Network

此篇论文的核心思想就是构造了一个端到端的最近邻分类器,并通过 meta-learning 的训练,可以使得该分类器在新的少样本任务上快速适应,并对该任务的测试样本进行预测。下图是 Matching Network 的网络结构:

初看论文时看到这个图时会比较懵,以及论文里的各种公式也让人摸不着头脑,但是看作者的代码就能理清楚这里面的结构了,话不多上代码。

    def build(self, support_set_image, support_set_label, image):"""the main graph of matching networks"""# image [None, 28, 28, 1] -> [None, 1*1*64]# 1. 原始图片特征提取模块image_encoded = self.image_encoder(image)   # (batch_size, 64)#[(batch_size, 64), ] list 长度是 n*ksupport_set_image_encoded = [self.image_encoder(i) for i in tf.unstack(support_set_image, axis=1)]# 2. Full Context Embeddings 模块if self.use_fce:g_embedding = self.fce_g(support_set_image_encoded)     # (n * k, batch_size, 64)f_embedding = self.fce_f(image_encoded, g_embedding)    # (batch_size, 64)else:g_embedding = tf.stack(support_set_image_encoded)       # (n * k, batch_size, 64)f_embedding = image_encoded                             # (batch_size, 64)# c(f(x_hat), g(x_i))# 3. 距离度量模块# g 已知 label,f 是 test,未知 labelembeddings_similarity = self.cosine_similarity(f_embedding, g_embedding) # (batch_size, n * k)# compute softmax on similarity to get a(x_hat, x_i)# 4. attention 模块attention = tf.nn.softmax(embeddings_similarity)# \hat{y} = \sum_{i=1}^{k} a(\hat{x}, x_i)y_i# [batch_size, 1, n*k] * [batch_size,n*k, n] = [batch_size, 1, n]y_hat = tf.matmul(tf.expand_dims(attention, 1), tf.one_hot(support_set_label, self.n_way))# [batch_size,1,n] -> [batch_size, n]self.logits = tf.squeeze(y_hat)   # (batch_size, n)self.pred = tf.argmax(self.logits, 1)

整个网络结构可以分为四个模块:

  1. 原始图片特征提取模块
  2. Full Context Embeddings 模块
  3. 距离度量模块
  4. attention 模块

特征提取模块

特征提取模块比较简单,就是用一个4层的卷积网络,提取原始图片的全连接层特征,全连接层维度是64,即卷积网络后的输出shape是 [batch_size, 64]。该卷积网络的代码如下:

    def image_encoder(self, image):"""the embedding function for image (potentially f = g)For omniglot it's a simple 4 layer ConvNet, for mini-imagenet it's VGG or Inception"""with slim.arg_scope([slim.conv2d], num_outputs=64, kernel_size=3, normalizer_fn=slim.batch_norm):net = slim.conv2d(image)net = slim.max_pool2d(net, [2, 2])net = slim.conv2d(net)net = slim.max_pool2d(net, [2, 2])net = slim.conv2d(net)net = slim.max_pool2d(net, [2, 2])net = slim.conv2d(net)net = slim.max_pool2d(net, [2, 2])return tf.reshape(net, [-1, 1 * 1 * 64])

Full Context Embeddings 模块

此模块是论文的重点也是创新的地方,即如何对某个抽样任务的训练集样本进行 embedding 得到 gθg_\thetagθ,如何对该任务的测试样本进行 embedding 得到 fθf_\thetafθ

  1. gθg_\thetagθ
    其中对训练样本,是输入到一个双向LSTM网络中,LSTM的前向和后向隐藏层单元数都是 32,LSTM 网络的输出是一个长为 n*k的list,list中每个元素的shape是 (batch_size,64)。最后将输入embedding和 LSTM output 相加,相加后的结果即是 gθg_\thetagθ,相当于做了一个 skip connection的操作。gθg_\thetagθ的实现过程如下:
   def fce_g(self, encoded_x_i):"""the fully conditional embedding function gThis is a bi-directional LSTM, g(x_i, S) = h_i(->) + h_i(<-) + g'(x_i) where g' is the image encoderFor omniglot, this is not used.encoded_x_i: g'(x_i) in the equation.   length n * k list of (batch_size ,64)"""fw_cell = rnn.BasicLSTMCell(32) # 32 is half of 64 (output from cnn)bw_cell = rnn.BasicLSTMCell(32)# outputs: [(batch_size, 64), (batch_size, 64), ...], list 长度是 n*koutputs, state_fw, state_bw = rnn.static_bidirectional_rnn(fw_cell, bw_cell, encoded_x_i, dtype=tf.float32)# [n*k, batch_size, 64] + [n*k, batch_size, 64]return tf.add(tf.stack(encoded_x_i), tf.stack(outputs))

其中需要注意的是 batch_size 是随机抽样的 batch_size 个task,每个 task 共有 n*k 个训练样本,n值该task是n分类任务,k指每个类别共有k个样本。实际训练时,相当于LSTM网络共有 n*k个时刻,每个时刻的输入shape都是(batch_size,64),每个时刻的前向输出shape是(batch_size,32),后向输出shape是(batch_size,32)。LSTM的训练过程示意图如下:

  1. fθf_\thetafθ

对测试任务的样本求 embedding 时,同样也是输入到一个LSTM网络中,只不过这个LSTM是有固定步数的单向lstm,共有 processing_steps步,processing_steps可以提取设定。特殊的地方是,在每步的计算中加了 attention 部分,即让上一步的输出状态 h 乘以 gθg_\thetagθ。最后将最后一个时刻 lstm 网络的 softmax 输出作为 fθf_\thetafθ。此部分的代码实现如下:

    def fce_f(self, encoded_x, g_embedding):"""the fully conditional embedding function fThis is just a vanilla LSTM with attention where the input at each time step is constant and the hidden stateis a function of previous hidden state but also a concatenated readout vector.For omniglot, this is not used.encoded_x: f'(x_hat) in equation (3) in paper appendix A.1.     (batch_size, 64)g_embedding: g(x_i) in equation (5), (6) in paper appendix A.1. (n * k, batch_size, 64)"""cell = rnn.BasicLSTMCell(64)prev_state = cell.zero_state(self.batch_size, tf.float32) # state[0] is c, state[1] is hfor step in xrange(self.processing_steps):output, state = cell(encoded_x, prev_state) # output: (batch_size, 64)h_k = tf.add(output, encoded_x) # (batch_size, 64)content_based_attention = tf.nn.softmax(tf.multiply(prev_state[1], g_embedding))    # (n * k, batch_size, 64)r_k = tf.reduce_sum(tf.multiply(content_based_attention, g_embedding), axis=0)      # (batch_size, 64)prev_state = rnn.LSTMStateTuple(state[0], tf.add(h_k, r_k))

距离度量模块

现在有了 gθg_\thetagθfθf_\thetafθ,其中gθg_\thetagθd shape是(n*k, batch_size, 64),fθf_\thetafθ的shape是(batch_size,64)。距离度量模块就是针对每个 task,求出test和train中每个样本的余弦距离,最后输出shape为(batch_size,n*k)。余弦相似性的代码实现如下:

    def cosine_similarity(self, target, support_set):"""the c() function that calculate the cosine similarity between (embedded) support set and (embedded) targetnote: the author uses one-sided cosine similarity as zergylord said in his repo (zergylord/oneshot)"""#target_normed = tf.nn.l2_normalize(target, 1) # (batch_size, 64)target_normed = targetsup_similarity = []for i in tf.unstack(support_set):i_normed = tf.nn.l2_normalize(i, 1) # (batch_size, 64)similarity = tf.matmul(tf.expand_dims(target_normed, 1), tf.expand_dims(i_normed, 2)) # (batch_size, )sup_similarity.append(similarity)return tf.squeeze(tf.stack(sup_similarity, axis=1)) # (batch_size, n * k)

attention模块

此模块将求出每个测试样本的label。所谓 attention,其实很简单,就是将上一步求出的相似度结果做了 softmax 激活操作,然后将最大值处的train label作为 test label。此模块的代码实现如下:

        embeddings_similarity = self.cosine_similarity(f_embedding, g_embedding) # (batch_size, n * k)# compute softmax on similarity to get a(x_hat, x_i)# 4. attention 模块attention = tf.nn.softmax(embeddings_similarity)# \hat{y} = \sum_{i=1}^{k} a(\hat{x}, x_i)y_i# [batch_size, 1, n*k] * [batch_size,n*k, n] = [batch_size, 1, n]y_hat = tf.matmul(tf.expand_dims(attention, 1), tf.one_hot(support_set_label, self.n_way))# [batch_size,1,n] -> [batch_size, n]self.logits = tf.squeeze(y_hat)   # (batch_size, n)self.pred = tf.argmax(self.logits, 1)

实验结果


参考资料

  • https://github.com/markdtw/matching-networks
  • https://github.com/karpathy/paper-notes/blob/master/matching_networks.md

元学习之《Matching Networks for One Shot Learning》代码解读相关推荐

  1. 一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述

    <繁凡的深度学习笔记>第 15 章 元学习详解 (上)万字中文综述(DL笔记整理系列) 3043331995@qq.com https://fanfansann.blog.csdn.net ...

  2. 【转载】Few-shot learning(少样本学习)和 Meta-learning(元学习)概述

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_37589575/arti ...

  3. 元学习:实现通用人工智能的关键!

    1 前言 Meta Learning(元学习)或者叫做 Learning to Learn(学会学习)已经成为继Reinforcement Learning(增强学习)之后又一个重要的研究分支(以后仅 ...

  4. 元学习与小样本学习 | (2) Few-shot Learning 综述

    原文地址 分类非常常见,但如果每个类只有几个标注样本,怎么办呢? 笔者所在的阿里巴巴小蜜北京团队就面临这个挑战.我们打造了一个智能对话开发平台--Dialog Studio,以赋能第三方开发者来开发各 ...

  5. Few-shot learning(少样本学习)和 Meta-learning(元学习)概述

    目录 (一)Few-shot learning(少样本学习) 1. 问题定义 2. 解决方法 2.1 数据增强和正则化 2.2 Meta-learning(元学习) (二)Meta-learning( ...

  6. 小样本学习元学习经典论文整理||持续更新

      本文整理了近些年来有关小样本学习的经典文章,并附上了原文下载链接以及论文解读链接.关注公众号"深视",回复"小样本学习",可以打包下载全部文章.该文我会持续 ...

  7. (转)Paper list of Meta Learning/ Learning to Learn/ One Shot Learning/ Lifelong Learning

    Meta Learning/ Learning to Learn/ One Shot Learning/ Lifelong Learning 2018-08-03 19:16:56 本文转自:http ...

  8. Meta Learning 元学习

    来源:火炉课堂 | 元学习(meta-learning)到底是什么鬼?bilibili 文章目录 1. 元学习概述 Meta 的含义 从 Machine Learning 到 Meta-Learnin ...

  9. 元学习概述(Meta-Learning)

    转载自: 凉爽的安迪-深度瞎学 一文入门元学习(Meta-Learning) 写在前面:迄今为止,本文应该是网上介绍[元学习(Meta-Learning)]最通俗易懂的文章了( 保命),主要目的是想对 ...

最新文章

  1. Java面试题(一)部分题目
  2. 【正一专栏】最好的回击是打得你好无脾气
  3. 【NLP】基于预训练的中文NLP工具介绍:ltp 和 fastHan
  4. 【转载保存】webCollector使用教程
  5. C++_程序内存模型_内存四区_代码区_全局区_每种区域都存放什么样的变量---C++语言工作笔记028
  6. MinIO环境搭建及使用
  7. 1.UNIX 环境高级编程--UNIX基础知识
  8. 排序算法基础+冒泡排序+冒泡排序的小优化
  9. python画spc控制图_如何选择最适合我们的SPC控制图?
  10. 黄永成-thinkphp讲解-个人博客讲解26集
  11. matlab地震振幅属性分析,洛马普列塔地震分析 - MATLAB Simulink Example - MathWorks 中国...
  12. ​LeetCode刷题实战371:两整数之和
  13. Non-local Neural Networks论文理解
  14. 【考前冲刺整理】20220812
  15. 谷粒商城-分布式事务
  16. 遥感深度学习数据集汇总(更新中)
  17. java8 stream map flatMap
  18. NOIP 2012初赛普及组C/C++答案详解
  19. hadoop 错误锦集
  20. 简述java的发展历史,22年最新

热门文章

  1. 饭谈:你凭什么觉得自己的简历很好?就凭有面试?
  2. 在线教育平台的数据分析——业务流程指标的计算
  3. 新春特辑 | 新基建专题合辑 报告下载
  4. 阿里云物联网平台HTTP连接通信
  5. 玩转TM4C1294XL(2)——建立Keil工程模板
  6. 王世吹摩托车是假的吧?中国达人秀上吹了,现实没有
  7. java抽组件_GitHub - ysc/HtmlExtractor: HtmlExtractor是一个Java实现的基于模板的网页结构化信息精准抽取组件。...
  8. java计算机毕业设计中小型超市管理系统源码+数据库+系统+lw文档+mybatis+运行部署
  9. 黑龙江省大兴安岭地区谷歌高清卫星地图下载
  10. 办营业执照时公司经营范围变更注意事项