元学习之《Matching Networks for One Shot Learning》代码解读
元学习系列文章
- optimization based meta-learning
- 《Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks》 论文翻译笔记
- 元学习方向 optimization based meta learning 之 MAML论文详细解读
- MAML 源代码解释说明 (一)
- MAML 源代码解释说明 (二)
- 元学习之《On First-Order Meta-Learning Algorithms》论文详细解读
- 元学习之《OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING》论文详细解读
- metric based meta-learning
- 元学习之《Matching Networks for One Shot Learning》代码解读
- model based meta-learning: 待更新…
文章目录
- 前言
- Matching Network
- 特征提取模块
- Full Context Embeddings 模块
- 距离度量模块
- attention模块
- 实验结果
- 参考资料
前言
此篇是 metric-based metalearning 的第一篇,所谓 metric-based 即通过某种度量方式来判断测试样本和训练集中的哪个样本最相似,进而把最相似样本的 label 作为测试样本的 label,总体思想有点类似于 KNN。
Matching Network
此篇论文的核心思想就是构造了一个端到端的最近邻分类器,并通过 meta-learning 的训练,可以使得该分类器在新的少样本任务上快速适应,并对该任务的测试样本进行预测。下图是 Matching Network 的网络结构:
初看论文时看到这个图时会比较懵,以及论文里的各种公式也让人摸不着头脑,但是看作者的代码就能理清楚这里面的结构了,话不多上代码。
def build(self, support_set_image, support_set_label, image):"""the main graph of matching networks"""# image [None, 28, 28, 1] -> [None, 1*1*64]# 1. 原始图片特征提取模块image_encoded = self.image_encoder(image) # (batch_size, 64)#[(batch_size, 64), ] list 长度是 n*ksupport_set_image_encoded = [self.image_encoder(i) for i in tf.unstack(support_set_image, axis=1)]# 2. Full Context Embeddings 模块if self.use_fce:g_embedding = self.fce_g(support_set_image_encoded) # (n * k, batch_size, 64)f_embedding = self.fce_f(image_encoded, g_embedding) # (batch_size, 64)else:g_embedding = tf.stack(support_set_image_encoded) # (n * k, batch_size, 64)f_embedding = image_encoded # (batch_size, 64)# c(f(x_hat), g(x_i))# 3. 距离度量模块# g 已知 label,f 是 test,未知 labelembeddings_similarity = self.cosine_similarity(f_embedding, g_embedding) # (batch_size, n * k)# compute softmax on similarity to get a(x_hat, x_i)# 4. attention 模块attention = tf.nn.softmax(embeddings_similarity)# \hat{y} = \sum_{i=1}^{k} a(\hat{x}, x_i)y_i# [batch_size, 1, n*k] * [batch_size,n*k, n] = [batch_size, 1, n]y_hat = tf.matmul(tf.expand_dims(attention, 1), tf.one_hot(support_set_label, self.n_way))# [batch_size,1,n] -> [batch_size, n]self.logits = tf.squeeze(y_hat) # (batch_size, n)self.pred = tf.argmax(self.logits, 1)
整个网络结构可以分为四个模块:
- 原始图片特征提取模块
- Full Context Embeddings 模块
- 距离度量模块
- attention 模块
特征提取模块
特征提取模块比较简单,就是用一个4层的卷积网络,提取原始图片的全连接层特征,全连接层维度是64,即卷积网络后的输出shape是 [batch_size, 64]。该卷积网络的代码如下:
def image_encoder(self, image):"""the embedding function for image (potentially f = g)For omniglot it's a simple 4 layer ConvNet, for mini-imagenet it's VGG or Inception"""with slim.arg_scope([slim.conv2d], num_outputs=64, kernel_size=3, normalizer_fn=slim.batch_norm):net = slim.conv2d(image)net = slim.max_pool2d(net, [2, 2])net = slim.conv2d(net)net = slim.max_pool2d(net, [2, 2])net = slim.conv2d(net)net = slim.max_pool2d(net, [2, 2])net = slim.conv2d(net)net = slim.max_pool2d(net, [2, 2])return tf.reshape(net, [-1, 1 * 1 * 64])
Full Context Embeddings 模块
此模块是论文的重点也是创新的地方,即如何对某个抽样任务的训练集样本进行 embedding 得到 gθg_\thetagθ,如何对该任务的测试样本进行 embedding 得到 fθf_\thetafθ。
- gθg_\thetagθ
其中对训练样本,是输入到一个双向LSTM网络中,LSTM的前向和后向隐藏层单元数都是 32,LSTM 网络的输出是一个长为n*k
的list,list中每个元素的shape是 (batch_size,64)。最后将输入embedding和 LSTM output 相加,相加后的结果即是 gθg_\thetagθ,相当于做了一个 skip connection的操作。gθg_\thetagθ的实现过程如下:
def fce_g(self, encoded_x_i):"""the fully conditional embedding function gThis is a bi-directional LSTM, g(x_i, S) = h_i(->) + h_i(<-) + g'(x_i) where g' is the image encoderFor omniglot, this is not used.encoded_x_i: g'(x_i) in the equation. length n * k list of (batch_size ,64)"""fw_cell = rnn.BasicLSTMCell(32) # 32 is half of 64 (output from cnn)bw_cell = rnn.BasicLSTMCell(32)# outputs: [(batch_size, 64), (batch_size, 64), ...], list 长度是 n*koutputs, state_fw, state_bw = rnn.static_bidirectional_rnn(fw_cell, bw_cell, encoded_x_i, dtype=tf.float32)# [n*k, batch_size, 64] + [n*k, batch_size, 64]return tf.add(tf.stack(encoded_x_i), tf.stack(outputs))
其中需要注意的是 batch_size 是随机抽样的 batch_size 个task,每个 task 共有 n*k
个训练样本,n值该task是n分类任务,k指每个类别共有k个样本。实际训练时,相当于LSTM网络共有 n*k
个时刻,每个时刻的输入shape都是(batch_size,64),每个时刻的前向输出shape是(batch_size,32),后向输出shape是(batch_size,32)。LSTM的训练过程示意图如下:
- fθf_\thetafθ
对测试任务的样本求 embedding 时,同样也是输入到一个LSTM网络中,只不过这个LSTM是有固定步数的单向lstm,共有 processing_steps
步,processing_steps
可以提取设定。特殊的地方是,在每步的计算中加了 attention 部分,即让上一步的输出状态 h 乘以 gθg_\thetagθ。最后将最后一个时刻 lstm 网络的 softmax 输出作为 fθf_\thetafθ。此部分的代码实现如下:
def fce_f(self, encoded_x, g_embedding):"""the fully conditional embedding function fThis is just a vanilla LSTM with attention where the input at each time step is constant and the hidden stateis a function of previous hidden state but also a concatenated readout vector.For omniglot, this is not used.encoded_x: f'(x_hat) in equation (3) in paper appendix A.1. (batch_size, 64)g_embedding: g(x_i) in equation (5), (6) in paper appendix A.1. (n * k, batch_size, 64)"""cell = rnn.BasicLSTMCell(64)prev_state = cell.zero_state(self.batch_size, tf.float32) # state[0] is c, state[1] is hfor step in xrange(self.processing_steps):output, state = cell(encoded_x, prev_state) # output: (batch_size, 64)h_k = tf.add(output, encoded_x) # (batch_size, 64)content_based_attention = tf.nn.softmax(tf.multiply(prev_state[1], g_embedding)) # (n * k, batch_size, 64)r_k = tf.reduce_sum(tf.multiply(content_based_attention, g_embedding), axis=0) # (batch_size, 64)prev_state = rnn.LSTMStateTuple(state[0], tf.add(h_k, r_k))
距离度量模块
现在有了 gθg_\thetagθ和fθf_\thetafθ,其中gθg_\thetagθd shape是(n*k, batch_size, 64)
,fθf_\thetafθ的shape是(batch_size,64)
。距离度量模块就是针对每个 task,求出test和train中每个样本的余弦距离,最后输出shape为(batch_size,n*k)
。余弦相似性的代码实现如下:
def cosine_similarity(self, target, support_set):"""the c() function that calculate the cosine similarity between (embedded) support set and (embedded) targetnote: the author uses one-sided cosine similarity as zergylord said in his repo (zergylord/oneshot)"""#target_normed = tf.nn.l2_normalize(target, 1) # (batch_size, 64)target_normed = targetsup_similarity = []for i in tf.unstack(support_set):i_normed = tf.nn.l2_normalize(i, 1) # (batch_size, 64)similarity = tf.matmul(tf.expand_dims(target_normed, 1), tf.expand_dims(i_normed, 2)) # (batch_size, )sup_similarity.append(similarity)return tf.squeeze(tf.stack(sup_similarity, axis=1)) # (batch_size, n * k)
attention模块
此模块将求出每个测试样本的label。所谓 attention,其实很简单,就是将上一步求出的相似度结果做了 softmax 激活操作,然后将最大值处的train label作为 test label。此模块的代码实现如下:
embeddings_similarity = self.cosine_similarity(f_embedding, g_embedding) # (batch_size, n * k)# compute softmax on similarity to get a(x_hat, x_i)# 4. attention 模块attention = tf.nn.softmax(embeddings_similarity)# \hat{y} = \sum_{i=1}^{k} a(\hat{x}, x_i)y_i# [batch_size, 1, n*k] * [batch_size,n*k, n] = [batch_size, 1, n]y_hat = tf.matmul(tf.expand_dims(attention, 1), tf.one_hot(support_set_label, self.n_way))# [batch_size,1,n] -> [batch_size, n]self.logits = tf.squeeze(y_hat) # (batch_size, n)self.pred = tf.argmax(self.logits, 1)
实验结果
参考资料
- https://github.com/markdtw/matching-networks
- https://github.com/karpathy/paper-notes/blob/master/matching_networks.md
元学习之《Matching Networks for One Shot Learning》代码解读相关推荐
- 一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
<繁凡的深度学习笔记>第 15 章 元学习详解 (上)万字中文综述(DL笔记整理系列) 3043331995@qq.com https://fanfansann.blog.csdn.net ...
- 【转载】Few-shot learning(少样本学习)和 Meta-learning(元学习)概述
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_37589575/arti ...
- 元学习:实现通用人工智能的关键!
1 前言 Meta Learning(元学习)或者叫做 Learning to Learn(学会学习)已经成为继Reinforcement Learning(增强学习)之后又一个重要的研究分支(以后仅 ...
- 元学习与小样本学习 | (2) Few-shot Learning 综述
原文地址 分类非常常见,但如果每个类只有几个标注样本,怎么办呢? 笔者所在的阿里巴巴小蜜北京团队就面临这个挑战.我们打造了一个智能对话开发平台--Dialog Studio,以赋能第三方开发者来开发各 ...
- Few-shot learning(少样本学习)和 Meta-learning(元学习)概述
目录 (一)Few-shot learning(少样本学习) 1. 问题定义 2. 解决方法 2.1 数据增强和正则化 2.2 Meta-learning(元学习) (二)Meta-learning( ...
- 小样本学习元学习经典论文整理||持续更新
本文整理了近些年来有关小样本学习的经典文章,并附上了原文下载链接以及论文解读链接.关注公众号"深视",回复"小样本学习",可以打包下载全部文章.该文我会持续 ...
- (转)Paper list of Meta Learning/ Learning to Learn/ One Shot Learning/ Lifelong Learning
Meta Learning/ Learning to Learn/ One Shot Learning/ Lifelong Learning 2018-08-03 19:16:56 本文转自:http ...
- Meta Learning 元学习
来源:火炉课堂 | 元学习(meta-learning)到底是什么鬼?bilibili 文章目录 1. 元学习概述 Meta 的含义 从 Machine Learning 到 Meta-Learnin ...
- 元学习概述(Meta-Learning)
转载自: 凉爽的安迪-深度瞎学 一文入门元学习(Meta-Learning) 写在前面:迄今为止,本文应该是网上介绍[元学习(Meta-Learning)]最通俗易懂的文章了( 保命),主要目的是想对 ...
最新文章
- Java面试题(一)部分题目
- 【正一专栏】最好的回击是打得你好无脾气
- 【NLP】基于预训练的中文NLP工具介绍:ltp 和 fastHan
- 【转载保存】webCollector使用教程
- C++_程序内存模型_内存四区_代码区_全局区_每种区域都存放什么样的变量---C++语言工作笔记028
- MinIO环境搭建及使用
- 1.UNIX 环境高级编程--UNIX基础知识
- 排序算法基础+冒泡排序+冒泡排序的小优化
- python画spc控制图_如何选择最适合我们的SPC控制图?
- 黄永成-thinkphp讲解-个人博客讲解26集
- matlab地震振幅属性分析,洛马普列塔地震分析
- MATLAB Simulink Example
- MathWorks 中国...
- ​LeetCode刷题实战371:两整数之和
- Non-local Neural Networks论文理解
- 【考前冲刺整理】20220812
- 谷粒商城-分布式事务
- 遥感深度学习数据集汇总(更新中)
- java8 stream map flatMap
- NOIP 2012初赛普及组C/C++答案详解
- hadoop 错误锦集
- 简述java的发展历史,22年最新
热门文章
- 饭谈:你凭什么觉得自己的简历很好?就凭有面试?
- 在线教育平台的数据分析——业务流程指标的计算
- 新春特辑 | 新基建专题合辑 报告下载
- 阿里云物联网平台HTTP连接通信
- 玩转TM4C1294XL(2)——建立Keil工程模板
- 王世吹摩托车是假的吧?中国达人秀上吹了,现实没有
- java抽组件_GitHub - ysc/HtmlExtractor: HtmlExtractor是一个Java实现的基于模板的网页结构化信息精准抽取组件。...
- java计算机毕业设计中小型超市管理系统源码+数据库+系统+lw文档+mybatis+运行部署
- 黑龙江省大兴安岭地区谷歌高清卫星地图下载
- 办营业执照时公司经营范围变更注意事项