pytorch实现attention_Longformer: 局部Attention和全局attention的混搭
最近要开始使用Transformer去做一些事情了,特地把与此相关的知识点记录下来,构建相关的、完整的知识结构体系,
以下是要写的文章,本文是这个系列的第十一篇:
- Transformer:Attention集大成者
- GPT-1 & 2: 预训练+微调带来的奇迹
- Bert: 双向预训练+微调
- Bert与模型压缩
- Bert与模型蒸馏:PKD和DistillBert
- ALBert: 轻量级Bert
- TinyBert: 模型蒸馏的全方位应用
- MobileBert: Pixel4上只需40ms
- 更多待续
- Bert与AutoML (待续)
- Bert变种
- Roberta: Bert调优
- Transformer优化之自适应宽度注意力
- Transformer优化之稀疏注意力
- Reformer: 局部敏感哈希和可逆残差带来的高效
- Longformer: 局部attentoin和全局attention的混搭(本篇)
- Linformer: 线性复杂度的Attention
- XLM: 跨语言的Bert
- T5 (待续)
- 更多待续
- GPT-3
- 更多待续
Overall
Bert模型虽然很强大,但双向attention它的时间和空间复杂度呈N^2的趋势增长,所以最初的Bert模型能够处理的最长长度是512。
而在这一限制的基础上,如果想处理较长的序列,就需要用妥协的方式:
- 直接截断成512长度的。这点普遍用于文本分类问题。
- 截成多个长度为512的序列段(这些序列段可以互相overlapping),每个都输入给Bert获得输出,然后将多段的输出拼接起来。
- 两个阶段去解决问题,一般用于Question-Answer问题,第一个阶段去选择相关文档,第二个阶段去找到对应的answer。
无论哪种方式,毫无疑问都会带来损失。如果能直接处理长序列就好了。
关于这一点,我们在Reformer和自适应宽度注意力这两篇中各自讲述了办法去解决:
- Reformer使用局部敏感哈希来解决性能问题,动机就在于attention起作用的在于top-N而不是全部。
- 自适应宽度使用一种动态窗口的方法来解决,动机在于attention有可能只attend最近的一些context。
而今天的这篇论文[1],用了一种更加直接的方式去对attention进行改造。那就是局部attention和全局attention的结合,局部attention用来捕捉局部信息,一般用于底层,这点和自适应宽度其实有点像。全局attention则捕捉全局信息,用于高层,目的在于保持住所有的信息。除此之外,论文还提供了改造后的attention的C++实现优化,使之相对于pytorch中的naive实现有了很大的提升。
下图中有个对比,可以看到,在计算速度上,Longformer与Full attention持平,但Full attention在超过一定长度后因为内存问题就无法运行了,论文提供的实现要比pytorch的原始实现块6倍。而在内存上,Longformer则是线性增长的。
滑动窗口Attention
论文的核心就在于局部Attention的设计,在这里采用的是滑动窗口来做,滑动窗口的大小为w,那么每个位置只attend前后w/2个位置。如下图b所示。
因为模型都是多层叠加的,所以层级越高,attend的视野域就越广。如果w=3,那么第一层只能注意3个位置,但到第二层能注意到第一层输出的三个位置,换算到第一层的输入,就是5个位置。所以随着层级越高,理论上每个位置注意到的区域就越大,所能存储的信息就越接近全局attention时的状态。
旁白君:这点和卷积神经网络很像。
另外,每一层的w其实可以不同,鉴于越高层需要的全局信息越多,可以在层级较高的时候把w调大。
因为w远小于长度,所以有了滑动窗口,内存占用就从l^2变成了l乘以w,也就是线性。
滑动窗口+空洞Attention
上面的滑动窗口很类似于卷积,那么相应的,我们还可以像卷积一样加空洞,如下图c所示,这里有个参数d,意为空洞的大小。
空洞可以帮助attention在不增大内存占用的同时,增大视野域d倍。
全局Attention
在现在Bert架构中,只靠上面两个局部attention是不够的,因为储存的信息毕竟有限制,为了解决这个问题,所以全局attention应运而生。
这里的全局attention并不是所有位置attend所有位置,而是选中一些位置让它们之间去做两两的attention。而这些位置的选择,则与具体的问题相关。例如,对于文本分类问题而言,[CLS]这个特殊token会被当做所有信息的聚合点,因而这个位置肯定要被选中。而对于QA问题而言,所有的question的token上要去做全局attention。
旁白君: 其实这也是一种妥协,放弃了任务上的通用性。
QKV线性映射
回顾一下Attention的计算,对于一个序列的embedding,我们需要让它经过三个矩阵Q,K,V分别转化为q,k,v,然后q和k计算相似度,得到的权重再去和v做组合。
而有了全局attention和局部attention的区分后,我们也需要将这两种attention对应的QKV矩阵区分开,所以有两套QKV矩阵。
CUDA实现
在Tensorflow和pytorch的原始实现中,并没有用于计算滑动窗口attention的专门的实现,因为这个实现需要实现矩阵乘积且只要对角线的位置的非0值需要存在内存里。
而如果用for循环又异常的慢。所以用high-level的python struct描述了这种算法,并基于TVM生成了可以在GPU上编译的代码。在最上图可以看到,比naive实现快了6倍。
Attention的设置
正如上面提到的,在底层的时候使用较小的w,而在高层的时候w变大。这个参数设置需要做参数搜索,来平衡性能和效果。
另外,在底层不适用空洞attention,因为底层需要学习直接的局部信息。而在较高的层次会使用空洞attention,不过只限定在2个头上。
实验效果
在字符级的语言模型上,可以看到Longformer比之前的算法都要好。
而在消融实验中,可以看到递增w的策略和空洞attention都能带来提升。
为了节省训练的时间,用Roberta的参数来初始化Longformer,当然,因为结构的不同,需要做一些变动。得到的结果如下,可以看到,loss降了约8%~10%。
思考
勤思考,多提问是Engineer的良好品德。
- 用Roberta的checkpoint来初始化Longformer,需要做哪些变动?应该如何考虑?
答案会公布在微信公众号【雨石记】,欢迎关注交流。
参考
- [1]. Beltagy, Iz, Matthew E. Peters, and Arman Cohan. "Longformer: The long-document transformer." arXiv preprint arXiv:2004.05150 (2020).
pytorch实现attention_Longformer: 局部Attention和全局attention的混搭相关推荐
- 为节约而生:从标准Attention到稀疏Attention
作者丨苏剑林 单位丨追一科技 研究方向丨NLP,神经网络 个人主页丨kexue.fm 如今 NLP 领域,Attention 大行其道,当然也不止 NLP,在 CV 领域 Attention 也占有一 ...
- attention机制、self-attention、channel attention、spatial attention、multi-head attention、transformer
文章目录 attention sequence attention attention 与 self-attention channel attention 与 spatial attention m ...
- 【Attention,Self-Attention Self Attention Self_Attention】通俗易懂
Attention is, to some extent, motivated by how we pay visual attention to different regions of an im ...
- Attention 与Hierarchical Attention Networks 原理
Attention 与Hierarchical Attention Networks 1. Attention 注意力机制 1.1 什么是Attention? 1.2 加入Attention的动机 1 ...
- soft attention and self attention
注意力模型最近几年在深度学习各个领域被广泛使用,无论是图像处理.语音识别还是自然语言处理的各种不同类型的任务中,都很容易遇到注意力模型的身影.所以,了解注意力机制的工作原理对于关注深度学习技术发展的技 ...
- Nat. Mach. Intell.|从局部解释到全局理解的树模型
今天介绍美国华盛顿大学保罗·艾伦计算机科学与工程学院的Su-In Lee团队在nature mechine intelligence 2020的论文,该论文提出了一种基于博弈论沙普利值的TreeExp ...
- java如何做全局缓存_传智播客JNI第七讲 – JNI中的全局引用/局部引用/弱全局引用、缓存jfieldID和jmethodID的两种方式...
讲解JNI中的全局引用/局部引用/弱全局引用.缓存jfieldID和jmethodID的两种方式,并编写两种缓存方式的示例代码. 1.从Java虚拟机创建的对象传到本地C/C++代码时会产生引用,根据 ...
- Self Attention和Multi-Head Attention的原理和实现
个人博客:http://www.chenjianqu.com/ 原文链接:http://www.chenjianqu.com/show-47.html 引言 使用深度学习做NLP的方法,一般是将单词转 ...
- 不能返回函数内部new分配的内存的引用_JNI开发之局部引用、全局引用和弱全局引用(三)...
阿里P7移动互联网架构师进阶视频(每日更新中)免费学习请点击:https://space.bilibili.com/474380680 这篇文章比较偏理论,详细介绍了在编写本地代码时三种引用的使用场景 ...
最新文章
- 【风之语】至贱城市之成都
- IPO与上市的关系?
- LwIP之协议栈接口
- grails安装部署_grails中文版
- 共模电感适用的频率_详解消灭EMC的三大利器:电容器/电感/磁珠!
- 学生管理系统功能设计c语言,C语言--学生信息管理系统设计.doc
- 怎么看计算机运行时间,怎么查看电脑运行时间_怎么查看电脑运行记录
- 【MAC技巧】 MAC下两款免费的风扇调节工具
- ELk日志分析系统搭建
- 《SEM长尾搜索营销策略解密》一一2.12 宝洁里的长尾与创新
- 托福、雅思、托业有什么区别?
- 如何实现用户名或密码错误,弹出重新登录的提示
- wamp mysql使用方法_wamp使用方法【总】
- 燕东微通过注册:预计年营收超20亿 亦庄国投与京东方是股东
- 同步异步+阻塞非阻塞-二述
- MCS:离散随机变量——Poisson分布
- java string时间类型天数运算
- 大话设计模式三之单一职责原则、开放-封闭原则、依赖倒置原则、里氏代换原则
- ABP继承自AbpController后路由无效
- android开发论坛
热门文章
- c# 图片加图片水印、文字水印和图片文字水印
- 解决使用requests_html模块,req.html.render()下载chromium速度慢问题
- Go的异常处理 defer, panic, recover
- 如何在Django中以GROUP BY查询?
- 什么!在CSS中的重要意义? [重复]
- 自用Java爬虫工具JAVA-CURL已开源
- mysql--------命令来操作表
- Zabbix Python API 应用实战
- python 对字典排序
- 在windows下架设openssh实现资源共享