对Transformer中Positional Encoding的理解

  • 1. 什么是Positional Encoding?为什么Transformer需要使用Positional Encoding?
    • Transformer的输入
  • 2. Positional Encoding是怎么做的?
    • 公式表达
    • 源码展示
    • 直观理解
    • 计算过程
  • 参考资料

Transformer是最新的处理序列到序列问题的架构,由self-attention组成,其优良的可并行性以及可观的表现提升,让它在NLP领域中大受欢迎,GPT-3以及BERT都是基于Transformer实现的。

刚开始学Transformer,对一个模块Positional Encoding存在一些疑问,因此,参考了一些资料和博客,学习如何理解Positional Encoding。

1. 什么是Positional Encoding?为什么Transformer需要使用Positional Encoding?

在任何一门语言中,词语的位置和顺序对句子意思表达都是至关重要的。传统的RNN模型在处理句子时,以序列的模式逐个处理句子中的词语,这使得词语的顺序信息在处理过程中被天然的保存下来了,并不需要额外的处理。

对于Transformer来说,由于句子中的词语都是同时进入网络进行处理,顺序信息在输入网络时就已丢失。因此,Transformer是需要额外的处理来告知每个词语的相对位置的。其中的一个解决方案,就是论文中提到的Positional Encoding,将能表示位置信息的编码添加到输入中,让网络知道每个词的位置和顺序。

一句话概括,Positional Encoding就是句子中词语相对位置的编码,让Transformer保留词语的位置信息

Transformer的输入


首先给出Transformer的输入部分,如上图所示。X:[batch size,sequence length]指的是初始输入的多语句矩阵,多语句矩阵通过查表,得到词向量矩阵XembeddingX_{embedding}Xembedding:[batch size,sequence length,embedding dimension]。batch size指的是句子数,sequence length指的是输入的句子中最长的句子的字数,embedding dimension指的是词向量的长度(通过查表得到)。

XXXXembeddingX_{embedding}Xembedding的示意图如下图所示:

2. Positional Encoding是怎么做的?

要表示位置信息,首先出现在脑海里的一个点子可能是,给句子中的每个词赋予一个相位,也就是[0, 1]中间的一个值,第一个词是0,最后一个词是1,中间的词在0到1之间取值

但是这样会不会有什么问题呢?其中一个问题在于,你并不知道每个句子中词语的个数是多少,这会导致每个词语之间的间隔变化是不一致的。而对于一个句子来说,每个词语之间的间隔都应该是具有相同含义的。

那,为了保证每个词语的间隔含义一致,我们是不是可以给每个词语添加一个线性增长的时间戳呢?比如说第一个词是0,第二词是1,以此类推,第N个词的位置编码是N。

这样其实也会有问题。同样,我们并不知道一个句子的长度,如果训练的句子很长的话,这样的编码是不合适的。 另外,这样训练出来的模型,在泛化性上是有一定问题的。

因此,理想情况下,编码方式应该要满足以下几个条件:

  1. 对于每个位置的词语,它都能提供一个独一无二的编码
  2. 词语之间的间隔对于不同长度的句子来说,含义应该是一致的
  3. 能够随意延申到任意长度的句子

文中提出了一种简单且有效的编码方式,能够满足上述所有条件。

公式表达


其中,PE为二维矩阵,大小跟输入embedding的维度一样,行表示词语,列表示词向量;pospospos表示词语在句子中的位置;dmodeld_{model}dmodel表示词向量的维度;iii表示词向量的位置。因此,上述公式表示在每个词语的词向量的偶数位置添加sin变量,奇数位置添加cos变量,以此来填满整个PE矩阵,然后加到input embedding中去,这样便完成了位置编码的引入,

使用sin编码和cos编码的原因是可以得到词语之间的相对位置,因为:
sin⁡(α+β)=sin⁡αcos⁡β+cos⁡αsin⁡β\sin{(\alpha+\beta)} = \sin{\alpha}\cos{\beta}+\cos{\alpha}\sin{\beta}sin(α+β)=sinαcosβ+cosαsinβ
cos⁡(α+β)=cos⁡αcos⁡β−sin⁡αsin⁡β\cos{(\alpha+\beta)} = \cos{\alpha}\cos{\beta} - \sin{\alpha}\sin{\beta}cos(α+β)=cosαcosβsinαsinβ

即由sin⁡(pos+k)\sin(pos+k)sin(pos+k)可以得到,通过线性变换获取后续词语相对当前词语的位置关系。

源码展示

class PositionalEncoding(nn.Module):"Implement the PE function."def __init__(self, d_model, dropout, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)# Compute the positional encodings once in log space.pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len).unsqueeze(1)div_term = torch.exp(torch.arange(0., d_model, 2) * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)self.register_buffer('pe', pe)def forward(self, x):x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)return self.dropout(x)

其中,div_term是上述公式经过简单的数学变换得到的,具体如下:
1/100002i/dmodel=elog(10000−2i/dmodel)=e−2i/dmodel∗log10000=e2i∗(−log10000/dmodel)1/10000^{2i/d_{model}}=e^{log{(10000^{-2i/d_{model}})}}=e^{-2i/d_{model} * log{10000}}=e^{2i * (-log{10000}/d_{model})}1/100002i/dmodel=elog(100002i/dmodel)=e2i/dmodellog10000=e2i(log10000/dmodel)

直观理解

为什么这样简单的sines和cosines的组合可以表达位置信息呢?一开始的确有点难理解。举个二进制的例子就明白了。可以观察一下下面这个表,将数字用二进制表示出来。可以发现,每个比特位的变化率是不一样的,越低位的变化越快,红色位置0和1每个数字会变化一次,而黄色位,每8个数字才会变化一次。

