transfromer-XL论文详解 – 潘登同学的NLP笔记

文章目录

    • transfromer-XL论文详解 -- 潘登同学的NLP笔记
  • Vanilla Transformer
  • Segment-Level Recurrence
    • Relative Position Encodeings
  • 最终总结

Transformer-XL是对Transformer的改进或变种,主要是解决长序列的问题,其中XL表示extra long,在最近流行的XLNet中就是使用Transformer-XL作为基础模块。在下文中,是将Trm-XL放在类似GPT这样的语言模型框架中来介绍,所以理解的时候要放在整个模型中去理解,而不是一个单独的Trm-XL。

Vanilla Transformer

transformer作为一种特征提取器,在NLP中有广泛的应用。但是Trm需要对输入序列设置一个固定的长度,比如在BERT中,默认长度是512。如果文本序列长度短于固定长度,可以通过填充的方式来解决。如果序列长度超过固定长度,处理起来就比较麻烦。一种处理方式,就是将文本划分为多个segments。训练的时候,对每个segment单独处理,segments之间没有联系,如下图(a)所示。这存在两个问题,1)因为segments之间独立训练,所以不同的token之间,最长的依赖关系,就取决于segment的长度;2)出于效率的考虑,在划分segments的时候,不考虑句子的自然边界,而是根据固定的长度来划分序列,导致分割出来的segments在语义上是不完整的。

在预测的时候,会对固定长度的segment做计算,一般取最后一个位置的隐向量作为输出。为了充分利用上下文关系,在每做完一次预测之后,就对整个序列向右移动一个位置,再做一次计算,如上图(b)所示,这导致计算效率非常低。

Segment-Level Recurrence

为了解决上面提到的问题,在Trm的基础上,Trm-XL提出了一个改进,在对当前segment进行处理的时候,缓存并利用上一个segment中所有layer的隐向量序列,而且上一个segment的所有隐向量序列只参与前向计算,不再进行反向传播,这就是所谓的segment-level Recurrence。

Trm本身是可以设置multi-heads,但是在后文中为了简化描述采用单个head。将两个连续的segments表示为

  • S τ = [ x τ , 1 , x τ , 2 , … , x τ , L ] S_{\tau}=[x_{\tau,1},x_{\tau,2},\ldots,x_{\tau,L}] Sτ=[xτ,1,xτ,2,,xτ,L]
  • S τ + 1 = [ x τ + 1 , 1 , x τ + 1 , 2 , … , x τ + 1 , L ] S_{\tau+1}=[x_{\tau+1,1},x_{\tau+1,2},\ldots,x_{\tau+1,L}] Sτ+1=[xτ+1,1,xτ+1,2,,xτ+1,L]

L是序列长度假设整个模型中,包含N层Trm,那么每个segment中就有N组长度为L的隐向量序列,将第 τ \tau τ个segment的第n层隐向量序列表示为 h τ n ∈ R L × d h_{\tau}^{n}\in R^{L\times d} hτnRL×d,d是隐向量维度.那么第 τ + 1 \tau+1 τ+1个segment的第n层隐向量序列,可以由下面的一组公式计算得出。
h ~ τ + 1 n − 1 = [ S G ( h τ n − 1 ) , h τ + 1 n − 1 ] ( 表示对两个向量的拼接 , 拼接后为 2 L × d ) q τ + 1 n , k τ + 1 n , v τ + 1 n = h τ + 1 n − 1 W q T , h ~ τ + 1 n − 1 W k T , h ~ τ + 1 n − 1 W v T h τ + 1 n − 1 = T r a n s f o r m e r L a y e r ( q τ + 1 n , k τ + 1 n , v τ + 1 n ) \tilde{h}_{\tau+1}^{n-1} = [SG(h_{\tau}^{n-1}),h_{\tau+1}^{n-1}] \qquad (表示对两个向量的拼接,拼接后为2L\times d) \\ \qquad \\ q_{\tau+1}^n,k_{\tau+1}^n,v_{\tau+1}^n = h_{\tau+1}^{n-1}W_q^T,\tilde{h}_{\tau+1}^{n-1}W_k^T,\tilde{h}_{\tau+1}^{n-1}W_v^T \\ \qquad \\ {h}_{\tau+1}^{n-1} = Transformer\quad Layer(q_{\tau+1}^n,k_{\tau+1}^n,v_{\tau+1}^n) h~τ+1n1=[SG(hτn1),hτ+1n1](表示对两个向量的拼接,拼接后为2L×d)qτ+1n,kτ+1n,vτ+1n=hτ+1n1WqT,h~τ+1n1WkT,h~τ+1n1WvThτ+1n1=TransformerLayer(qτ+1n,kτ+1n,vτ+1n)
注意q的计算方式不变,只使用当前segment中的隐向量,计算得到的q序列长度仍然是L。k和v采用拼接之后的 h ~ \tilde{h} h~来计算,计算出来的序列长度是2L。之后的计算就是标准的Transformer计算。计算出来的第n层隐向量序列长度仍然是L,而不是2L。Trm的输出隐向量序列长度取决于query的序列长度,而不是key和value。

