论文标题:

Learning to Encode Position for Transformer with Continuous Dynamical Model

论文作者:

Xuanqing Liu (UCLA), Hsiang-Fu Yu (Amazon), Inderjit Dhillon (UT Austin, Amazon), Cho-Jui Hsieh (UCLA)

论文链接:

https://arxiv.org/pdf/2003.09229.pdf

代码链接:

https://github.com/xuanqing94/FLOATER


随着Transformer时代的到来,各种花式位置编码方法被提出,但是,它们要么需要手动地设计,要么受到文本长度的限制。

本文提出一种基于连续动态系统(Continuous Dynamic Model)的位置编码,使用常微分方程(ODE)求解器学习,既不受文本长度的限制,又能建模位置上的关系,非常灵活。在NMT和NLU等任务上能实现比较好的结果。

位置编码

以Transformer为代表的使用自注意力(Self-Attention)的模型具有位置置换不变性:打乱句子中的词模型会得到同样的特征。为此,此类模型需要加入“位置编码”,让模型能够识别什么位置有什么词。

当前已经有一些关于位置编码的研究,如Transformer原文提出的三角函数编码、可学习参数编码,和后来的相对位置编码等,但这些编码方式都存在一些问题。

比如三角函数编码,尽管可以处理理论上很长的句子,但是由于它是人为设计的而不是自动从数据中学习,那么就可能在效果上欠佳。

而可学习的参数编码,尽管是模型自己学到的,但它能处理的文本长度是有限的,因为其需要的参数量是是文本长度。

相对位置编码需要的参数量是,和文本长度无关,但是它在一定程度上牺牲了远距离的位置依赖。

我们希望位置编码有以下特点:

  • 可归约性:能够处理任何长度的文本

  • 可学习性:不是人为指定的而是从数据中学到的

  • 低参数性:引入的参数量是有限制的,而不是无限增长的

基于此,本文提出将位置编码的学习归入一种连续动态系统,这样一来,就可以通过学习这个系统(模型)得到每个位置编码,而不是单独地为每个位置学习一个独有的编码。

同时,它也满足了以上三个条件:(1)定义域为,可以学习任何长度的文本;(2)位置编码是学习得到的;(3)参数量就是该系统的所有参数。

为了学习这个模型,本文使用了神经常微分方程(Neural ODE)求解方法。总的来说,本文贡献如下:

  • 提出FLOATER——一种新的位置编码方案,通过连续动态系统和ODE学习;

  • FLOATER克服了以往位置编码的若干缺点,可以处理任何长度文本;

  • FLOATER可以被运用到任何基于Transformer的模型中;

  • 在机器翻译、自然语言理解和问答等任务上,FLOATER实现了较好的效果提升。

Transformer位置编码

在介绍FLOATER之前,我们先简要介绍一下Transformer和位置编码,并引入一些记号。

为模型的第层,是第层的注意力层,是第层的前馈层,那么,Transformer的编码层就可以表示为:

这里,是输入序列。进行如下的自注意力操作:

以上没有考虑位置编码,如果把位置编码加进来,那么每一层就可以表示为:

这里,上标是第层。的选择有很多。Transformer给出的方案是三角函数,和可学习的参数。

FLOATER:基于连续动态系统的位置编码

首先要明确,所谓的位置编码其实是离散的,也即一个向量序列,然后依次加到输入特征上。

但是从上面的概要中我们发现,这些序列在开始输入的时候彼此之间是独立的,如果想要建模位置编码的相关性又该如何做呢?我们可以想象有这样一个模型,它能接受前一个位置的编码,得到下一个位置的编码,即

基于此,我们可以考虑一个连续版本的位置编码,再考虑一个函数,这样一来,我们就可以把域中的点映射为想要的高维位置编码了。

现在的问题是,如何构造函数。我们可以使用一个连续动态系统:

并有初值。这里是一个神经网络,参数为。这个式子的意思是,要得到时刻的值,只需要考虑它前面的一个位置,计算之间的“增量”即可(即积分部分)。

因为函数是连续且定义在正实数域上的,而实际的位置编码是定义在自然数域上的,所以在得到之后,我们可以建立一个的映射,比如,这样一来,第个位置编码就可以是,其中是间隔,可以自主设置(本文设置为0.1)。

现在剩下的问题就是,如何求解函数(注意到是一个输入为点位置和该点值的神经网络)。这等价于解如下常微分方程(ODE):

