论文链接:https://arxiv.org/pdf/1901.02860.pdf
代码链接:https://github.com/kimiyoung/transformer-xl
参考来源:https://mp.weixin.qq.com/s/C1hXU3HMSXSY5Ru9r1CZAA

导读

今天学习的是谷歌大脑的同学和 CMU 的同学于 2019 年联合出品的论文《Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context》,目前被引次数超 200 次。

这篇论文提出的 Transformer-XL 主要是针对 Transformer 在解决长依赖问题中受到固定长度上下文的限制,如 Bert 采用的 Transformer 最大上下文为 512。

Transformer-XL 采用了一种 segment-level 的递归方法,不仅解决长依赖的问题,还解决了上下文碎片问题。最终,Transformer-XL 能学习到的长依赖超过 LSTM 80%,并比原来的 Transforner 多出 4.5 倍。而且 Transformer-XL 在长短序列中都获得了不错的性能,预测速度更是比原来快了 1800 多倍。

1、摘要

Transformer具有学习长依赖关系的潜力,但是受到语言建模中上下文长度固定的限制。为此,本文提出一种新的神经网络架构Transformer-XL,该网络结构能够在不破坏时间一致性的情况下,学习到超越固定长度的依赖性。该网络结构由片段级的循环机制(segment-level recurrence)和全新的位置编码策略(positional encoding scheme)组成。其优点是不仅可以捕获更长的依赖关系,还可以解决上下文碎片化(context fragmentation)的问题。从实验结果上来看,Transformer-XL 学习到的依赖性比 RNN 学习到的长 80%,比标准 Transformer 学到的长 450%,无论在长序列还是短序列中都得到了更好的结果,而且在评估时比标准 Transformer 快 1800+ 倍。值得一提的是,Transformer-XL还刷新了 bpc 和perplexity(困惑度)的当前最佳结果:在 enwiki8 上 bpc 从 1.06 提升至 0.99,在 text8 上从 1.13 提升至 1.08;在 WikiText-103 上困惑度从 20.5 提升到 18.3,在 One Billion Word 上从 23.7 提升到 21.8,在宾州树库(不经过微调的情况下)上从 55.3 提升到 54.5。本文模型的代码、预训练模型以及超参数在 TensorFlow 和 PyTorch 中都可以使用。

2、引言

语言建模需要对长期依赖性进行建模,它成功应用了无监督的预训练方法 (Peters et al., 2018; Devlin et al., 2018)。但要让神经网络对序列数据的长期依赖性建模一直都是一项挑战。RNN网络,特别是LSTM是一个标准的方案,它可以在多个benchmarks上获得健壮的结果(strong results)。尽管其使用广泛,但是RNNs由于梯度消失和梯度爆炸问题的存在,难以优化。纵使引入一些门限和梯度裁剪技术,仍然不足以完全解决该问题。此前的工作已经表明LSTM平均可以捕获200个word的上下文信息,这也指出了进一步改进的空间。

另一方面,通过attention机制直接连接长的word pairs可以缓解优化问题,同时习得长依赖(即原始的Transformer工作)。近来Al-Rfou 等人(2018)设计了一组辅助损失来训练深度 Transformer 网络进行字符级(character-level)语言建模,其结果远超LSTM。虽然已经取得成功,但是 Al-Rfou 等人(2018)的语言模型是在长度固定的几百个字符片段上独立训练的,没有任何跨片段的信息流(即多个segments,每个segment的长度固定,由数百个characters组成。但是segments之间没有信息交流)。由于上下文的长度是固定的,因此模型无法捕获任何超过预定义上下文长度的长依赖。此外,长度固定的segments都是在不考虑句子或其它语义边界的情况下,通过选择连续的符号块来创建的。因此,模型缺乏必要的上下文信息来很好地预测前几个符号,这就导致模型的优化效率和性能低下。我们将这个问题称为上下文碎片化(context fragmentation)。

