Transformer的潜在竞争对手QRNN论文解读,训练更快的RNN
来源:DeepHub IMBA
本文约2100字,建议阅读5分钟
本文我们将讨论论文“拟递归神经网络”中提出的QRNN模型。
使用递归神经网络(RNN)序列建模业务已有很长时间了。但是RNN很慢因为他们一次处理一个令牌无法并行化处理。此外,循环体系结构增加了完整序列的固定长度编码向量的限制。为了克服这些问题,诸如CNN-LSTM,Transformer,QRNNs之类的架构蓬勃发展。
在本文中,我们将讨论论文“拟递归神经网络”(https://arxiv.org/abs/1611.01576)中提出的QRNN模型。从本质上讲,这是一种将卷积添加到递归和将递归添加到卷积的方法。
LSTM
LSTM是RNN最著名的变体。红色块是线性函数或矩阵乘法,蓝色块是无参数元素级块。LSTM单元应用门控功能(输入,遗忘,输出)以获得输出和称为隐藏状态的存储元素。此隐藏状态包含整个序列的上下文信息。由于单个向量编码完整序列,因此LSTM无法记住长期依赖性。而且,每个时间步长的计算取决于前一个时间步长的隐藏状态,即LSTM一次计算一个时间步长。因此,计算不能并行进行。
CNN
CNN可以捕获空间特征(主要用于图像)。红色块是卷积运算,蓝色块是无参数池化运算。CNN使用内核(或过滤器)通过滑动窗口捕获要素之间的对应关系。这克服了固定长度的隐藏表示形式(以及由此带来的长期依赖问题)以及RNN缺乏并行性限制的问题。但是,CNN不显示序列的时间性质,即时间不变性。池化层只是在不考虑序列顺序信息的情况下降低了通道的维数。
Quasi-Recurrent Neural Networks (QRNN)
QRNN解决了两种标准架构的缺点。它允许并行处理并捕获长期依赖性,例如CNN,还允许输出依赖序列中令牌的顺序,例如RNN。
因此,首先,QRNN体系结构具有2个组件,分别对应于CNN中的卷积(红色)和池化(蓝色)组件。
卷积分量
卷积组件的操作如下:
形状的输入序列:(batch_size,sequence_length,embed_dim)
每个“ bank”的形状为“ hidden_dim”的内核:(batch_size,kernel_size,embed_dim)。
输出是一个形状序列:(batch_size,sequence_length,hidden_dim)。这些是序列的隐藏状态。
卷积运算在序列以及小批量上并行应用。
为了保留模型的因果关系(即,只有过去的标记才可以预测未来),使用了一种称为遮罩卷积(masked-convolutions)的概念。也就是说,输入序列的左边是“ kernel_size-1”零。因此,只有'sequence_length-kernel_size + 1'过去的标记可以预测给定的标记。为了更好理解,请参考下图:
接下来,我们基于池化功能(将在下一节中讨论)使用额外的内核库,以获取类似于LSTM的门控向量:
这里,*是卷积运算;Z是上面讨论的输出(称为“输入门”输出);F是使用额外的内核库W_f获得的“忘记门”输出;O是使用额外的内核库W_o获得的“输出门”输出。
如上所述,这些卷积仅应用于过去的“ sequence_length-kernel_size + 1”令牌。因此,如果我们使用kernel_size = 2,我们将得到类似LSTM的方程式:
池化组件
通常,合并是一种无参数的函数,可捕获卷积特征中的重要特征。对于图像,通常使用最大池化和平均池化。但是,在序列的情况下,我们不能简单地获取特征之间的平均值或最大值,它需要有一些循环。因此,QRNN论文提出了受传统LSTM单元中元素级门控体系结构启发的池化功能。本质上,它是一个无参数函数,它将跨时间步混合隐藏状态。
最简单的选项是“动态平均池化”,它仅使用了“忘记门”(因此称为f-pooling):
⊙是逐元素矩阵乘法。它以忘记门为参数,几乎等于输出的“移动平均值”。
另一种选择是使用忘记门以及输出门(所以被称作,fo-pooling):
除此以外,池化可能另外具有专用的输入门(ifo-pooling):
正则化
在检查了各种递归退出方案之后,QRNN使用了一种扩展方案,称为“区域退出”(‘zone out),它本质上是在每个时间步选择一个随机子集来退出,对于这些通道,它只是将当前通道值复制到下一次 步骤,无需任何修改。
这等效于将QRNN的“忘记门”通道的子集随机设置为1,或在1-F上进行dropout -- QRNN Paper
来自DenseNet的想法
DenseNet体系结构建议在每一层与其前面的每一层之间都具有跳过连接,这与在后续层上具有跳过连接的惯例相反。因此,对于具有L个层的网络,将存在L(L-1)个跳过连接。这有助于梯度流动和收敛,但要考虑二次空间。
使用QRNN构建seq2seq
在基于RNN的常规seq2seq模型中,我们只需使用编码器的最后一个隐藏状态初始化解码器,然后针对解码器序列对其进行进一步修改。我们无法对循环池层执行此操作,因为在这里,编码器状态无法为解码器的隐藏状态做出很大贡献。因此,作者提出了一种改进的解码器架构。
将编码器的最后一个隐藏状态(最后一个令牌的隐藏状态)线性投影(线性层),并在应用任何激活之前,将其添加到解码器层每个时间步长的卷积输出中(广播,因为编码器矢量较小):
V是应用于最后一个编码器隐藏状态的线性权重。
注意力机制
注意力仅应用于解码器的最后隐藏状态。
其中s是编码器的序列长度,t是解码器的序列长度,L表示最后一层。
首先,将解码器的未选通的最后一层隐藏状态的点积与最后一层编码器隐藏状态相乘。这将导致形状矩阵(t,s)。将Softmax替代s,并使用该分数获得形状(t,hidden_dim)的注意总和k_t。然后,将k_t与c_t一起使用,以获取解码器的门控最后一层隐藏状态。
性能测试
与LSTM架构相比,QRNN可以达到相当的准确度,在某些情况下甚至比LSTM架构略胜一筹,并且运算速度提高了17倍。
最近,基于QRNN的模型pQRNN在序列分类上仅用1.3M参数就取得了与BERT相当的结果(与440M参数的BERT相对):
结论
我们深入讨论了新颖的QRNN架构。我们看到了它如何在基于卷积的模型中增加递归,从而加快了序列建模的速度。QRNN的速度和性能也许真的可以替代Transformer。
编辑:王菁
校对:林亦霖
Transformer的潜在竞争对手QRNN论文解读,训练更快的RNN相关推荐
- EfficientNetV2震撼发布!87.3%准确率!模型更小,训练更快!谷歌大脑新作
EfficientNetV2: Smaller Models and Faster Training paper: https://arxiv.org/abs/2104.00298 code(官方TF ...
- GitHub趋势榜第一:超强PyTorch目标检测库Detectron2,训练更快,支持更多任务
栗子 发自 凹非寺 量子位 报道 | 公众号 QbitAI PyTorch目标检测库Detectron2诞生了,Facebook出品. 站在初代的肩膀上,它训练比从前更快,功能比从前更全,支持的模型也 ...
- 目标检测新网络——Matrix Net (xNet)参数更少,训练更快
摘要 提出了一种新的深度目标检测体系结构--矩阵网(xNets).xNets将具有不同大小和长宽比的对象映射到层中,这些层中对象的大小和长宽比几乎一致.因此,xnet提供了支持比例和高宽比的体系结构. ...
- Fully Convolutional Networks for Semantic Segmentation----2014CVPR FCN论文解读
Fully Convolutional Networks for Semantic Segmentation----2014CVPR论文解读 Abstract 卷积网络在特征分层领域是非常强大的视 ...
- 论文解读:《功能基因组学transformer模型的可解释性》
论文解读:<Explainability in transformer models for functional genomics> 1.文章概括 2.背景 3.相关工作 4.方法 4. ...
- 论文解读:《基于BERT和二维卷积神经网络的DNA增强子序列识别transformer结构》
论文解读:<A transformer architecture based on BERT and 2D convolutional neural network to identify DN ...
- 微软最新论文解读 | 基于预训练自然语言生成的文本摘要方法
作者丨张浩宇 学校丨国防科技大学计算机学院 研究方向丨自然语言生成.知识图谱问答 本文解读的是一篇由国防科技大学与微软亚洲研究院共同完成的工作,文中提出一种基于预训练模型的自然语言生成方法. 摘要 在 ...
- 论文解读丨LayoutLM: 面向文档理解的文本与版面预训练
摘要:LayoutLM模型利用大规模无标注文档数据集进行文本与版面的联合预训练,在多个下游的文档理解任务上取得了领先的结果. 本文分享自华为云社区<论文解读系列二十五:LayoutLM: 面向文 ...
- AI论文解读:基于Transformer的多目标跟踪方法TrackFormer
摘要:多目标跟踪这个具有挑战性的任务需要同时完成跟踪目标的初始化.定位并构建时空上的跟踪轨迹.本文将这个任务构建为一个帧到帧的集合预测问题,并提出了一个基于transformer的端到端的多目标跟踪方 ...
最新文章
- 什么是整型?Python整型详细介绍
- 安卓按键精灵_[按键精灵教程]学了这个你也能做出稳定的脚本
- Android-源码剖析CountDownTimer(倒计时类)
- C语言程序设计 | 动态内存管理:动态内存函数介绍,常见的动态内存错误,柔性数组
- php 通用购物车,PHP实现购物车代码[可重复使用]
- 放弃redis使用mongodb做任务队列支持增删改管理
- pycharm2020版本界面中英文注释
- 乱码385b1b926a38153d38957556c0dc55b5
- python零基础自学教材-python萌新:从零基础入门到放弃
- zabbix监控 nginx 进程
- 工作中线程池使用不当的问题记录(get是阻塞式的)
- 可以免费领取卡巴斯基激活码的活动
- 第二人生的源码分析(七十六)判断程序运行多个实例
- 年轻人逃离算法?更懂你的时尚推荐算法,你会拒绝吗?| FashionHack 专栏
- matlab 积分函数曲线,matlab积分函数
- ceph---ceph osd DNE状态对集群的影响
- win10系统升级,提示VirtualBox 立即卸载此应用,因为它与Windows 10 不兼容
- 二叉树前序,中序求后续;中序,后续求前序
- 算法中的『前缀和』及『差分』思想详解
- controll层跳转页面_以SpringMVC注解的形式 从Controller跳到另一个Controller 实现登入页面的跳转...
热门文章
- python电影情感评论分析_Kaggle电影评论情感分析
- matlab中gen2par函数,R语言中绘图par()函数用法
- java 读取url https_如何获取URL链接是http还是https
- linux 3.5.0-23-generic内核版本系统调用数目,Linux操作系统分析(三)- 更新内核与添加系统调用...
- CNCF接纳Harbor为沙箱项目
- 解决 java “错误:编码GBK 的不可映射字符”
- java如何读取excel文件
- 上海市金山区财政局容灾项目竞争性谈判600万元
- pandas.apply 有源码github
- tensorflow全联接层fully_connected参数解释正确的