文章目录

  • Neural Turing Machines (NTMs)
    • 读:
    • 写:
  • Memory-augmented Neural Networks
  • Meta Networks
    • 算法
  • 参考
  • 系列文章代码、数据集地址

  上一节说了Metrics-Based Methods,主要是将输入编码到一个相同的特征空间,然后比较相似度。但是人类很多时候能够快速学习的原因是对以往知识、经验的利用。因此通过扩展一个记忆模块似乎也能做到少样本学习。这一节主要介绍通过模型结构的设计,来做few shot learning

Neural Turing Machines (NTMs)

  • Neural Turing Machines

  LSTM将记忆藏在隐藏节点(hidden state)中,这样就会存在很多问题,一个是计算的开销,另外一个就是记忆会被经常改动,并且是那种牵一发动全身地改变。

  而NTM由一个controller和一个memory矩阵构成,通过特定的寻址读写机制,对相关的memory进行修改,并且易于扩展:

  当给定一个输入时,controller负责依据输入对memory进行读写操作,实现记忆更新。

  但是通过特定的行和列来读取memory的话,我们就没办法对整个网络求梯度了,不能使用微分算法来更新。那这样肯定不行,要想一些办法。

  我们需要对外部的memory(存储器)进行选择性地读写。人的大脑工作的时候首先是聚焦注意力(记忆中很大一块,比如说你昨晚吃啥了,一般会聚焦到昨晚那一大块时间段),然后寻找到特定的记忆(比如有啥菜)。因此有了模糊读写(blurry read and write)的概念,通过不同的权重与内存中的所有元素进行交互

读:

  假设记忆矩阵MtM_{t}Mtstep ttt 是一个拥有RRR行和CCC列的内存矩阵(CCC代表记忆中每一行的大小)。执行读写操作的网络输出称为headscontroller输出一个attention向量,长度为RRR,也称之为weight vector(wt)(w_{t})(wt),其中的每一个元素wt(i)w_{t}(i)wt(i)memory matrixiii行的weightweight通常都被归一化,用数学形式可表示为:

0≤wt(i)≤1∑i=1Rwt(i)Z=1\begin{aligned} & 0 \leq w_{t}(i) \leq 1 \\ &\sum_{i=1}^{R} w_{t}(i) Z =1 \end{aligned}0wt(i)1i=1Rwt(i)Z=1

  read head返回的就是记忆矩阵行的线性组合:

rt←∑iRwt(i)Mt(i)r_{t} \leftarrow \sum_{i}^{R} w_{t}(i) M_{t}(i)rtiRwt(i)Mt(i)

写:

  写的操作可以分为两步擦除(erasing)和添加(adding)。为了实现擦除操作,需要一个erase vector ete_{t}et其值在0-1之间,擦除操作可表示为:

Mterased(i)←Mt−1(i)[1−wt(i)et]M_{t}^{\text {erased}}(i) \leftarrow M_{t-1}(i)\left[\mathbf{1}-w_{t}(i) e_{t}\right]Mterased(i)Mt1(i)[1wt(i)et]

  当weight wt(i)w_{t}(i)wt(i)et\mathbf{e}_{t}et都为1时,memory被清空,当其中任意一个为0时,则不会有任何改变,这种方式也支持多个操作任意顺序的相互叠加。记忆矩阵可用一个长度为CCC的向量ata_{t}at更新:

Mt(i)←Mterased+wt(i)atM_{t}(i) \leftarrow M_{t}^{e r a s e d}+w_{t}(i) a_{t}Mt(i)Mterased+wt(i)at

  • 寻址

  读写操作的关键就在于权重矩阵www,控制器产生权重矩阵可以分为以下四步:

  1. content-based addressinghead产生一个长度为CCCkey vector ktk_{t}kt,然后用余弦相似度度量ktk_{t}kt与记忆矩阵MtM_{t}Mt的相似性:

K(u,v)=u⋅v∥u∥⋅∥v∥K(u, v)=\frac{u \cdot v}{\|u\| \cdot\|v\|}K(u,v)=uvuv

  对记忆矩阵每一行都进行一样的操作再归一化可得content weight vector