为了解决上文提到的上下文固定长度的限制,本文提出了一种叫做Transformer-XL(超长)的新架构。我们将循环(recurrence)概念引入了深度自注意力网络。我们不再从头计算每个新segment的隐藏状态,而是复用从之前segments中获得的隐藏状态。被复用的隐藏状态视为当前segment的memory,而当前的segment为segments之间建立了循环连接(recurrent connection)。因此,超长依赖性建模成为了可能,因为信息可以通过循环连接来传播。同时,从之前的segment传递信息也可以解决上下文碎片化的问题。更重要的是,本文展示了使用相对位置而不是用绝对位置进行编码的必要性,这样做可以在不造成时间混乱(temporal confusion)的情况下,实现状态的复用。因此,作为额外的技术贡献,文本引入了简单但有效的相对位置编码公式,它可以泛化至比在训练过程中观察到的长度更长的注意力长度。

从单词级(word-level)到字符级(character level)的五个语言建模数据集上,Transformer-XL都获得了很好的结果。Transformer-XL在仅基于100M tokens训练的基础上也可以生成相对连贯的长文本文章。

本文的主要贡献包括:
(1)在纯粹的自注意力模型中引入了recurrence的概念,即循环连接。
(2)推导了一种新的位置编码方案。

这两种技术构成了一组完整的解决方案,因为其中任何一种单独都不能解决上下文长度固定的问题。Transformer-XL是首个从实质上不管是character-level还是word-level都比RNN更优秀的自注意力模型。

3、Transformer-XL模型

3.1、Vanilla Transformer

要想将 Transformer 应用到模型中,要解决的核心问题是如何训练 Transformer 使其可以将任意大小的上下文编码为固定大小的 Representation。

如果不考虑计算资源和内存的话,最简单粗暴的方法就是直接使用 Transformer 来对整个序列进行编码。但我们知道这种方法是不可能的。

还有一种可行但是比较粗糙的方法是将整个语料库分为多个大小相同的片段(segment),然后只在每个片段上训练而忽视所有的上下文信息,这种方法我们称为 Vanilla Transformer:

在预测过程中,Vanilla Transformer 也采用与训练相同大小的片段来预测最后一个位置,然后每次基于滑动窗口向右移动一个位置:

这种方法一定程度上确保了在预测过程中尽可能大的利用上下文,缓解了上下文碎片问题,但由于每次移动,新的片段都需要重新计算一次,所以其计算代价昂贵。

3.2、Segment-Level Recurrence

为了解决固定长度上下文的带来的问题,作者建议在 Transformer 架构中引入递归机制(Recurrence Mechanism)。在训练过程中,前一段计算出来的隐藏层状态会被被固定并缓存下来,当模型处理下一个新段时作为扩展上下文而被重用:

这种附加的连接可以随着网络深度的增加而增大依赖项的最大长度(想不通的可以想一下 GCN 的一阶领域)。除此之外,这种递归机制还可以解决上下文碎片问题,为新段前端的令牌提供必要的上下文信息。

我们来给出具体计算过程的数学公式:

假设现在有两个连续的分割片段sτ=[xτ,1,⋯,xτ,L]s_{\tau}=[x_{\tau,1},\cdots,x_{\tau,L}]sτ=[xτ,1,,xτ,L]sτ+1=[xτ+1,1,⋯,xτ+1,L]s_{\tau+1}=[x_{\tau+1,1},\cdots,x_{\tau+1,L}]sτ+1=[xτ+1,1,,xτ+1,L] ,其中 xxx 表示 token,LLL为序列长度, sτs_{\tau}sτ表示第τ\tauτ 个分割片段。

假设 Transformer 有 NNN 层,那么每个片段sτs_{\tau}sτ 就有 NNN 个隐藏层状态,我们将第 τ\tauτ 个片段的第 nnn 个隐藏层状态表示为 hτnh_{\tau}^nhτn, 那么第 τ+1\tau+1τ+1 个片段的第 nnn 层隐藏层状态就可以通过下式得出:h~τ+1n−1=[SG(hτn−1)∘hτ+1n−1]\tilde{h}_{\tau+1}^{n-1}=[SG(h_{\tau}^{n-1})\circ h_{\tau+1}^{n-1}]h~τ+1n1=[SG(hτn1)hτ+1n1]qτ+1n,kτ+1n,vτ+1n=hτ+1n−1WqT,h~τ+1n−1WkT,h~τ+1n−1WvTq_{\tau+1}^n,k_{\tau+1}^n,v_{\tau+1}^n=h_{\tau+1}^{n-1}W_q^T,\tilde{h}_{\tau+1}^{n-1}W_k^T,\tilde{h}_{\tau+1}^{n-1}W_v^Tqτ+1n,kτ+1n,vτ+1n=hτ+1n1WqT,h~τ+1n1WkT,h~τ+1n1WvThτ+1n=Transformer−Layer(qτ+1n,kτ+1n,vτ+1n)h_{\tau+1}^n=Transformer-Layer(q_{\tau+1}^n,k_{\tau+1}^n,v_{\tau+1}^n)hτ+1n=TransformerLayer(qτ+1n,kτ+1n,vτ+1n)其中,SG 是指 Stop-Gradient,表示状态固定,虽然提供信息但不再进行反向传播;h~τ+1n−1\tilde{h}_{\tau+1}^{n-1}h~τ+1n1 是一个临时符号,表示对两个连续片段第 n−1n-1n1 层隐藏层状态的拼接,qτn,kτn,vτnq_{\tau}^n,k_{\tau}^n,v_{\tau}^nqτn,kτn,vτn 分别表示 query、key 和 value 向量;注意,仔细看下公式,query 的计算方式不变,而 key 和 value 是利用拼接后的h~\tilde{h}h~来计算。

