作者:Madison May

编译:ronghuaiyang (AI3公园)

原文地址:

对Reformer的深入解读​mp.weixin.qq.com

导读

Reformer之前已经提过几次,这次带大家更加深入的了解一下这个方法的思想及背后的动机。

自从最初的"Attention is All You Need"论文在NLP社区掀起了Transformer热潮,似乎我们一直在不懈地追求更大的模型。在2019年夏天,英伟达发布了他们的MegatronLM论文 —— 83亿参数。在2020年2月,微软再次加大赌注,发布了一篇关于Turing-NLG的博客文章,拥有170亿个参数。

理解当我们增加参数数量和训练数据的时候,这些模型能到什么程度肯定是有价值的,我很高兴有这些资源可以进行大规模实验的公司已经这么做了。但是,相比来说,我们在如何把Transformer架构变的更加高效这件事情上,投入的太少了。

"Reformer: The Efficient Transformer"来自Nikita Kitaev,Łukasz Kaiser,Anselm Levskaya,与过去两年的“越大越好”的趋势形成了鲜明的对比,并在2020年的ICLR进行了报告。Reformer的论文读起来就像呼吸了一股清新的空气 —— 这篇文章主要关注自注意力操作是如何随序列长度扩展的,并提出了一种替代的注意了机制,可以将来自更长的上下文的信息整合到语言模型中。

使用Reformer对Transformer的改变,可以在单个加速器上对长度为64000的序列进行注意力操作,相比于 MegatronLM和TuringNLP中的1024的上下文长度,形成了鲜明的对比。这两个模型都采用了模型并行管道来拷贝大量的参数。

Self-Attention的回顾

在深入研究Reformer体系结构的细节之前,让我们简要回顾一下self-attention的形成过程,以获得一些在合并长上下文中所遇到的困难的背景知识。

为了简单起见,我们只讨论与单头的点积注意力,尽管在实践中使用了多头注意力。

如果你想要更深入的回顾一下self-attention机制,我强烈推荐Alexander Rush的Annotated Transformer,还有Jay Alammar的Illustrated Transformer。

我们可以把self-attention分为三个主要部分:

Query - Key - Value投影

QKV投影。尽管我们将这个操作画成三个独立的线性投影,但为了提高计算效率,它通常被实现为单个矩阵乘法。

在此阶段,每个token的当前隐藏状态通过线性投影分解为三个部分。

queries = np.matmul(query_weights, hidden) + query_bias
keys = np.matmul(key_weights, hidden) + key_bias
values = np.matmul(value_weights, hidden) + value_bias

Query / Key矩阵乘法

self-attention操作的核心 —— 一个矩阵乘法计算我们的keys和queries之间的两两相似度得分。

在投影之后,将queries和keys相乘以计算两两的相似度。这是用矩阵乘法实现的。

qk_agreement = np.matmul(queries, np.swapaxes(keys, -1, -2))

如果你的keys和queries是形状为(batch, sequence_length, hidden_size)的张量,那么矩阵乘法的输出就是形状为(batch, sequence_length, sequence_length)的张量。

这种看似无关紧要的矩阵乘法正是这种self-attention操作的计算复杂性问题的根源。对于序列长度的线性增加,计算输出所需的乘法次数以平方方式增加,因为我们需要为输入中每一对可能的token计算相似性。这O(L²)的复杂性意味着序列的长度超过1024的token使用原始的transformer结构是不切实际的。事实上,BERT和它的继任者RoBERTa中所选择的上下文长度只有512。

Softmax + Values的加权和

key / value 协同矩阵中的项除以了一个缩放因子sqrt(hidden_size),用来消除hidden size这个参数对注意力分布的影响。对于每个query,我们在所有keys上计算一个softmax,以确保矩阵的每一行和为1—— 确保新的隐藏状态的大小不依赖于序列长度。最后,我们用我们的注意力矩阵乘以我们的values矩阵,为每个token生成一个新的隐藏表示。

计算复杂度 — 解决方案

如前所述,虽然点积注意力方式非常好用,允许任意的token在我们的上下文中从任何其他的token中聚合信息,这种灵活性是有代价的,一个不幸的O (L²)计算复杂度。

有几篇论文提出了帮助解决这种计算复杂性的transformer的变体。"Generating Long Sequences With Sparse Transformers”建议使用成对的注意力操作和精心选择的注意力模式来分解注意操作。"Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"引入了一种循环机制,允许整合来自比自注意力操作的上下文更大的距离的信息。