wtc(i)=exp⁡(βtK(kt,Mt(i)))∑jexp⁡(βtK(kt,Mt(j)))w_{t}^{c}(i)=\frac{\exp \left(\beta_{t} K\left(k_{t}, M_{t}(i)\right)\right)}{\sum_{j} \exp \left(\beta_{t} K\left(k_{t}, M_{t}(j)\right)\right)}wtc(i)=jexp(βtK(kt,Mt(j)))exp(βtK(kt,Mt(i)))

  其中βt\beta_{t}βt可控制聚焦的精度,βt\beta_{t}βt越大,聚焦范围就越小。(这里可以看作是找一个大块的记忆。)

  1. location-based addressing,这里从特定的内存地址中进行读写。通过一个interpolation gate gt∈(0,1)g_{t} \in (0,1)gt(0,1)来混合content weight vector wtcw_{t}^{c}wtc和上一时刻head 产生的weight vector wt−1w_{t-1}wt1来产生gated weighting wtgw_{t}^{g}wtg,通过这种wtgw_{t}^{g}wtg控制方式可以考虑用或者不用content weight vector

wtg←gtwtc+(1−gt)wt−1w_{t}^{g} \leftarrow g_{t} w_{t}^{c}+\left(1-g_{t}\right) w_{t-1}wtggtwtc+(1gt)wt1

  1. interpolation之后,head产生一个normalized shift weighting sts_{t}st,对权重进行旋转位移,比如当前的权重值关注于某一个locationmemory,经过此步就会扩展到其周围的location,使得模型对周围的memory也会做出少量的读和写操作,采用循环卷积:

w~t(i)←∑j=0R−1wtg(j)st(i−j)\tilde{w}_{t}(i) \leftarrow \sum_{j=0}^{R-1} w_{t}^{g}(j) s_{t}(i-j)w~t(i)j=0R1wtg(j)st(ij)

  1. 卷积操作之后会使得权重分布趋于均匀化,这将会导致本来集中于单个位置的焦点出现发散,这里采用锐化操作,head产生一个标量γ≥1\gamma \geq 1γ1

wt(i)←w~t(i)γt∑jw~t(j)γtw_{t}(i) \leftarrow \frac{\tilde{w}_{t}(i)^{\gamma_{t}}}{\sum_{j} \tilde{w}_{t}(j)^{\gamma_{t}}}wt(i)jw~t(j)γtw~t(i)γt

  上述操作都是可微分的,因此可以使用微分算法对其进行优化。

Memory-augmented Neural Networks

  NMT中采用了content-based addressinglocation-based addressing,在MANN中只采用content-based addressing,因为只需要比较当前input是否和之前输入的input相似即可。

  MANN中的读取操作和NTM的读取操作非常类似,不同之处在于它只采取content-based addressing的方式,先产生一个归一化的权重向量wtrw_{t}^{r}wtr

wtr=exp⁡(K(kt,Mt(i)))∑jexp⁡(K(kt,Mt(j)))w_{t}^{r}=\frac{\exp \left(K\left(k_{t}, M_{t}(i)\right)\right)}{\sum_{j} \exp \left(K\left(k_{t}, M_{t}(j)\right)\right)}wtr=jexp(K(kt,Mt(j)))exp(K(kt,Mt(i)))

  其中K()K()K()表示余弦相似性,之后与NTM类似与记忆矩阵MtM_{t}Mt加权求和即可:

rt←∑iRwtr(i)Mt(i)r_{t} \leftarrow \sum_{i}^{R} w_{t}^{r}(i) M_{t}(i)rtiRwtr(i)Mt(i)

  这里采用了Least Recently Used Access (LRUA)的写入方式:

Mt(i)←Mt−1(i)+wtw(i)ktwtu←γwt−1u+wtr+wtw\begin{aligned} M_{t}(i) & \leftarrow M_{t-1}(i)+w_{t}^{w}(i) k_{t} \\ w_{t}^{u} & \leftarrow \gamma w_{t-1}^{u}+w_{t}^{r}+w_{t}^{w} \end{aligned}Mt(i)wtuMt1(i)+wtw(i)ktγwt1u+wtr+wtw

  其中wtuw_{t}^{u}wtu是使用权重,由读权重wtrw_{t}^{r}wtr和写权重wtww_{t}^{w}wtw,和上一时刻的使用权重wtuw_{t}^{u}wtu组成,γ\gammaγ为折扣因子。

  wtww_{t}^{w}wtw由上一时刻的读权重wtrw_{t}^{r}wtr(表示last used location)和最少使用(least-used weight)权重wluw^{lu}wlu基于参数α\alphaα组成:

wtw=σ(α)wt−1r+(1−σ(α))wt−1luw_{t}^{w}=\sigma(\alpha) w_{t-1}^{r}+(1-\sigma(\alpha)) w_{t-1}^{l u}wtw=σ(α)wt1r+(1σ(α))wt1lu

  这里就剩下最后一个问题,最少使用权重wluw^{lu}wlu怎么定义。定义如下:

wtlu=1wtu(i)≤m(wtu,n)w_{t}^{l u}=\mathbf{1}_{w_{t}^{u}(i) \leq m\left(w_{t}^{u}, n\right)}wtlu=1wtu(i)m(wtu,n)

  其中m(wtu,n)m\left(w_{t}^{u},n\right)m(wtu,n)表示wtuw_{t}^{u}wtu中第nnn个最小的元素,只有当期够小才为1,否者为0

Meta Networks

  传统的神经网络通过stochastic gradient descent方式做更新,如果batch_size1的话,更新就会很慢。如果train一个网络去预测目标任务的网络参数的话,这样的学习起来就会很快,称之为fast weights。由此我们可以知道meta network由两部分组成:

  1. meta-learner:它所要做的事情就是获取不同task的通用的知识。可以看作是一个embeddings function,判断两个不同的数据之间的差别。
  2. base-learner:期望去学一个target task,就是最常见的学习算法,比如做个分类这样。

  开始之前定义一些术语:

  • Support set:从训练集采样得到的一些数据点(x,y)(x,y)(x,y)
  • query set:同样也是从训练集采样得到的一些数据点(x,y)(x,y)(x,y),作为query set。
  • Embedding functionfθf_{\theta}fθmeta-learner的一部分,与siamese network类似,用于预测两个输入是否属于同一类。
  • Base-learner modelgϕg_{\phi}gϕ:就是一个需要处理完整任务的学习算法。
  • θ+\theta^{+}θ+Embedding functionfθf_{\theta}fθfast weight,由一个LSTM FwF_{w}Fw产生。
  • ϕ+\phi^{+}ϕ+Base-learner modelgϕg_{\phi}gϕfast weight,由一个网络GvG_{v}Gv产生。

  可以看出slow weights(θ,ϕ)(\theta,\phi)(θ,ϕ)构成了meta-learnersbase learners。两个不同的网络FwF_{w}FwGvG_{v}Gv生成fast weight

  meta networks的网络架构如下所示:

  可以看到meta networkbase learnermeta-learner组成,meta-learner给配了一个外部memory(external memory)。

