Transformer结构详解(有图,有细节)
文章目录
- 1. transformer的基本结构
- 2. 模块详解
- 2.1 模块1:Positional Embedding
- 2.2 模块2:Multi-Head Attention
- 2.2.1 Scaled Dot-Product Attention
- 2.2.2 Multi-Head
- 2.3 模块3:ADD
- 2.4 模块4:Layer Normalization
- 2.5 模块5:Feed Forward NetWork
- 2.6 模块6:Masked Multi-Head Attention
- 2.7 模块7: Multi-Head Attention
- 2.8 模块8:Linear
- 2.9 模块9:SoftMax
- 3. transformer在机器翻译任务中的使用
- 4 transformer 相关的其它问题
1. transformer的基本结构
2. 模块详解
2.1 模块1:Positional Embedding
- p o s pospos代表的是一个字在句子中的位置,从0到名字长度减1,是下图中红色的序号。
- i ii代表的是dim 的序号,是下图中蓝色的序号:
- 当i ii为偶数时,此位置的值使用 s i n ( p o s / 1000 0 2 i / d m o d e l ) sin(pos/10000^{2i/d_{model}})sin(pos/100002i/dmodel)来填充。
- 当i ii为奇数时,些位置的值使用 c o s ( p o s / 1000 0 2 i / d m o d e l ) cos(pos/10000^{2i/d_{model}})cos(pos/100002i/dmodel)来填充
实现代码:
class PositionalEncoding(nn.Module):"Implement the PE function."def __init__(self, d_model, dropout, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)# Compute the positional encodings once in log space.pe = torch.zeros(max_len, d_model).float()position = torch.arange(0, max_len).unsqueeze(1).float()div_term = torch.exp(torch.arange(0, d_model, 2).float() *-(math.log(10000.0) / d_model)).float()pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)self.register_buffer('pe', pe)def forward(self, x):x = x + Variable(self.pe[:, :x.size(1)],requires_grad=False)return self.dropout(x)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
至于为什么选择这种方式,论文中给出的解释是:
- 我们之所以选择这个函数,是因为我们假设它可以让模型很容易地通过相对位置来学习,因为对任意确定的偏移k kk, P E p o s + k PE_{pos+k}PEpos+k可以表示为P E p o s PE_{pos}PEpos的线性函数。
理解:
由s i n ( α + β ) = s i n α c o s β + s i n β c o s α c o s ( α + β ) = c o s α c o s β − s i n β s i n α sin(\alpha+\beta)=sin\alpha cos\beta + sin\beta cos\alpha\\ cos(\alpha+\beta)=cos\alpha cos\beta - sin\beta sin\alphasin(α+β)=sinαcosβ+sinβcosαcos(α+β)=cosαcosβ−sinβsinα
推出:
P E ( p o s + k , 2 i ) = s i n ( ( p o s + k ) / 1000 0 2 i / d m o d e l ) = s i n ( p o s / 1000 0 2 i / d m o d e l ) c o s ( k / 1000 0 2 i / d m o d e l ) + s i n ( p o s / 1000 0 2 i / d m o d e l ) c o s ( k / 1000 0 2 i / d m o d e l ) = P E ( p o s , 2 i ) P E ( k , 2 i + 1 ) − P E ( p o s , 2 i + 1 ) P E ( k , 2 i )PE(pos+k,2i)=sin((pos+k)/100002i/dmodel)=sin(pos/100002i/dmodel)cos(k/100002i/dmodel)+sin(pos/100002i/dmodel)cos(k/100002i/dmodel)=PE(pos,2i)PE(k,2i+1)−PE(pos,2i+1)PE(k,2i)PE(pos+k,2i)=sin((pos+k)/100002i/dmodel)=sin(pos/100002i/dmodel)cos(k/100002i/dmodel)+sin(pos/100002i/dmodel)cos(k/100002i/dmodel)=PE(pos,2i)PE(k,2i+1)−PE(pos,2i+1)PE(k,2i)
PE(pos+k,2i)=sin((pos+k)/100002i/dmodel)=sin(pos/100002i/dmodel)cos(k/100002i/dmodel)+sin(pos/100002i/dmodel)cos(k/100002i/dmodel)=PE(pos,2i)PE(k,2i+1)−PE(pos,2i+1)PE(k,2i)
P E ( p o s + k , 2 i + 1 ) = c o s ( ( p o s + k ) / 1000 0 2 i / d m o d e l ) = c o s ( p o s / 1000 0 2 i / d m o d e l ) c o s ( k / 1000 0 2 i / d m o d e l ) − s i n ( p o s / 1000 0 2 i / d m o d e l ) s i n ( k / 1000 0 2 i / d m o d e l ) = P E ( p o s , 2 i + 1 ) P E ( k , 2 i + 1 ) − P E ( p o s , 2 i ) P E ( k , 2 i )PE(pos+k,2i+1)=cos((pos+k)/100002i/dmodel)=cos(pos/100002i/dmodel)cos(k/100002i/dmodel)−sin(pos/100002i/dmodel)sin(k/100002i/dmodel)=PE(pos,2i+1)PE(k,2i+1)−PE(pos,2i)PE(k,2i)PE(pos+k,2i+1)=cos((pos+k)/100002i/dmodel)=cos(pos/100002i/dmodel)cos(k/100002i/dmodel)−sin(pos/100002i/dmodel)sin(k/100002i/dmodel)=PE(pos,2i+1)PE(k,2i+1)−PE(pos,2i)PE(k,2i)
PE(pos+k,2i+1)=cos((pos+k)/100002i/dmodel)=cos(pos/100002i/dmodel)cos(k/100002i/dmodel)−sin(pos/100002i/dmodel)sin(k/100002i/dmodel)=PE(pos,2i+1)PE(k,2i+1)−PE(pos,2i)PE(k,2i)
以P E ( p o s + k , 2 i ) = P E ( p o s , 2 i ) P E ( k , 2 i + 1 ) − P E ( p o s , 2 i + 1 ) P E ( k , 2 i ) PE(pos+k,2i)=PE(pos,2i)PE(k,2i+1)-PE(pos,2i+1)PE(k,2i)PE(pos+k,2i)=PE(pos,2i)PE(k,2i+1)−PE(pos,2i+1)PE(k,2i)为例,当k kk确定时: P E ( k , 2 i + 1 ) PE(k,2i+1)PE(k,2i+1)、P E ( k , 2 i ) PE(k,2i)PE(k,2i)均为常数,P E ( p o s + k , 2 i ) = P E ( p o s , 2 i ) ∗ 常 数 2 i + 1 k − P E ( p o s , 2 i + 1 ) ∗ 常 数 i k PE(pos+k,2i)=PE(pos,2i) * 常数_{2i+1}^k - PE(pos,2i+1) * 常数_{i}^kPE(pos+k,2i)=PE(pos,2i)∗常数2i+1k−PE(pos,2i+1)∗常数ik
上式为即为1)中所说的线性函数。我们知道,每个位置(pos)的PE值均不同,因此我们可以根据PE的值区分位置,而由上面的线性函数,我们可以计量出两个位置的相对距离。 - 我们还尝试使用预先学习的positional embeddings 来代替正弦波,发现这两个版本产生了几乎相同的结果 。我们之所以选择正弦曲线,是因为它允许模型扩展到比训练中遇到的序列长度更长的序列。
理解:
第二点很好理解就是说了下正弦波的优点。这里我着重讲下正弦波存在的问题。在transformer架构里,我们计算两个特征的关系用的是点积的的方式(因为使用了Dot-Product Attention)。所以两个PE的关系(距离)实际是以它们的点积来表示的。举例如下[ 1 ] ^{[1]}[1]:
我们令c i = 1 / 1000 0 2 i / d m o d e l c_i=1/10000^{2i/d_{model}}ci=1/100002i/dmodel,则第t tt及t + 1 t+1t+1个位置的positional embedding 是:
P E t = [ s i n ( c 0 t ) c o s ( c 0 t ) s i n ( c 1 t ) c o s ( c 1 t ) ⋮ s i n ( c d 2 − 1 t ) c o s ( c d 2 − 1 t ) ] T PE_t={\left[ {sin(c0t)cos(c0t)sin(c1t)cos(c1t)⋮sin(cd2−1t)cos(cd2−1t)sin(c0t)cos(c0t)sin(c1t)cos(c1t)⋮sin(cd2−1t)cos(cd2−1t)
} \right]^T}PEt=⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎡sin(c0t)cos(c0t)sin(c1t)cos(c1t)⋮sin(c2d−1t)cos(c2d−1t)⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎤T
P E t + k = [ s i n ( c 0 ( t + k ) ) c o s ( c 0 ( t + k ) ) s i n ( c 1 ( t + k ) ) c o s ( c 1 ( t + k ) ) ⋮ s i n ( c d 2 − 1 ( t + k ) ) c o s ( c d 2 − 1 ( t + k ) ) ] T PE_{t+k}={\left[ {sin(c0(t+k))cos(c0(t+k))sin(c1(t+k))cos(c1(t+k))⋮sin(cd2−1(t+k))cos(cd2−1(t+k))sin(c0(t+k))cos(c0(t+k))sin(c1(t+k))cos(c1(t+k))⋮sin(cd2−1(t+k))cos(cd2−1(t+k))
} \right]^T}PEt+k=⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎡sin(c0(t+k))cos(c0(t+k))sin(c1(t+k))cos(c1(t+k))⋮sin(c2d−1(t+k))cos(c2d−1(t+k))⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎤T
则:P E t P E t + k = Σ j = 0 d 2 [ s i n ( c j t ) s i n ( c j ( t + k ) + c o s ( c j t ) c o s ( c j ( t + k ) ] = Σ j = 0 d 2 c o s ( c j ( t − ( t + k ) ) = Σ j = 0 d 2 c o s ( c j k )PEtPEt+k=Σd2j=0[sin(cjt)sin(cj(t+k)+cos(cjt)cos(cj(t+k)]=Σd2j=0cos(cj(t−(t+k))=Σd2j=0cos(cjk)PEtPEt+k=Σj=0d2[sin(cjt)sin(cj(t+k)+cos(cjt)cos(cj(t+k)]=Σj=0d2cos(cj(t−(t+k))=Σj=0d2cos(cjk)
PEtPEt+k=Σj=02d[sin(cjt)sin(cj(t+k)+cos(cjt)cos(cj(t+k)]=Σj=02dcos(cj(t−(t+k))=Σj=02dcos(cjk)
上式的第二行是使用了 c o s ( α − β ) = s i n α s i n β + c o s α c o s β cos(\alpha-\beta)=sin\alpha sin\beta + cos\alpha cos\betacos(α−β)=sinαsinβ+cosαcosβ 这个公式进行的变换。从最终的结果我们可以看出,两个embedding的距离度量只与间隔k kk有关,而c o s coscos函数关于y轴对称,即c o s x = c o s ( − x ) cosx=cos(-x)cosx=cos(−x),所以,P E t P E t + k PE_tPE_{t+k}PEtPEt+k的度量只与k kk的大小有关,与谁在前,谁在后无关。即,经过dot-attention机制后,我们把positional embedding中的顺序信息丢失了。所以,从这方面看,正弦波这种位置PE并不太适合在用在transformer结构中,这也可能是后面的bert,t5都采用的基于学习的positional embedding。(注:模块3会把顺序信息传递下去,但我们还是在算法的核心处理上丢失了信息。)
2.2 模块2:Multi-Head Attention
这个模块是transformer的核心,我们把这块拆成两部分来理解,先讲下其中的Scaled Dot-Product Attention(缩放的点积注意力机制),再讲Multi-Head。
2.2.1 Scaled Dot-Product Attention
我们先看下论文中的 Scaled Dot-Product Attention 步骤,如下图:
下面我们对着上面的图讲一下,具体的看下每步做了什么。
由于linear的输入和输出均为d m o d e l d_{model}dmodel,所以Q,K,V的大小和input_sum的大小是一致的。
MatMul: 这步是实际是计算的 Q ∗ K T Q*K^TQ∗KT, 如下图:
从上图可以看出Q ∗ K T Q*K^TQ∗KT的结果s c o r e s scoresscores是一个L ∗ L L*LL∗L的矩阵(L为句字长度),其中scores中的[ i , j ] [i,j][i,j]位置表示的是Q QQ中的第i ii行的字和K T K^TKT中第j jj列的相似度(也可以说是重要度,我们可以这么理解,在机器翻译任务中,当我们翻译一句话的第i ii个字的的时候,我们要考虑原文中哪个位置的字对我们现在要翻译的这个位置的字的影响最大)。Scale :这部分就是对上面的s c o r e s scoresscores进行了个类似正则化的操作。
s c o r e s = s c o r e s d q scores=\frac{scores}{\sqrt{d_q}}scores=dqscores (这里要说一下d q d_{q}dq,论文中给出的是d h d_{h}dh,即d m o d e l / h d_{model}/hdmodel/h, 因为论文中做了multi-head,所以 d q = d h d_q=d_{h}dq=dh),这里解释下除以d q \sqrt{d_q}dq的原因,原文是这样说的:“我们认为对于大的d k d_kdk,点积在数量级上增长的幅度大,将softmax函数推向具有极小梯度的区域4 ^44。为了抵消这种影响,我们对点积扩展1 d k \frac{1}{\sqrt{d_k}}dk1倍”。Mask: 这步使用一个很小的值,对指定位置进行覆盖填充。这样,在之后计算softmax时,由于我们填充的值很小,所以计算出的概率也会很小,基本就忽略了。(如果不填个很小的值的话,后面我们计算softmax时,e x i ∑ i = 1 k e x i \frac{e^{x_i}}{\sum_{i=1}^{k}{e^{x_i}}}∑i=1kexiexi ,当x = 0 x=0x=0时(padding的值),分子e 0 = 1 e^{0}=1e0=1这可不是一个很小的值。),mask操作在encoder和decoder过程中都存在,在encoder中我们是对padding的值进行mask,在decoder中我们主要是为了不让前面的词在翻译时看到未来的词,所以对当前词之后的词的信息进行mask。下面我们先看看encoder中关于padding的mask是怎么做的。
如上图,输入中有两个pad字符,s c o r e s scoresscores中的x都是pad参与计算产生的,我们为了排除pad产生的影响,我们提供了如图的mask,我们把scores与mask的位置一一对应,如果mask的值为0,则scores的对应位置填充一个非常小的负数(例如:− e 9 -e^9−e9)。最终得到的是上图最后一个表格。说了这么多,其实在pytorch中就一句话。
scores = scores.masked_fill(mask == 0, -1e9)
- 1
- SoftMax: 对scores中的数据按行做softmax。这样就把权得转换成了概率。
- MatMul: 这步就是使用softmax后的概率值与V VV矩阵做矩阵乘法。
附上代码:
def attention(query, key, value, mask=None):d_k = query.size(-1)scores = torch.matmul(query, key.transpose(-2, -1)) \/ math.sqrt(d_k)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)p_attn = F.softmax(scores, dim = -1)return torch.matmul(p_attn, value)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
2.2.2 Multi-Head
2.3 模块3:ADD
2.4 模块4:Layer Normalization
2.5 模块5:Feed Forward NetWork
2.6 模块6:Masked Multi-Head Attention
2.7 模块7: Multi-Head Attention
2.8 模块8:Linear
此模块的目的是把模型的输transformer decoder的输出从d m o d e l d_{model}dmodel维度映射到词表大小的维度。linear本身也比较简单,这里不再细讲了。
2.9 模块9:SoftMax
此模块会把上层linear的输出转化成概率,对应到某个字的概率。
3. transformer在机器翻译任务中的使用
4 transformer 相关的其它问题
这部分我是想写写transformer的并行等其它问题,但今天写的太累了,主要的也都写完了,就先发了。
References
[ 1 ] [1][1] https://zhuanlan.zhihu.com/p/166244505
Transformer结构详解(有图,有细节)相关推荐
- Transformer(二)--论文理解:transformer 结构详解
转载请注明出处:https://blog.csdn.net/nocml/article/details/110920221 本系列传送门: Transformer(一)–论文翻译:Attention ...
- 数据结构图,图存储结构详解
1. 数据结构的图存储结构 我们知道,数据之间的关系有 3 种,分别是 "一对一"."一对多" 和 "多对多",前两种关系的数据可分别用线性 ...
- Transformer模型详解(图解最完整版)
前言 Transformer由论文<Attention is All You Need>提出,现在是谷歌云TPU推荐的参考模型.论文相关的Tensorflow的代码可以从GitHub获取, ...
- ResNet结构详解
ResNet结构详解 ResNet的层数34,50,101到底指什么? 首先看ResNet34的对比图 然后再看这个表 ResNet 到底是个什么结构 ResNet-34 虚线结构 ResNet-50 ...
- 一文看懂Transformer(详解)
文章目录 Transformer 前言 网络结构图: Encoder Input Embedding Positional Encoder self-attention Padding mask Ad ...
- ViT( Vision Transformer)详解
文章目录 (一)参考博客和PPT原文件下载连接 (二)VIT原理详解 2.1.self-attention 2.2.sequence序列之间相关性 α \boldsymbol{\alpha} α的求解 ...
- Windows GPT磁盘GUID结构详解
前一篇 Windows磁盘MBR结构详解 中我们介绍了Basic Disk中的Master Boot Record结构.GPT Disk作为Windows 2003以后引入的分区结构.使用了GUID分 ...
- 微信小程序01【目录结构详解、视图与渲染、事件、input、scroll-view】
学习地址:https://www.bilibili.com/video/BV1sx411z77P 笔记01:https://blog.csdn.net/weixin_44949135/article/ ...
- 搞一下 车载以太网实战 | 01 车载以太网帧结构详解
前言 搞SOA.搞 AP & CP AUTOSAR.搞异构SoC.搞车载以太网.搞车载OS等就找搞一下汽车电子. 全系内容可在<搞一下汽车电子>后台回复 "系列" ...
最新文章
- Makefile经典教程
- 简单粗暴地理解js原型链–js面向对象编程
- python+requests+re匹配抓取猫眼上映电影信息
- Pollar Rho算法
- 诗与远方:无题(六十六)- 清明时节雨纷下
- 2.3 logistic 回归损失函数
- MDC记录activiti流程ID
- visio画图复制粘贴到word_visio复制粘贴到word中
- 各类w3school网站的区别小记
- mysql pxc 原理_mysql PXC配置
- XTU 1236 Fibonacci
- 压摆率和上升时间的区别
- 证监会计算机类笔试上岸经验,公务员考试笔试166分上岸经验(全干货)
- 【趣读官方文档】1.管家的抉择 (Android进程生命周期)
- EXCEL电子表格使用技巧
- 香橙派 One Plus 像单片机一样硬件寄存器 控制GPIO 点灯
- 投身开源,需要持之以恒的热爱与贡献 —— Apache Spark Committer 姜逸坤
- turtlepen画出小黄人
- linux虚拟机和电脑ping通(可上网)
- 节流(Throttle)与防抖(Debounce)区别与demo实现+ 图解
热门文章
- 请简述计算机软件系统与硬件系统的关系,电脑硬件与软件的关系是什么?
- C++继承中的访问级别
- linux 设备驱动 百度,Linux设备驱动之input子系统
- php页面中文乱码分析,PHP页面中文乱码分析
- html5 项目案例_互动案例技术分析(3)
- 5获取按钮返回值消息_大数据从入门到深入:JavaEE 之 项目实战 项目基础编码阶段(5)...
- tf.dynamic_stitch 和 tf.dynamic_partition
- 645. Set Mismatch(python)
- anaconda不同虚拟环境下使用jupyter的问题
- 文巾解题 7. 整数反转