The Reformer

"Reformer: The Efficient Transformer"的作者采用了一种完全不同的方法来处理序列长度问题。首先,他们观察到学习不同的keys和queries的投影并不是严格必要的。他们丢弃了query投影,并将注意力权重替换为key的函数。

有点令人惊讶的是,尽管他们从注意力模块中移除了一些参数,他们的模型在enwiki8上的性能并没有下降。

在enwiki8上把key和query的投影放到一起做可以获得相同的性能。

现在,注意力块不再包含queries的单独投影,我们只有key和value对。然而,计算key的协同矩阵(通过将每个key与其他key进行比较)仍然是非常昂贵的。

不幸的是我们可能并没有利用好所有的这些计算。softmax的输出通常由几个关键元素控制 — 其余的往往在噪声中消失。我们在计算softmax的时候,并不一定需要那些注意力权重很小的token。

在编写传统软件时,我们总是会遇到这个问题。如果我们想找到与给定key对应的value,我们通常不会遍历所有key的列表并检查每个key是否匹配。相反,我们使用散列映射数据结构来执行O(1)的查找,而不是O(n)比较。

方便的是,向量空间的哈希映射确实存在类似的情况,它被称为“局部敏感哈希”(LSH)。正是基于这种方法,Reformer的论文的作者们希望产生一个transformer的替代方案,以避免使用点积注意力的平方复杂性。

局部敏感哈希 (LSH)

局部敏感哈希是一组将高维向量映射到一组离散值(桶/集群)的方法。它最常用来作为近似最近邻搜索的一种方法,用于近似的重复检测或视觉搜索等应用。

局部敏感哈希方法尝试将高维空间中相近的向量以高概率分配到相同的哈希。有效的哈希函数有很多种,最简单的可能是随机投影。

lsh_proj = np.random.randn(hidden_size, hash_size)
hash_value = np.sign(np.dot(x, lsh_proj.T))

换句话说,我们选择一个随机的向量集合,观察输入向量在每个向量上的投影是正的还是负的,然后使用这个二值向量来表示给定向量的预期存储区。下图说明了LSH投影矩阵“u”中单个向量的处理过程。绿色的正号表示与向量u点积为正的点,而红色的负号表示与向量u点积为负的点。

LSH注意力

Reformer的论文选择了局部敏感哈希的angular变体。它们首先约束每个输入向量的L2范数(即将向量投影到一个单位球面上),然后应用一系列的旋转,最后找到每个旋转向量所属的切片。

该图演示了一个用4个桶进行3轮哈希的设置。下面的图中的向量映射到了同一个bucket,因为它们的输入很接近,而上一张图中的向量映射到第一个和最后一个bucket。

找到给定的向量选择之后属于哪个桶也可以看成是找到和输入最一致的向量 —— 下面是Reformer的代码:

# simplified to only compute a singular hash
random_rotations = np.random.randn(hidden_dim, n_buckets // 2)
rotated_vectors = np.dot(x, random_rotations)
rotated_vectors = np.hstack([rotated_vectors, -rotated_vectors])
buckets = np.argmax(rotated_vectors, axis=-1)

在为每个token计算一个桶之后,将根据它们的桶对这些token进行排序,并将标准的点积注意力应用到桶中的token的块上。

有了足够多的桶,这就大大减少了所有的给定的token需要处理的token的数量 —— 在实验中,Reformer的论文运行的模型被配置为使用128块大小的块。因此,LSH操作将昂贵的key协同矩阵乘法的上下文大小限制为更易于管理的值。

我们现在的时间复杂度为O (L*log(L)) ,而不是时间复杂度成正比O (L²), 这允许我们把注意力操作扩展到更长的序列的时候不会由于运行时间而受到影响。

因为这个分桶过程是随机的,所以Reformer有选择地多次运行这个过程,以减少两个在输入空间很近的向量被随机地放在不同的桶中的可能性。当所有的事情都做了之后,你就有了一个完全替代标准的多头注意力的方法,它可以与计算完整的注意力矩阵相媲美。

内存复杂度

不幸的是,实现更好的时间复杂度只是问题的一半。如果我们将新的LSH注意力块替换为标准的多头注意力,并尝试输入新长度的信息,我们将很快认识到系统中的下一个瓶颈 — 内存复杂性。

即使我们已经非常小心地最小化了注意力操作的计算复杂度,我们仍然必须将所有的key和value存储在内存中,更糟糕的是,在训练期间,我们需要缓存激活以计算参数更新。

Reformer论文使用了序列长度为64k的enwiki8语言建模数据集来做实验,隐藏单元的大小为1024,层数为12层,这意味着存储key和value需要2 * 64000 * 1024 * 12 = ~ 1.5B个浮点数,大约是6GB的内存。使用这种内存使用方式,我们将无法在训练期间使用大的批处理大小,从而影响我们的运行时间。

一个选择是实现gradient checkpoint来帮助限制我们的内存使用。允许我们减少内存使用,只存储从正向传递中的关键的激活,剩余的在反向传递中重新计算。因此,我们可以选择只在key和value投影之前存储隐藏状态,而不是存储key和value,然后第二次重新投影隐藏状态来计算梯度。

不幸的是,这使我们的后向传递的成本增加了一倍,因此我们能够支持更大的批处理大小所获得的好处将通过重新计算得到部分缓解。更重要的是,即使我们选择只存储输入的一小部分,存储单个层的激活需要250MB的空间,这意味着我们很难在12GB的GPU上支持超过12个样本的批处理大小。

RevNets

幸运的是,我们还有其他方法来减少内存使用。RevNet。

RevNets有个非常聪明的计算技巧,通过以一种特定的方式构造每一层,使内存使用与网络深度保持一致。每一层分为两个部分,X₁和X₂,前向计算如下:

def forward_pass(x1, x2, Wf, Wg):"""Need an extra node in the computational graphbecause the gradient of the loss with respect to z1       # differs from the gradient of loss with respect to y1x1: one half of layer inputx2: other half of layer inputWf: weights that parameterize function fWg: weights that parameterize function g"""z1 = x1 + f(Wf, x2)y2 = x2 + g(Wg, z1)y1 = z1

可视化一下,看起来就是这样:

来自RevNet论文的图,图(a)为RevNet的前向,图(b)为相应的反向。

由于该层的特定结构,我们可以编写一个自定义函数参数更新,这意味着我们不需要缓存任何激活来计算我们的后向传播。类似于使用梯度检查点,我们仍然需要做一些冗余计算。然而,由于每一层的输入都可以很容易地从它的输出中构造出来,我们的内存使用不再随网络中层数的增加而增加。

def backward_pass(y1, y2, d_y1, d_y2, Wf, Wg):"""Pseudocode for RevNet of backward passy1: one half of layer outputy2: second half of layer outputd_y1: derivative of y1d_y2: derivative of y2Wf: weights that parameterize function fWg: weights that parameterize function g"""z1 = y1# Extra computation -- the price we pay for memory# complexity that doesn't scale with n_layers# Importantly this means we don't have to store x1 or x2!x2 = y2 - g(Wg, z1)x1 = y1 - f(Wf, x2)# Standard backprop:# vjp --> Vector Jacobian Productd_Wf, partial_x2 = jax.vjp(f, Wf, x2)(d_z1)d_Wg, partial_z1 = jax.vjp(g, Wg, z1)(d_y2)d_z1 = d_y1 + partial_z1d_x2 = d_y2 + partial_x2d_x1 = d_z1return x1, x2, d_x1, d_x2, d_Wf, d_Wg

在实践中,Reformer定义f(x)是LSH注意力块,g (x)是标准的前向块,来自transformer结构。

有了RevNet架构,我们只需要在内存中存储单层的激活,就可以在训练期间使用更大的批处理大小!现在我们不再受训练期间激活的内存占用的限制,我们可以利用LSH注意力块改进时间复杂度。

重要的是,语言模型的loss不会因为可逆层结构而降低。

这些变化实现起来并不容易 —— 很明显Nikita Kitaev, Łukasz Kaiser和Anselm Levskaya付出巨大的努力在平衡时间和内存。

总的来说,这些变化使得序列长度的扩展成为可能。虽然结果是初步的,但在enwiki8上的实验表明,在语言建模任务上,Reformer可以与它的重量级前辈竞争。

总结

位置敏感哈希的注意力和可逆层构成了Reformer的蓝图,非常高兴可以看到基于transformer的结构选择去优化和处理长序列的问题,而不是简单的扩展之前的工作。

英文原文:https://www.pragmatic.ml/reformer-deep-dive/

对Reformer的深入解读相关推荐

  1. 解读Reformer

    论文地址:https://openreview.net/forum?id=rkgNKkHtvB 代码:https://github.com/lucidrains/reformer-pytorch 介绍 ...

  2. 7 Papers | 腾讯王者荣耀绝悟AI;ICLR高分论文Reformer

    点击上方"深度学习技术前沿",选择"星标"公众号 资源干货,第一时间送达 本周 7 Papers 包含多篇 AAAI 2020.ICLR 2020 入选论文,如 ...

  3. Python Re 模块超全解读!详细

    内行必看!Python Re 模块超全解读! 2019.08.08 18:59:45字数 953阅读 121 re模块下的函数 compile(pattern):创建模式对象 > import ...

  4. Bert系列(二)——源码解读之模型主体

    本篇文章主要是解读模型主体代码modeling.py.在阅读这篇文章之前希望读者们对bert的相关理论有一定的了解,尤其是transformer的结构原理,网上的资料很多,本文内容对原理部分就不做过多 ...

  5. Bert系列(三)——源码解读之Pre-train

    https://www.jianshu.com/p/22e462f01d8c pre-train是迁移学习的基础,虽然Google已经发布了各种预训练好的模型,而且因为资源消耗巨大,自己再预训练也不现 ...

  6. NLP突破性成果 BERT 模型详细解读 bert参数微调

    https://zhuanlan.zhihu.com/p/46997268 NLP突破性成果 BERT 模型详细解读 章鱼小丸子 不懂算法的产品经理不是好的程序员 ​关注她 82 人赞了该文章 Goo ...

  7. 解读模拟摇杆原理及实验

    解读模拟摇杆原理及实验 Interpreting Analog Sticks 当游戏支持控制器时,玩家可能会一直使用模拟摇杆.在整个体验过程中,钉住输入处理可能会对质量产生重大影响.让来看一些核心概念 ...

  8. 自监督学习(Self-Supervised Learning)多篇论文解读(下)

    自监督学习(Self-Supervised Learning)多篇论文解读(下) 之前的研究思路主要是设计各种各样的pretext任务,比如patch相对位置预测.旋转预测.灰度图片上色.视频帧排序等 ...

  9. 自监督学习(Self-Supervised Learning)多篇论文解读(上)

    自监督学习(Self-Supervised Learning)多篇论文解读(上) 前言 Supervised deep learning由于需要大量标注信息,同时之前大量的研究已经解决了许多问题.所以 ...

最新文章

  1. 图形交互界面_人机交互界面UI简介
  2. 【机器视觉】探索机器学习理论的最新进展,走近云、端、芯上的视觉计算
  3. 拍个自拍,让Python告诉你,军训过后你黑了几度?
  4. spring boot同时启动多个服务副本(同一服务启动在不同端口)配置方法
  5. 第四周课程总结试验报告(二)
  6. POJ - 3922 A simple stone game(K倍博弈-斐波那契博弈进阶)
  7. 如何在 C# 中使用匿名类型
  8. 英语六级翻译训练:教育专题
  9. java list indexof_Java LinkedList indexOf()方法
  10. PSP战神 斯巴达勇士 游戏ISO文件和完美通关存档和金手指
  11. 正睿OIday8-day10
  12. 展望下未来的计算机400字,展望未来的作文400字
  13. iOS音效和音乐播放
  14. sql查询大于平均得分的球员的名字和得分,并追加显示平均得分的列
  15. 经典 Fuzzer 工具 AFL 模糊测试指南
  16. gcms基峰有什么用_请收下!来自前辈的“气质联用”经验分享
  17. uncooked 计算机术语,物流专业英语和计算机基础练习题[1]
  18. 《Unity Shader入门精要》笔记01 前言
  19. 企业邮箱排名,收费企业邮箱哪家好?
  20. php 输入 输出,php的文件输入输出流php://input

热门文章

  1. 字符串数组-获取两个字符串中最大的相同子串(最大相同子串有且只有一个)
  2. RuntimeError: dictionary changed size during iteration 解决办法
  3. cp -r dir1/. dir2 表示将dir1下的文件复制到dir2,不包括dir1目录
  4. cbow 和skip-gram比较
  5. HMM中文分词分析 知乎
  6. Python 3.8 即将到来,这是你需要关注的几大新特性
  7. 【连载】优秀程序员的45个习惯之39——架构师必须写代码
  8. Robert Hoekman,Jr 继《一目了然》后的又一经典力作
  9. [PyTorch]一个非常好的抢救outofmemory的方法
  10. TensorFlow安装 通过Anaconda Prompt Win10 64位安装 cpu版 tensorflow