算法

  整个训练数据被分为两部分support setS=(xi′,yi′)S=(x_{i}^{\prime},y_{i}^{\prime})S=(xi,yi),和query set U=(xi,yi)U=(x_{i},y_{i})U=(xi,yi),我们要做的事情就是学四个网络(f(θ),g(ϕ),Fw,Gvf(\theta),g(\phi),F_{w},G_{v}f(θ),g(ϕ),Fw,Gv)的参数。

  1. support set随机采样K个样本对。循环将其中每个样本1-K送入embedding function f(θ)f(\theta)f(θ),并计算cross-entropy lossLembeddingsL_{e m b e d d i n g s}Lembeddings
  2. 计算得到的cross-entropy lossLembeddingsL_{e m b e d d i n g s}Lembeddings再经过LSTM计算θ+\theta^{+}θ+θ+=Fw(∇Lembeddings)\theta^{+}=F_{w}\left(\nabla L_{\text {embeddings}}\right)θ+=Fw(Lembeddings)
  3. 之后对于support set中的每个样本计算fast weight,同时基于embeddings更新external memory。首先循环样本1-K,经过base learner gϕ(xi)g_{\phi}(x_{i})gϕ(xi),计算loss LitaskL_{i}^{task}Litask,对其求梯度,然后求fast weight ϕi+=Gv(∇Litas⁡k)\phi_{i}^{+}=G_{v}\left(\nabla L_{i}^{\operatorname{tas} k}\right)ϕi+=Gv(Litask)。然后将ϕi+\phi_{i}^{+}ϕi+存储在memory MMM的第iiilocation

  然后将fastslow weight合并:

  support sample再经过这个网络得到 ri′=fθ,θ+(xi′)r_{i}^{\prime}=f_{\theta, \theta^{+}}\left(x_{i}^{\prime}\right)ri=fθ,θ+(xi),将ri′r_{i}^{\prime}ri存储在memory RRR的第iiilocation

  1. 基于query set U=(xi,yi)U=(x_{i},y_{i})U=(xi,yi)来构造损失函数,开始时Ltrain=0L_{train}=0Ltrain=0。从1-L循环所有样本:拿query set的数据经过embeddings network rj=f^θ,θ+(xj)r_{j}=\hat{f}_{\theta, \theta^{+}}\left(x_{j}\right)rj=f^θ,θ+(xj),然后计算其与support set中的样本经过embeddings的输出,也就是memory R的相似度aj=cosine⁡(R,rj)a_{j}=\operatorname{cosine}\left(R, r_{j}\right)aj=cosine(R,rj)。基于此计算base learnerfast weightϕ+\phi^{+}ϕ+ϕj+=softmax⁡(aj)TM\phi_{j}^{+}=\operatorname{softmax}\left(a_{j}\right)^{T} Mϕj+=softmax(aj)TM,其中MMMsupport set samples。然后计算loss LitaskL_{i}^{task}LitaskLtrain←Ltrain+Ltask(gϕ,ϕ+(xi),yi)L_{t r a i n} \leftarrow L_{t r a i n}+L^{t a s k}\left(g_{\phi, \phi^{+}}\left(x_{i}\right), y_{i}\right)LtrainLtrain+Ltask(gϕ,ϕ+(xi),yi)
  2. LtrainL_{t r a i n}Ltrain更新f(θ),g(ϕ),Fw,Gvf(\theta),g(\phi),F_{w},G_{v}f(θ),g(ϕ),Fw,Gv网络参数。

  matching networksLSTM meta-learners其实是使用了相同的策略,都有利用额外的信息,一个是contextual embeddings,一个是meta information,期望抽取出一些对于整个task比较重要的信息。

参考

  • NTM-Lasagne: A Library for Neural Turing Machines in Lasagne
  • Neural Turing Machines