由于这是递归机制,所以层数越高,所能依赖到的范围越大,最大可能依赖长度为O(N×L)O(N\times L)O(N×L) ,如下图阴影部分所示:

除了实现超长的上下文依赖和解决碎片问题外,递归机制的另一个好处就是显著加快了计算速度。具体来说,Vanilla Transformer 每次都需要重新计算,而现在可以重用以前的片段,只要 GPU 内存允许,我们可以尽可能多的缓存之前的片段,并重用之前的片段以作为额外的上下文。

3.3、Relative Positional Encoding

在 Vanilla Transformer 中,由于每个片段相互独立每次都会重新计算,且使用了绝对位置编码的方式,所以不会出现位置混乱的情况。但是在 Transformer-XL 中,每个片段都是用相同的位置编码会导致在重用过程中无法保证位置信息的一致性。

为了避免这种情况,Transformer-XL 使用了相对位置信息编码的方式,从概念上来说,位置编码会为模型提供 token 相对顺序的线索。为了达到同样的目的,Transformer 在计算当前位置隐向量时,考虑和它存在依赖的 token 的相对位置。具体来说,在计算 Attention 评分时不需要知道 Query 和 key 的绝对位置,只要知道相对位置即可,并将这种相对位置关系动态的注入到每一层的 Attention 评分计算中,而不是静态地将偏差加入到初始 Embedding 中。

我们来对比一下绝对位置和相对位置:Ai,jabs=qiTkj=(Exi+Ui)TWqTWk(Exj+Uj)=ExiTWqTWkExj+ExiTWqTWkUj+UiTWqTWkExj+UiTWqTWkUjA_{i,j}^{abs}=q_i^Tk_j=(E_{x_i}+U_i)^TW_q^TW_k(E_{x_j}+U_j)\\=E_{x_i}^TW_q^TW_kE_{x_j}+E_{x_i}^TW_q^TW_kU_j+U_i^TW_q^TW_kE_{x_j}+U_i^TW_q^TW_kU_jAi,jabs=qiTkj=(Exi+Ui)TWqTWk(Exj+Uj)=ExiTWqTWkExj+ExiTWqTWkUj+UiTWqTWkExj+UiTWqTWkUj

其中, ExiE_{x_i}Exi为 token xix_ixi 的输入编码;UiU_iUi 为绝对位置编码;Wq,WkW_q,W_kWq,Wk 分别为 query 和 key 矩阵。Ai,jrel=ExiTWqTWk,EExj⏟(a)+ExiTWqTWk,RRi−j⏟(b)+uTWk,EExj⏟(c)+vTWk,RRi,j⏟(d)A_{i,j}^{rel}=\underbrace{E_{x_i}^TW_q^TW_{k,E}E_{x_j}}_{(a)}+\underbrace{E_{x_i}^TW_q^TW_{k,R}R_{i-j}}_{(b)}+\underbrace{u^TW_{k,E}E_{x_j}}_{(c)}+\underbrace{v^TW_{k,R}R_{i,j}}_{(d)}Ai,jrel=(a)

ExiTWqTWk,EExj+(b)

ExiTWqTWk,RRij
+
(c)

uTWk,EExj
+
(d)

vTWk,RRi,j