这个怎么解呢?我们在下面简要说明,不感兴趣的读者可以略过下面的一节,或者可以参考原文附录A和论文Neural Ordinary Differential Equations。

求解编码函数

假设我们的输入序列长度为,那么我们可以首先求出这个位置编码:

然后按照常规的流程,把这些位置编码加到输入特征上,继续往下走,直到最后产生损失:。那么为了更新,我们就要计算损失对它的梯度,这就可以用ODE的方法解决,如下图所示:

于是,梯度可以计算为:

其中,可以通过下式得出:

权重共享

研究表明,在每一层都加入位置编码会提高最终的效果,于是,第层的位置编码就可以同样表示为:

为了更高效地学习,我们共享所有层的模型参数,只不过是对不同的层有不同的初值

与普通Transformer的关系

那么,FLOATER引入的位置编码和普通的Transformer的关系是什么呢?回忆一下,普通Transformer计算Query的方式是这样的:

这里是普通的位置编码,比如三角函数编码和可学习的参数编码。那么,FLOATER的计算方式是:

显然,FLOATER等价于在原来Transformer的基础上增加一个偏置项,既然如此,我们直接去学习一个偏置项函数即可:

这时候,如果,则,这就退化到了普通的位置编码了。这说明,普通的位置编码是FLOATER的特例。

下图是FLOATER的示意图。

实验

我们在机器翻译、自然语言理解和问答上实验。实验设置、模型初始化详见原文附录。下表是机器翻译的结果。可以看到,相比三角函数编码和参数编码,FLOATER编码能够实现较大的提升。

下表是NLU任务的结果。从表中可以看到,FLOATER几乎在所有任务上都能超过RoBERTa,尤其是在大模型上有更大的优势。在问答方面,FLOATER也略好于RoBERTa。

接下来看看在不同文本长度上各编码方案的优劣。如下图所见,当文本越长时,FLOATER的相对优势就越明显,这表明,FLOATER学到的编码函数可以有较强的泛化能力。

其次,我们发现FLOATER和RNN是有一定的相似度的,这体现在位置编码的计算方式上,如果我们通过下面的方式(RNN)来计算位置编码又如何呢:

这里的表示第个位置,要么是(scalar),要么是三角函数表示的向量(vector)。在得到整个位置编码序列之后,我们同样地把它们和Transformer的输入相加。

下表是几种计算位置编码方法的结果。可以看到,用RNN去计算位置编码效果也不错,但都没有FLOATER好。

最后我们来看看几种位置编码的可视化,如下图所示。

显然,三角函数编码(a)的结构化程度最好,而参数化编码(b)就显得比较杂乱,RNN编码(d)几乎就没有结构化信息,而FLOATER(c)和三角函数编码比较类似,具有一定结构化信息。

注意到,并不是说结构化程度越高效果就越好,此处只是在阐释不同位置编码具有怎样的模式。

另一个值得注意的地方是,参数化编码(b)的底部几乎是常数,这是因为长文本在数据集中总的来说还是比较少的,所以这些比较远的位置就难以得到更新。

换句话说,参数化编码难以泛化到比较远的地方。而FLOATER(c)则不然,尽管长文本比较少,但是它仍然有很好的泛化能力。

小结

本文提出了一种基于连续动态系统的位置编码方法,可以不受文本长度的限制,可以从数据中学习,并且引入的参数量也不大。

实验表明,这种位置编码方式可以提升基线模型的表现,在机器翻译、自然语言理解和问答等任务上表现良好。

近些年来,ODE/PDE和神经网络结合的工作开始涌现,从物理上解释、提升神经网络是一条有前景的道路。

比如,Understanding and Improving Transformer From a Multi-Particle Dynamic System Point of View 这篇文章从ODE的角度试图解释Transformer,并且实现了很好的结果。我们期待未来有更多结合可解释性的文章。

????

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

关于PaperWeekly

PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。