不同频率的sines和cosines组合其实也是同样的道理,通过调整三角函数的频率,可以实现这种低位到高位的变化,这样的话,位置信息就表示出来了。

计算过程


如上图所示,word embedding指的是词向量由每个词根据查表得到,pos embedding就是我们要求的Positional Encoding,也就是位置编码。可以看到word embedding和pos embedding逐点相加得到composition,即包含语义信息和位置编码信息的最终矩阵。

回到公式中,我们可以得知:pospospos指当前字符在句子中的位置(如:“你好啊”,这句话里面“你”的pos=0pos=0pos=0),dmodeld_{model}dmodel指的是word embedding的长度(比如说:查表得到“民主”这个词的word embedding为[1,2,3,4,5][1,2,3,4,5][1,2,3,4,5],则dmodel=5d_{model}=5dmodel=5),iii的取值范围是:i=0,1,...,dmodel−1i=0,1,...,d_{model}-1i=0,1,...,dmodel1。当iii的值为偶数时使用上面那条公式,当iii的值为奇数时使用下面那条公式。当pos=3,dmodel=128pos=3, d_{model}=128pos=3,dmodel=128时Positional Encoding(或者说是pos embedding)的计算结果为:

每一个字所计算出来的Positional Encoding并不是一个值而是一个向量,它的长度和这个字的word embedding的长度一致,从而方便它们两个逐点相加得到既包含word embedding又包含位置信息的最终向量。

参考资料

1.https://zhuanlan.zhihu.com/p/338592312
2.https://blog.csdn.net/weixin_44012382/article/details/113059423
3.https://arxiv.org/pdf/1706.03762.pdf

【AI理论学习】对Transformer中Positional Encoding的理解相关推荐

  1. transformer中QKV的通俗理解(渣男与备胎的故事)

    transformer中QKV的通俗理解(渣男与备胎的故事) 用vit的时候读了一下transformer的思想,前几天面试结束之后发现对QKV又有点忘记了, 写一篇文章来记录一下 参考链接: 哔哩哔 ...

  2. 深度学习笔记--Transformer中position encoding的源码理解与实现

    1--源码 import torch import math import numpy as np import torch.nn as nnclass Pos_Embed(nn.Module):de ...

  3. transformer中相对位置编码理解

    对于一副图像,位置信息占有非常重要的地位,ViT中用了绝对位置编码,Swin中用到了相对位置编码.看了Swin的源码,参考了https://blog.csdn.net/qq_37541097/arti ...

  4. 对Transformer中的MASK理解

    对Transformer中的MASK理解 Padding Masked Self-Attention Masked 上一篇文章我们介绍了 对Transformer中FeedForward层的理解,今天 ...

  5. 基于Transformer的文本情感分析编程实践(Encoder编码器-Decoder解码器框架 + Attention注意力机制 + Positional Encoding位置编码)

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) Encoder编码器-Decoder解码器框架 + Atten ...

  6. Positional Encodings in ViTs 近期各视觉Transformer中的位置编码方法总结及代码解析 1

    Positional Encodings in ViTs 近期各视觉Transformer中的位置编码方法总结及代码解析 最近CV领域的Vision Transformer将在NLP领域的Transo ...

  7. transformer:self-attention,muti-head attention,positional encoding

    文章目录 transformer和RNN.LSTM相比 seq2seq 编码器-解码器架构 What is Input? What is Output? N-N:each vector has a l ...

  8. 透彻分析Transformer中的位置编码(positional enconding)

    一.Transformer中为什么要使用位置编码positional encoding 在<Attention Is All You Need>这篇论文中首次提到了transformer模 ...

  9. 深度学习之图像分类(十七)-- Transformer中Self-Attention以及Multi-Head Attention详解

    深度学习之图像分类(十七)Transformer中Self-Attention以及Multi-Head Attention详解 目录 深度学习之图像分类(十七)Transformer中Self-Att ...

最新文章

  1. PowerDesigner的数据类型
  2. Ubuntu14.04 ROS Indigo安装教程,以及卸载方法
  3. R开发(part11)--基于S4的面向对象编程
  4. 微服务技术发展的现状与展望
  5. Linux下redmine安装插件报错
  6. 对vuex在项目中的使用
  7. 【读】这一次,让我们再深入一点 - TCP协议
  8. JAVA中super和this关键字的区别
  9. Markdown 编辑器 Editor.md 使用
  10. C# List最大值最小值问题 List排序问题 List Max/Min
  11. Dagger2 学习
  12. 华为IE和思科IE哪个好?
  13. php如何开启COM组件
  14. HTML获取当前IP和当前位置
  15. keil c支持汇编语言吗,keil中用汇编实现hello.c的功能
  16. Linux安装Googlepinyin
  17. PIN PUK1
  18. 14年优质服务 海科融通进军P2P资金托管
  19. Collection类和泛型
  20. Xiaojie雷达之路---MATLAB仿真---RD(range-doppler)图

热门文章

  1. Java并发编程面试题(2022最新版)
  2. C语言冒泡排序(起泡法)
  3. 04 | 连接池:别让连接池帮了倒忙
  4. Nginx正向代理和反向代理的区别
  5. C语言_函数递归举例
  6. mysql联合索引如何创建
  7. 基于C#通过PLCSIM ADV仿真软件实现与西门子1500PLC的S7通信方法演示
  8. python字典键值对的添加和遍历
  9. 隔离式DC/DC高压模块5V12V24V转50V110V250V300V380V600V1100V短路保护直流升压可调开关控制电源模块
  10. AIS(ACL,IJCAI,SIGIR)(2019)论文报告会,感受大佬的气息...