其中,Ri−jR_{i-j}Rij 是相对位置编码矩阵,RRR是正弦信号编码矩阵。由于query 向量对于所有查询位置都是相同的,所以用 uTu^TuT 代替UiTWqTU_i^TW_q^TUiTWqT ,同样的原因,我们用 vTv^TvT代替UiTWqTU_i^TW_q^TUiTWqT ;将WkW_kWkWk,E,Wk,RW_{k,E},W_{k,R}Wk,E,Wk,R 分别代替,以细分表示基于内容的 key 向量和基于位置信息的 key 向量。

在相对位置中,每个位置都有直观的含义:

  • (a)编码相邻内容的影响;
  • (b)编码与相邻内容相关的位置偏差;
  • (c)编码全局内容偏差;
  • (d)编码全局位置偏差。

Vanilla Transformer 只有前两种含义,而没有后两种含义。

最后我们来看下整体的公式:h~τ+1n−1=[SG(hτn−1)∘hτ+1n−1]\tilde{h}_{\tau+1}^{n-1}=[SG(h_{\tau}^{n-1})\circ h_{\tau+1}^{n-1}]h~τ+1n1=[SG(hτn1)hτ+1n1]qτ+1n,kτ+1n,vτ+1n=hτ+1n−1WqT,h~τ+1n−1WkT,h~τ+1n−1WvTq_{\tau+1}^n,k_{\tau+1}^n,v_{\tau+1}^n=h_{\tau+1}^{n-1}W_q^T,\tilde{h}_{\tau+1}^{n-1}W_k^T,\tilde{h}_{\tau+1}^{n-1}W_v^Tqτ+1n,kτ+1n,vτ+1n=hτ+1n1WqT,h~τ+1n1WkT,h~τ+1n1WvTAτ,i,jn=qτ,inTkτ,jn+qτ,inTWk,RnRi−j+uTkτ,j+vTWk,RnRi−jA_{\tau,i,j}^n={q_{\tau,i}^n}^Tk_{\tau,j}^n+{q_{\tau,i}^n}^TW_{k,R}^nR_{i-j}+u^Tk_{\tau,j}+v^TW_{k,R}^nR_{i-j}Aτ,i,jn=qτ,inTkτ,jn+qτ,inTWk,RnRij+uTkτ,j+vTWk,RnRijατn=Masked−Softmax(Aτn)vτn\alpha_{\tau}^n=Masked-Softmax(A_{\tau}^n)v_{\tau}^nατn=MaskedSoftmax(Aτn)vτnoτn=LayerNorm(Linear(ατn)+hτn−1)o_{\tau}^n=LayerNorm(Linear(\alpha_{\tau}^n)+h_{\tau}^{n-1})oτn=LayerNorm(Linear(ατn)+hτn1)hτn=Positionwise−Feed−Forward(oτn)h_{\tau}^n=Positionwise-Feed-Forward(o_{\tau}^n)hτn=PositionwiseFeedForward(oτn)

4、实验

模型在不同数据集下的表现:






各模型的相对有效长度(最长依赖长度)

5、结论

Transformer-XL 从解决长距离依赖问题为目标,提出了循环机制和相对位置编码这两个创新点,在解决了长依赖问题的同时也解决了上下文碎片的问题。此外,由于循环机制重用了先前隐藏层状态,其预测速度也得到了显著提升。诸多试验证明,Transformer-XL 相对 Vanilla Transformer 而言具有很好的性能。

5.1、优点

  • 在几种不同的数据集(大/小,字符级别/单词级别等)均实现了最先进的语言建模结果。
  • 结合了深度学习的两个重要概念——循环机制和注意力机制,允许模型学习长期依赖性,且可能可以扩展到需要该能力的其他深度学习领域,例如音频分析(如每秒16k样本的语音数据)等。
  • 在inference阶段非常快,比之前最先进的利用Transformer模型进行语言建模的方法快300~1800倍。
  • 有详尽的源码!含TensorFlow和PyTorch版本的,并且有TensorFlow预训练好的模型及各个数据集上详尽的超参数设置。

5.2、不足

  • 尚未在具体的NLP任务如情感分析、QA等上应用。
  • 没有给出与其他的基于Transformer的模型,如BERT等,对比有何优势。
  • 在Github源码中提到,目前的sota结果是在TPU大集群上训练得出,对于我等渣机器党就只能玩玩base模式了。

