Transformer-XL解读(论文 + PyTorch源码)
本文将主要针对模型原理及其PyTorch实现进行逐一对照解读,因笔者能力有限,如有不详尽之处,可移步文末的传送门进行详细阅读,并欢迎指出~
文章目录
- 前言
- 一. 回顾Transformer
- 二. vanilla Transformer
- 三. Transformer-XL
- 1. 引入循环机制
- 2. 相对位置编码
- 3. 整体计算公式
- 四. PyTorch实现
- 五. 实验结果
- 1. 语言建模指标
- 2. 两个创新点的优势
- 3. 测试阶段的速度
- 六. 总结
- 1. 模型特点
- 2. 优点
- 3. 不足
- 传送门
一. 回顾Transformer
二. vanilla Transformer
为何要提这个模型?因为Transformer-XL是基于这个模型进行的改进。
三. Transformer-XL
1. 引入循环机制
原则上只要GPU内存允许,该方法可以利用前面更多段的信息,测试阶段也可以获得更长的依赖。
2. 相对位置编码
在Transformer-XL中,对上述的attention计算方式进行了变换,转为相对位置的计算,而且不仅仅在第一层这么计算,在每一层都是这样计算。
对比来看,主要有三点变化:
- 在(b)和(d)这两项中,将所有绝对位置向量U j U_jUj都转为相对位置向量R i − j R_{i-j}Ri−j,与Transformer一样,这是一个固定的编码向量,不需要学习。
- 在(c)这一项中,将查询的U i T W q T U_i^TW_q^TUiTWqT向量转为一个需要学习的参数向量u uu,因为在考虑相对位置的时候,不需要查询的绝对位置i ii,因此对于任意的i ii,都可以采用同样的向量。同理,在(d)这一项中,也将查询的U i T W q T U_i^TW_q^TUiTWqT向量转为另一个需要学习的参数向量v vv。
- 将键的权重变换矩阵W k W_kWk转为W k , E W_{k, E}Wk,E和W k , R W_{k, R}Wk,R,分别作为content-based key vectors和location-based key vectors。
从另一个角度来解读这个公式的话,可以将attention的计算分为如下四个部分:
3. 整体计算公式
四. PyTorch实现
笔者在这里主要研究的是核心模型部分,将针对关键的实现细节进行剖析,想要看完整代码的读者请戳这里。
class PositionalEmbedding(nn.Module):def __init__(self, demb):super(PositionalEmbedding, self).__init__()self.demb = dembinv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))def forward(self, pos_seq):sinusoid_inp = torch.ger(pos_seq, self.inv_freq)pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)return pos_emb[:,None,:]
class MultiHeadAttn(nn.Module):def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False):super(MultiHeadAttn, self).__init__()self.n_head = n_headself.d_model = d_modelself.d_head = d_headself.dropout = dropoutself.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)self.drop = nn.Dropout(dropout)self.dropatt = nn.Dropout(dropatt)self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)self.layer_norm = nn.LayerNorm(d_model)self.scale = 1 / (d_head ** 0.5)self.pre_lnorm = pre_lnormself.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)def _rel_shift(self, x, zero_triu=False):zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),device=x.device, dtype=x.dtype)x_padded = torch.cat([zero_pad, x], dim=1)x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])x = x_padded[1:].view_as(x)if zero_triu:ones = torch.ones((x.size(0), x.size(1)))x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]return xdef forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)if mems is not None:cat = torch.cat([mems, w], 0)if self.pre_lnorm:w_heads = self.qkv_net(self.layer_norm(cat))else:w_heads = self.qkv_net(cat)r_head_k = self.r_net(r)w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)w_head_q = w_head_q[-qlen:]else:if self.pre_lnorm:w_heads = self.qkv_net(self.layer_norm(w))else:w_heads = self.qkv_net(w)r_head_k = self.r_net(r)w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)klen = w_head_k.size(0)w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_headw_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_headw_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_headr_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head#### compute attention scorerw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_headAC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_headrr_head_q = w_head_q + r_r_biasBD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_headBD = self._rel_shift(BD)# [qlen x klen x bsz x n_head]attn_score = AC + BDattn_score.mul_(self.scale)#### compute attention probabilityif attn_mask is not None and attn_mask.any().item():if attn_mask.dim() == 2:attn_score = attn_score.float().masked_fill(attn_mask[None,:,:,None], -float('inf')).type_as(attn_score)elif attn_mask.dim() == 3:attn_score = attn_score.float().masked_fill(attn_mask[:,:,:,None], -float('inf')).type_as(attn_score)# [qlen x klen x bsz x n_head]attn_prob = F.softmax(attn_score, dim=1)attn_prob = self.dropatt(attn_prob)#### compute attention vectorattn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))# [qlen x bsz x n_head x d_head]attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)##### linear projectionattn_out = self.o_net(attn_vec)attn_out = self.drop(attn_out)if self.pre_lnorm:##### residual connectionoutput = w + attn_outelse:##### residual connection + layer normalizationoutput = self.layer_norm(w + attn_out)return output
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
def _update_mems(self, hids, mems, qlen, mlen):# does not deal with Noneif mems is None: return None# mems is not Noneassert len(hids) == len(mems), 'len(hids) != len(mems)'# There are `mlen + qlen` steps that can be cached into mems# For the next step, the last `ext_len` of the `qlen` tokens# will be used as the extended context. Hence, we only cache# the tokens from `mlen + qlen - self.ext_len - self.mem_len`# to `mlen + qlen - self.ext_len`.with torch.no_grad():new_mems = []end_idx = mlen + max(0, qlen - 0 - self.ext_len)beg_idx = max(0, end_idx - self.mem_len)for i in range(len(hids)):cat = torch.cat([mems[i], hids[i]], dim=0)new_mems.append(cat[beg_idx:end_idx].detach())return new_mems
这里的hids
是当前段每层的输出,mems
为当前段每层依赖的memory,qlen
为序列长度,mlen
为当前段依赖的memory的长度。
从代码来看的话,前面的循环示意图似乎有些问题?感觉在训练阶段,对于每个段里面的第二个位置开始的点,都应该连到第一个位置连到的最前面memory?因为用的是同样长度的memory。
五. 实验结果
1. 语言建模指标
2. 两个创新点的优势
3. 测试阶段的速度
六. 总结
1. 模型特点
在 AI-Rfou 等人提出的vanilla Transformer上做了两点创新:
2. 优点
- 在几种不同的数据集(大/小,字符级别/单词级别等)均实现了最先进的语言建模结果。
- 结合了深度学习的两个重要概念——循环机制和注意力机制,允许模型学习长期依赖性,且可能可以扩展到需要该能力的其他深度学习领域,例如音频分析(如每秒16k样本的语音数据)等。
- 在inference阶段非常快,比之前最先进的利用Transformer模型进行语言建模的方法快300~1800倍。
- 有详尽的源码!含TensorFlow和PyTorch版本的,并且有TensorFlow预训练好的模型及各个数据集上详尽的超参数设置。
3. 不足
- 尚未在具体的NLP任务如情感分析、QA等上应用。
- 没有给出与其他的基于Transformer的模型,如BERT等,对比有何优势。
- 在Github源码中提到,目前的sota结果是在TPU大集群上训练得出,对于我等渣机器党就只能玩玩base模式了。
传送门
Transformer-XL解读(论文 + PyTorch源码)相关推荐
- ELMo解读(论文 + PyTorch源码)
ELMo的概念也是很早就出了,应该是18年初的事情了.但我仍然是后知后觉,居然还是等BERT出来很久之后,才知道有这么个东西.这两天才仔细看了下论文和源码,在这里做一些记录,如果有不详实的地方,欢迎指 ...
- XLM解读(论文 + PyTorch源码)
这篇论文是Facebook在BERT的基础上发展出来的Cross-Lingual版本,即多语的.BERT的github上实际上也有一个多语版本的,但却没有提到是怎么训练的,也没有任何的信息.这里的XL ...
- 小白学习pytorch源码(二):setup.py最详细解读
小白学习pytorch源码(二) pytorch setup.py最全解析 setup.py与setuptools setup.py最详细解读 setup.py 环境检查 setup.py setup ...
- PyTorch 源码解读之 torch.utils.data:解析数据处理全流程
目录 0 前言 1 Dataset 1.1 Map-style dataset 1.2 Iterable-style dataset 1.3 其他 dataset 2 Sampler 3 DataLo ...
- pytorch源码解析2——数据处理torch.utils.data
迭代器 理解 Python 的迭代器是解读 PyTorch 中 torch.utils.data 模块的关键. 在 Dataset, Sampler 和 DataLoader 这三个类中都会用到 py ...
- pytorch 测试每一类_DeepFM全方面解析(附pytorch源码)
写在前面 最近看了DeepFM这个模型.把我学习的思路和总结放上来给大家和未来的自己做个参考和借鉴.文章主要希望能串起学习DeepFM的各个环节,梳理整个学习思路.以"我"的角度浅 ...
- 2023 XL软件库App后端源码 可自定义易支付 完整版
2023 XL软件库App后端源码 可自定义易支付 完整版 安装教程 先导入sql数据库,然后修改config.php 里边填数据库信息 再倒入app源码到iapp,打开源码main.iyu载入界面, ...
- 基于深度学习的文本分类6大算法-原理、结构、论文、源码打包分享
导读:文本分类是NLP领域一项基础工作,在工业界拥有大量且丰富的应用场景.传统的文本分类需要依赖很多词法.句法相关的human-extracted feature,自2012年深度学习技术快速发展之后 ...
- 多智能体系统——竞争网络下异构多智能体系统的分组一致性问题 Group consensus of heterogeneous multi-agent system (附论文链接+源码Matlab)
多智能体系统--竞争网络下异构多智能体系统的分组一致性问题 (附论文链接+源码Matlab) Yu F, Ji L, Yang S. Group consensus for a class of he ...
最新文章
- Spring Cloud(九)高可用的分布式配置中心 Spring Cloud Config 集成 Eureka 服务
- PHP_常用字符串处理函数
- Hacking PostgreSQL
- SQL Server 中常见的十张系统表
- reactor线程模型_简单了解Java Netty Reactor三种线程模型
- 分布式系统原理 之4 Quorum 机制
- 如何使ArrayList 线程安全
- 2021新媒体内容生态数据报告
- c语言圆周率计算_C语言入门这一篇就够了
- python中for语句的使用_python中for in的用法
- 微软云计算介绍与实践(实践之二十七)
- YV12数据与AVFrame的相互转换
- W801单片机学习笔记——内部结构,总线架构篇
- 自然场景文本检测识别 - 综述
- 微信小程序开发——设置默认图片、错误加载图片
- 3.3 费马质数测试
- 2023 年openEuler 社区技术委员会增选,新增2位委员
- mysql ( )_MYSQL (一)
- 1循环结构程序设计-第5关:C循环-寻找完数
- 以问题为导向剖析一些矩阵等价类的本质(合同篇)
热门文章
- html 链接app store,App Store 连接失败
- Ubuntu16.04运行VoxelNetRos
- mysql8 允许外网访问
- pyspark pipline
- shell中join链接多个域_Linux Shell中使用awk完成两个文件的关联Join
- Python应用实战案例-一文通读时间序列在Python中的应用
- MATLAB从入门到精通-如何用matlab来提取txt文本中的实验数据
- 次元网站女装穿起来,从A站到Z站,你知道哪个?谁才是你的最爱?
- 分布式计算Hadoop系列之如何修改Eclipse插件
- 共享文件夹的网络路径_Win10创建网络共享文件夹|设置局域网共享文件夹