推导一下:

  • Q [ L × d ] ⋅ K T [ d × 2 L ] = [ L × 2 L ] ⋅ V [ 2 L × d ] = [ L × d ] Q[L\times d] \cdot K^T[d\times 2L] = [L\times 2L] \cdot V[2L\times d] = [L\times d] Q[L×d]KT[d×2L]=[L×2L]V[2L×d]=[L×d]

训练和预测过程如下图所示。这张图上有一个点需要注意,在当前segment中,第n层的每个隐向量的计算,都是利用下一层中包括当前位置在内的,连续前L个长度的隐向量,这是在上面的公式组中没有体现出来的,也是文中没有明说的。每一个位置的隐向量,除了自己的位置,都跟下一层中前(L-1)个位置的token存在依赖关系,而且每往下走一层,依赖关系长度会增加(L-1),如下图中Evaluation phase所示,所以最长的依赖关系长度是N(L-1),N是模型中layer的数量。N通常要比L小很多,比如在BERT中,N=12或者24,L=512,依赖关系长度可以近似为 O ( N × L ) O(N\times L) O(N×L) 。在对长文本进行计算的时候,可以缓存上一个segment的隐向量的结果,不必重复计算,大幅提高计算效率。

上文中,我们只保存了上一个segment,实际操作的时候,可以保存尽可能多的segments,只要内存或者显存放得下。论文中的试验在训练的时候,只缓存一个segment,在预测的时候,会缓存多个segments。

Relative Position Encodeings

在vanilla Trm中,为了表示序列中token的顺序关系,在模型的输入端,对每个token的输入embedding,加一个位置embedding。位置编码embedding或者采用正弦\余弦函数来生成,或者通过学习得到。在Trm-XL中,这种方法行不通,每个segment都添加相同的位置编码,多个segments之间无法区分位置关系。Trm-XL放弃使用绝对位置编码,而是采用相对位置编码,在计算当前位置隐向量的时候,考虑与之依赖token的相对位置关系。具体操作是,在算attention score的时候,只考虑query向量与key向量的相对位置关系,并且将这种相对位置关系,加入到每一层Trm的attention的计算中。

我们对两种方法做个对比。下面一组公式是vanilla Trm计算attention的方式, E x E_x Ex表示token的输入embedding,U是绝对位置编码embedding,两个W分别是query矩阵和key矩阵。下面的公式是对 ( E x i + U i ) W q T W k ( E x j + U j ) (E_{x_i}+U_i)W_q^TW_k(E_{x_j}+U_j) (Exi+Ui)WqTWk(Exj+Uj)做了分解。
A i , j a b s = E x i T W q T W K E x j + E x i T W q T W K U j + U i T W q T W K E x j + U i T W q T W K U j A_{i,j}^{abs} = E_{x_i}^TW_q^TW_KE_{x_j} + E_{x_i}^TW_q^TW_KU_j + U_{i}^TW_q^TW_KE_{x_j} + U_{i}^TW_q^TW_KU_{j} Ai,jabs=ExiTWqTWKExj+ExiTWqTWKUj+UiTWqTWKExj+UiTWqTWKUj

下面一组公式,是Trm-XL计算attention的方式。首先,将绝对位置编码U,替换成了相对位置编码 R i − j R_{i-j} Rij 。插一句,因为i只利用之前的序列,所以i-j>=0。并且把 W k W_k Wk矩阵分为 W k , E 和 W k , R W_{k,E}和W_{k,R} Wk,EWk,R,用于分别生成基于内容的key向量和基于位置的key向量,
A i , j r e l = E x i T W q T W k , E E x j + E x i T W q T W k , R R i − j + U i T W q T W k , E E x j + U j T W q T W k , R R i − j A_{i,j}^{rel} = E_{x_i}^TW_q^TW_{k,E}E_{x_j} + E_{x_i}^TW_q^TW_{k,R}R_{i-j} + U_{i}^TW_q^TW_{k,E}E_{x_j} + U_{j}^TW_q^TW_{k,R}R_{i-j} Ai,jrel=ExiTWqTWk,EExj+ExiTWqTWk,RRij+UiTWqTWk,EExj+UjTWqTWk,RRij