ICML 2020 | 基于连续动态系统学习更加灵活的位置编码相关推荐

  1. [深度学习] 自然语言处理---Transformer 位置编码介绍

    2017年来自谷歌的Vaswani等人提出了Transformer模型,一种新颖的纯粹采用注意力机制实现的Seq2Seq架构,它具备并行化训练的能力,拥有非凡的性能表现,这些特点使它深受NLP研究人员 ...

  2. ICML 2020 | 基于类别描述的文本分类模型

    论文标题: Description Based Text Classification with Reinforcement Learning 论文作者: Duo Chai, Wei Wu, Qing ...

  3. CIKM 2020 | 基于多视图协作学习的人岗匹配研究

    论文简介 论文:Learning to Match Jobs with Resumes from Sparse Interaction Data using Multi-View Co-Teachin ...

  4. “交通·未来”第10期:基于深度学习的动态系统复杂数据建模方法:以铁路列车晚点预测为例...

    前一阵公众号正式推出了"交通·未来"系列线上公益学术活动等你来~, 9月21日晚19:00,我们将迎来活动的第10期. 1.讲座主题 基于深度学习的动态系统复杂数据建模方法:以铁路 ...

  5. 基于逐维反向学习的动态适应布谷鸟算法

    文章目录 一.理论基础 1.布谷鸟搜索算法 2.DA-DOCS算法 (1)逐维反向学习策略 (2)动态适应 (3)DA-DOCS算法流程 二.实验与结果分析 三.参考文献 一.理论基础 1.布谷鸟搜索 ...

  6. 【CIKM 2020】基于多视图协作学习的人岗匹配研究

    点击上方,选择星标或置顶,每天给你送干货! 阅读大概需要16分钟 跟随小博主,每天进步一丢丢 来自:RUC AI BOX 近日,第29届国际计算机学会信息与知识管理大会(CIKM 2020)在线上召开 ...

  7. 复杂系统学习(二):动态系统和混沌

    目录 1. 引言:什么是动态系统? 2. 函数和迭代 3. 案例学习:物种增长 3.1 完善我们的模型 4. 对混乱的定义 4.1 蝴蝶效应 4.2 对初始条件的敏感依赖 4.3 分岔图 1. 引言: ...

  8. 基于Java毕业设计在线学习平台源码+系统+mysql+lw文档+部署软件

    基于Java毕业设计在线学习平台源码+系统+mysql+lw文档+部署软件 基于Java毕业设计在线学习平台源码+系统+mysql+lw文档+部署软件 本源码技术栈: 项目架构:B/S架构 开发语言: ...

  9. 基于Html5的个性化学习系统的设计与实现

    目 录 摘 要 I Abstract II 第1章 绪论 1 1.1 课题背景及意义 1 1.2 国内外研究现状 1 1.2.1国内研究现状 1 1.2.2国外研究现状 4 1.3开发工具及技术 5 ...

最新文章

  1. EJB基础 作者 Richard Monson-HaefelTim Rohaly
  2. center6linux ip设置,centos6固定ip地址
  3. 吴恩达 coursera AI 第四课总结+作业答案
  4. 怎么判断ajax返回是否成功,如何判断jquery的ajax请求已经返回
  5. 超详细解读:神经语义解析的结构化表示学习 | 附代码分析
  6. Python合并字典的七种方式!
  7. 华为AppCube入选Forrester《中国低代码平台市场分析报告》
  8. java代码隐藏面消除算法,java常面的几种排序算法
  9. 德鲁伊 oltp oltp_内存中OLTP系列–简介
  10. 【考古向翻译】Pwn2Own 2010 Windows 7 Internet Explorer 8 exploit
  11. 去掉图片黑背景输出为透明png(算法和工具)
  12. 线程如何同步?如何使用同步方法?
  13. HTML5七夕情人节表白网页制作【生日快乐粒子烟花】HTML+CSS+JavaScript 生日祝福网页代码
  14. 3D人体重建方法漫谈
  15. 跳槽 ,你跳的是工资,还是阶层?
  16. 查询学过“叶平”老师所教的所有课的同学的学号、姓名
  17. 【matlab 基础篇 02】基础知识一键扫盲,看完即可无障碍编程(超详细+图文并茂)
  18. 冲破百亿天花板,木浪云用云和智能突破备份边界
  19. 小科普:什么是屏幕分辨率
  20. android 静音接口,android 静音方法

热门文章

  1. Lucene mysql app查询_集成Lucene,查询相关数据
  2. 多个数字数组_1分钟彻底理解JavaScript的数组与函数
  3. CEF3开发者系列之CEF3入门
  4. python-configparser生成ini配置文件
  5. @PathVariable详解
  6. Leetcode中单链表题总结
  7. linux 进程管理 ppt,linux操作系统-进程管理和打印管理.ppt
  8. mysql dms_关于MySQL与DMsql探寻
  9. oracle 01775,set Autotrace使用的问题与解决
  10. 假设检验 Hypothesis testing