Transformer-XL语言模型:超长上下文依赖相关推荐

  1. transformer xl在文本生成上面的应用

    Transformer_xl相关介绍:https://zhuanlan.zhihu.com/p/84159401 从文本生成看Seq2Seq模型:https://zhuanlan.zhihu.com/ ...

  2. transformer xl 用于文本生成

    本文尝试用transformer xl做中文文本续写,基于论文为:<Transformer-XL: Attentive Language Models Beyond a Fixed-Length ...

  3. BERT: Bidirectional Encoder Representations from Transformers双向Transformer用于语言模型 NAACL 2018

    论文链接:https://arxiv.org/abs/1810.04805 tensorflow版本代码链接:https://github.com/google-research/bert pytor ...

  4. NLP-预训练模型-2019:XLM-Roberta【一种多语言预训练模型】

    <原始论文:Unsupervised Cross-lingual Representation Learning at Scale> Facebook AI团队于2019年11月发布了XL ...

  5. NLP-生成模型-2019:TransformerXL【对Vanilla Transformer的改进:片段级递归机制、相对位置编码】【超出固定长度上下文的注意力语言模型】

    <原始论文:Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context> 一.概述 一句话简介:Tran ...

  6. 图解OpenAI的秘密武器GPT-2:可视化Transformer语言模型

    大数据文摘出品 来源:github 编译:小七.池俊辉.Andy 今年,我们见识了许多令人眼花缭乱的机器学习的应用成果.其中OpenAI训练的GPT-2模型就展示出了惊艳的能力,它能够撰写出连贯而富有 ...

  7. 上车!带你一文了解GPT-2模型(transformer语言模型可视化)

    全文共9517字,预计学习时长28分钟 来源:Pexels 今年,各种机器学习的应用程序纷纷涌现.其中OpenAI GPT-2能够创作出逻辑清晰且激情盎然的文章,远远超出了人们对当前语言模型创造力的预 ...

  8. Transfomer XL翻译

    翻译:月入上万_ 审核:yphacker 原论文 论文代码 Transfomer XL翻译 摘要 1.简介 2.相关工作 3.模型 3.1 普通的Transformer模型 3.2 Segment-L ...

  9. Attention Mechanism[Transformer、Transformer-XL、XLNet]

    Content Attention Mechanism--->聚焦关键点 1 History 2 Introduction 3 structure 4 application situation ...

最新文章

  1. linux除了eeprom其他的保存方法,linux的EEPROM的读写控制.doc
  2. SVG(网页加载显示的加载进度动态图)
  3. 自增或自减例子:i++和++i的相同点和不同点
  4. mysql表 spid program_oracle 解锁某张表 和编译存储过程卡死问题处理
  5. SPC-Light显示正常的日期与时间
  6. NOI2005 瑰丽华尔兹
  7. Foobar2000目前最强解码方案
  8. linux firefox 解雇ie,Fire IE
  9. linux gz解压 指定目,linux解压tar.gz到指定文件夹或目录
  10. 蓝牙技术|伦茨科技带你了解蓝牙音频
  11. 思科服务器如何进入网站,思科路由器怎么进入设置网站
  12. workerman php使用,workerman怎么用
  13. 使用 TestFlight 进行 iOS App 内测
  14. valgrind安装及使用
  15. STLINK : Warning: Connection to device 0x413 is lost
  16. Java程序员必知必会之JVM运行时数据区
  17. JAVA学习笔记JEECG BOOT介绍
  18. 实用 Windows 软件系列分享(五)
  19. java编写车类_用Java程序创建一个汽车接口,接口中要定义汽车应有的属性和行为,随后编写多个汽车接口的实现类,...
  20. 相机内参 k_4K相机与智能手机中的4K视频相比如何

热门文章

  1. Black Box(POJ 1442·TREAP实现)
  2. Exchange Server 2013多域名证书申请
  3. SpringCloud Sentinel 使用restTemplate的两种配置介绍
  4. 工程师软技能4:找出你的短板
  5. vue-cli本地的一个websocket
  6. 查看linux网络带宽
  7. Netty入门笔记-Linux网络I/O模型介绍
  8. SpringBoot配置@PropertySource、@ImportResource、@Bean注解
  9. textarea输入中文和数字换行解决方法
  10. Mybatis There is no getter for property named 'XXX' in 'class java.lang.XXX