相对位置关系用一个位置编码矩阵 R ∈ R L m a x × d R\in R^{L_{max}\times d} RRLmax×d 来表示,第i行表示相对位置间隔为i的位置向量。论文中强调R采用正弦函数生成,而不是通过学习得到的,好处是预测时,可以使用比训练距离更长的位置向量。

最终总结

最后来看一下Trm-XL的完整计算公式,如下所示,只有前3行与vanilla Trm不同,后3行是一样的。第3行公式中,计算A的时候直接采用query向量,而不再使用 表示。最后需要注意的是,每一层在计算attention的时候,都要包含相对位置编码。而在vanilla Trm中,只有在输入embedding中才包含绝对位置编码,在中间层计算的时候,是不包含位置编码的。

h ~ τ + 1 n − 1 = [ S G ( h τ n − 1 ) , h τ + 1 n − 1 ] q τ + 1 n , k τ + 1 n , v τ + 1 n = h τ + 1 n − 1 W q T , h ~ τ + 1 n − 1 W k T , h ~ τ + 1 n − 1 W v T h τ + 1 n − 1 = T r a n s f o r m e r L a y e r ( q τ + 1 n , k τ + 1 n , v τ + 1 n ) A i , j r e l = E x i T W q T W k , E E x j + E x i T W q T W k , R R i − j + U i T W q T W k , E E x j + U j T W q T W k , R R i − j α τ n = M a s k e d S o f t m a x ( A τ n ) V τ n o τ n = L a y e r N o r m ( L i n e a r ( α τ n ) + h τ + 1 n − 1 ) h τ n = P o s i t i o n w i s e F e e d F o r w a r d ( o τ n ) \tilde{h}_{\tau+1}^{n-1} = [SG(h_{\tau}^{n-1}),h_{\tau+1}^{n-1}] \\ \qquad \\ q_{\tau+1}^n,k_{\tau+1}^n,v_{\tau+1}^n = h_{\tau+1}^{n-1}W_q^T,\tilde{h}_{\tau+1}^{n-1}W_k^T,\tilde{h}_{\tau+1}^{n-1}W_v^T \\ \qquad \\ {h}_{\tau+1}^{n-1} = Transformer\quad Layer(q_{\tau+1}^n,k_{\tau+1}^n,v_{\tau+1}^n) \\ \qquad \\ A_{i,j}^{rel} = E_{x_i}^TW_q^TW_{k,E}E_{x_j} + E_{x_i}^TW_q^TW_{k,R}R_{i-j} + U_{i}^TW_q^TW_{k,E}E_{x_j} + U_{j}^TW_q^TW_{k,R}R_{i-j} \\ \qquad \\ \alpha_{\tau}^n = Masked\quad Softmax(A_{\tau}^n)V_{\tau}^n \\ \qquad \\ o_{\tau}^n = LayerNorm(Linear(\alpha_{\tau}^n)+{h}_{\tau+1}^{n-1}) \\ \qquad \\ h_{\tau}^n = Positionwise\quad Feed\quad Forward(o_{\tau}^n) h~τ+1n1=[SG(hτn1),hτ+1n1]qτ+1n,kτ+1n,vτ+1n=hτ+1n1WqT,h~τ+1n1WkT,h~τ+1n1WvThτ+1n1=TransformerLayer(qτ+1n,kτ+1n,vτ+1n)Ai,jrel=ExiTWqTWk,EExj+ExiTWqTWk,RRij+UiTWqTWk,EExj+UjTWqTWk,RRijατn=MaskedSoftmax(Aτn)Vτnoτn=LayerNorm(Linear(ατn)+hτ+1n1)hτn=PositionwiseFeedForward(oτn)

总结,Trm-XL为了解决长序列的问题,对上一个segment做了缓存,可供当前segment使用,但是也带来了位置关系问题,为了解决位置问题,又打了个补丁,引入了相对位置编码。