少样本学习系列(二)【Model-Based Methods】相关推荐

  1. 【转载】Few-shot learning(少样本学习)和 Meta-learning(元学习)概述

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_37589575/arti ...

  2. Few-shot learning 少样本学习

    N-way K-shot用来衡量网络泛化能力,但小样本在实际应用中并不是很好. 背景 深度学习已经广泛应用于各个领域,解决各类问题,在图像分类的问题下,可以很轻松的达到94%之上.然而,deep le ...

  3. Few-shot learning(少样本学习)和 Meta-learning(元学习)概述

    目录 (一)Few-shot learning(少样本学习) 1. 问题定义 2. 解决方法 2.1 数据增强和正则化 2.2 Meta-learning(元学习) (二)Meta-learning( ...

  4. 2020年最全 | 少样本学习(FSL)相关综述、数据集、模型/算法和应用资源整理分享...

    文章来源 | 深度学习与NLP Few Shot Learning(FSL)又称少样本学习,这是做AI研究经常遇到的一个问题.深度学习技术需要大量的数据来训练一个好的模型.例如典型的 MNIST 分类 ...

  5. NeurIPS 2021 | 微软研究院提出CLUES,用于NLU的少样本学习评估

    ©作者 | 雪麓 单位 | 北京邮电大学 研究方向 | 序列标注 自然语言理解 (NLU) 的最新进展部分是由 GLUE.SuperGLUE.SQuAD 等基准驱动的.事实上,许多 NLU 模型现在在 ...

  6. 图机器学习(GML)图神经网络(GNN)原理和代码实现(前置学习系列二)

    图机器学习(GML)&图神经网络(GNN)原理和代码实现(PGL)[前置学习系列二] 上一个项目对图相关基础知识进行了详细讲述,下面进图GML networkx :NetworkX 是一个 P ...

  7. Few-shot learning(少样本学习,入门篇)

    本文介绍一篇来自 https://www.analyticsvidhya.com/ 关于少样本学习的的博客. 原文地址 文章目录 1. 少样本学习 1.1 为什么要有少样本学习?什么是少样本学习? 1 ...

  8. 少样本学习新突破!创新奇智入选ECCV 2020 Oral论文

    点击上方"机器学习与生成对抗网络",关注"星标" 获取有趣.好玩的前沿干货! 转自 创新奇智 近日,创新奇智有关少样本学习(Few-shot Learning) ...

  9. A.图机器学习(GML)图神经网络(GNN)原理和代码实现(前置学习系列二)

    图学习图神经网络算法专栏简介:主要实现图游走模型(DeepWalk.node2vec):图神经网络算法(GCN.GAT.GraphSage),部分进阶 GNN 模型(UniMP标签传播.ERNIESa ...

  10. 少样本学习(一):了解一些基础概念

    Few-shot learning(少样本学习)和 Meta-learning(元学习)概述 参考:Few-shot learning(少样本学习)和 Meta-learning(元学习)概述_Cao ...

最新文章

  1. 大数据笔记10:大数据之Hadoop的MapReduce的原理
  2. windows 生成 deploy key_推荐一个免费生成点线/方格/横线纸张的网站
  3. python分配buffer_Node.js中的buffer如何和python中的buffer相对应
  4. 【Python】用Pyecharts制作炫酷的可视化大屏
  5. 今年最有档次的9个词!(不看后悔)
  6. node.js 实现udp传输_Node.js实战15:通过udp传输文件。
  7. linux perl 单例模式,Perl脚本学习经验(三)--Perl中ftp的使用
  8. MySQL在windows的my-default.ini配置
  9. apache是干嘛用的_同学,其实用免费版的IDEA来创建SpringBoot项目挺方便的...
  10. 呼吸灯 裸机 S3C2416
  11. Jobdu 1005
  12. java爬虫 教程_Java爬虫其实也很简单,教你实用的入门级爬虫
  13. java1.8.0_java jdk官方下载|java jdk v1.8.0 官方免费版-520下载站
  14. 谷歌地图网页版_安卓版谷歌地图新增专用的街景图层
  15. 企业微信怎么批量加人?怎么管理员工?看看这套系统
  16. Word转图片的方法(两种)
  17. Docker入门教程 Part 1 基础概念 - 镜像、容器、仓库
  18. 内外盘期货市场的介绍(一)
  19. flash中国官网显示可能损害计算机,重橙网络:Flash Player 中国官网最新版可解决使用异常的问题...
  20. SAP 各大常用模块汇总介绍(一)

热门文章

  1. ubuntu 12.04 lts搭建android 编译环境
  2. VBA实战技巧精粹013:宏代码保存工作簿的3种方法
  3. Debian - RAID5搭建(热备)
  4. HCIE-RS面试--环路产生及防环机制
  5. web之nginx相关配置二
  6. 姚前:分布式账本与传统账本的异同及其现实意义
  7. 八、JVM视角浅理解并发和锁
  8. [C#] readonly vs const
  9. 整理的部分Java和C#不同点
  10. html5 canvas类库 实例