transfromer-XL论文详解相关推荐

  1. swin-Transformer论文详解

    swin-Transformer论文详解 – 潘登同学的深度学习笔记 文章目录 swin-Transformer论文详解 -- 潘登同学的深度学习笔记 前言 网络架构 Swin transformer ...

  2. 智能城市dqn算法交通信号灯调度_博客 | 滴滴 KDD 2018 论文详解:基于强化学习技术的智能派单模型...

    原标题:博客 | 滴滴 KDD 2018 论文详解:基于强化学习技术的智能派单模型 国际数据挖掘领域的顶级会议 KDD 2018 在伦敦举行,今年 KDD 吸引了全球范围内共 1480 篇论文投递,共 ...

  3. Fast R-CNN论文详解

    Fast R-CNN论文详解 作者:ture_dream &创新点 规避R-CNN中冗余的特征提取操作,只对整张图像全区域进行一次特征提取: 用RoI pooling层取代最后一层max po ...

  4. 限时9.9元 | 快速领取数学建模竞赛备战必备技巧与论文详解!

    全世界只有3.14 % 的人关注了 青少年数学之旅 大家晚上好,随着美赛时间的公布以及大大小小的数学建模竞赛的进行,小天经常可以收到来自很多小伙伴们提出的问题,"竞赛中如何去考虑选题?&qu ...

  5. KernelGAN论文详解分享

    KernelGAN- Blind Super-Resolution Kernel Estimation using an Internal-GAN论文详解 论文地址:https://arxiv.org ...

  6. ShuffleNetv2论文详解

    ShuffleNet v2 论文详解 近期在研究轻量级 backbone 网络,我们所熟悉和工业界能部署的网络有 MobileNet V2.ShuffleNet V2.RepVGG 等,本篇博客是对 ...

  7. 【论文精读3】MVSNet系列论文详解-P-MVSNet

    P-MVSNet全名为"P-MVSNet: Learning Patch-wise Matching Confidence Aggregation for Multi-View Stereo ...

  8. Spark RDD 论文详解(三)Spark 编程接口

    前言 本文隶属于专栏<1000个问题搞定大数据技术体系>,该专栏为笔者原创,引用请注明来源,不足和错误之处请在评论区帮忙指出,谢谢! 本专栏目录结构和参考文献请见1000个问题搞定大数据技 ...

  9. Spark 3.2.0 版本新特性 push-based shuffle 论文详解(一)概要和介绍

    前言 本文隶属于专栏<大数据技术体系>,该专栏为笔者原创,引用请注明来源,不足和错误之处请在评论区帮忙指出,谢谢! 本专栏目录结构和参考文献请见大数据技术体系 目录 Spark 3.2.0 ...

最新文章

  1. mysql 类型 自动转化_自动MySQL数据类型转换
  2. gitlab创建分支上传文件_Gitlab管理和使用基本教程
  3. 贪心算法之活动选择问题
  4. 匹配特殊字符的正则表达式
  5. 【UDP通过多线程改进,在一个窗口中同时接收又发送】
  6. 力扣——搜索旋转排序数组
  7. 相干检测--概念,原理,科斯塔斯环
  8. C#3.0 new features: Lambda expression
  9. PS2022新增功能简介
  10. 批处理之for /f
  11. 想成为产品经理,应该怎么起步?
  12. 如何理解冲突域和广播域?(转)
  13. 条形码简介_条形码基本常识_条形码基本原理
  14. 【tensorboard】解决ValueError: Duplicate plugins for name projector
  15. Android选择颜色,尺码联动
  16. TencentOS-Tiny在苹果MacOS初上手
  17. win10 qq远程不上服务器未响应,win10 qq远程协助能移动鼠标却点击不了怎么办
  18. win7不显示移动硬盘_如何在移动硬盘中安装win10系统?
  19. 通用表查询返回所有行(只适用于单表)
  20. linux获取触控板信息,关于linux:Linux-下通过命令行和脚本开关笔记本触控板和其他输入外设...

热门文章

  1. Redis 缓存回收的7种策略volatile设置过期时间及allkeys所有数据范围内
  2. 7-7 韩信点兵 (10 分)
  3. 五个学习管理系统的优点
  4. Python turtle入门:用小海龟画美队盾牌 (内附画五角星的详细代码)
  5. Android获取硬件设备详细信息
  6. 我闺蜜她男朋友要我用Python写个脚本,每天不同时间段用微信给闺蜜发消息
  7. Chrome游览器下载
  8. 2022年全球20大国际航运中心榜单公布,上海蝉联第三,与新加坡伦敦差距缩小 | 美通社头条...
  9. linux skyeye,移植LINUX到SKYEYE上
  10. unity3d简单的粒子特效