一、Vanilla Transformer的结构

首先,作者要解决的问题是字级别的LM,相比词级别的LM,字级别LM明显需要依赖的距离特别长,比如说一句话某个位置是应该使用she还是he,是依赖于前面的主语情况,这个主语可能距离此单词位置的有十几个单词,每个单词7-8字母长度,那么这就将近100+个字符长度了,作者使用transformer的结构主要原因是他认为该结构很容易做到在任意距离上的信息传递。相对而言,RNN(LSTM)这种结构,就需要按照时间一步一步的传递信息,不能做到跨越距离。

这篇文章虽然用到了transformer结构,但与Attention is all you need这篇文章(简称原Transformer)是有差异的。原Transformer整体是一个seq2seq结构,具体的细节见此处。而Vanilla Transformer只利用了原Transformer的decode的部分结构,也就是一个带有mask的attention层+一个ff层。

如果将 “一个带有mask的attention层+一个ff层” 称为一个layer,那么Vanilla Transformer一共有64个这样的layer,每一个layer有2个head,model_dim=512,ff层的hidden_units=2048,sequence的长度为512。对于训练语言模型来说,这已经是一个很深的网络了,要知道对于大名鼎鼎的BERT网络的层数也就12层(base)和24层(large)了。

另外,之所以使用mask结构是因为语言模型的定义是 p(xi|x0x1…xi-1),也就是根据前i个字符预测第i+1个字符,如果你已经提前看到了答案(也就是第i+1个字符甚至更后面的字符内容),那就没有预测的意义了,这里加mask与原Transformer的decode部分的带有mask的self-attention道理都是一样的。

Positional Embeddings:RNN结构的网络对于类似于LM这种序列性的数据编码带有天然的优势,但缺点就是不能并行,必须要step by step。而attention结构最大的优点就是可以实现并行,但它不能表达序列性,所以为了给网络加入识别序列性就要引入 位置编码 Positional Embeddings。在原Transformer中,位置编码的编码信息是固定的,不需要学习,具体编码方式如下,输出为pos embedding。将word embedding + pos embedding整体作为网络的输入,并且仅在第一层加入了位置编码,之后的每层都不会再次加入。而对于Vanilla Transformer,作者认为它的网络深度太深了,如果只在第一层加入pos embedding,那么经过多层传递,这个信息很容易丢失,所以它是每层都会将上一层的输出与pos embedding加在一起作为下一层的输入,而且,pos embedding是需要学习的。所以,光pos embedding模型就要学习 NLdim 个参数,其中N是网络的层数(本文64层),L是上下文的长度(本文512),dim是embedding的维度(本文=512)。

def positional_encoding(dim, seq_length, dtype=tf.float32):""":param dim: 编码后的维度:param seq_length: 序列的最大长度:param dtype::return:"""pos_encode = np.array([pos/np.power(10000, 2*i/dim) for pos in range(seq_length) for i in range(dim)])pos_encode[0::2] = np.sin(pos_encode[0::2])pos_encode[1::2] = np.cos(pos_encode[1::2])return tf.convert_to_tensor(pos_encode.reshape([seq_length, dim]), dtype=dtype, name='positional_encoding')

总之,从结构上来说,Vanilla Transformer没有什么太特别的地方,用的组件都是原Transformer这篇论文中用到的,甚至还精简了一些,无非就是Vanilla Transformer的网络深度非常深。这个深度导致在训练的时候很难收敛,个人认为这篇论文中值得学习的就是为了达到收敛目的,作者使用的一些小trick,这些小trick对于我们以后解决类似的问题是很有帮助的。

二、Vanilla Transformer训练时作者的一些小trick【3种辅助Loss】


为了方便,我们只以2层来展示,且每一个segment的length=4,原本我们是根据 t0t_0t0~t3t_3t3的输入,在 HHH 节点这个位置预测 t4t_4t4 的结果,loss就是 HHH 节点的输入计算一个交叉熵。【Figure 1 shows our initial model with the causal attention mask limiting information flow from left to right. Each character prediction is conditioned only on the characters that appeared earlier.】

作者在论文中说当网络的深度超过10的时候,就很难让模型收敛,准确率也很低,所以如果大家训练的网络深度超过10的时候就可以部分借鉴这篇论文中的训练方法:引入辅助的loss。【In initial experiments, we found training a network deeper than ten layers to be challenging, with slow convergence and poor accuracy. We were able to deepen the network to better effect through the addition auxiliary losses, which sped up convergence of the training significantly.】

如下图所示,这个辅助的loss分为3类:Multiple Positions; Intermediate Layer Losses; Multiple Targets。

We hypothesize that these losses not only speed up convergence but also serve as an additional regularizer.

During training, the auxiliary losses get added to the total loss of the network with discounted weights.

1、Multiple Positions

辅助loss的第一类loss就是:对于最后一层所有的节点都计算下一步应该预测的字符,即:

  • 在节点 EEE 处根据输入 t0t_0t0,预测输出为 t1t_1t1
  • 在节点 FFF 处根据输入为 t0t_0t0t1t_1t1,输出是 t2t_2t2,以此类推。

然后将每一个Positions处的loss加起来。

【第一类loss贯穿整个train的全部阶段,不发生衰减】

2、Intermediate Layer Losses

辅助loss的第二类是除了在最后一层计算交叉熵loss之外,在中间层也要计算,即:

  • 在节点 AAA 处根据输入 t0t_0t0,预测输出为 t1t_1t1
  • 在节点 BBB 处根据输入 t1t_1t1,预测输出为 t2t_2t2,以此类推;

但中间层的loss并不贯穿整个train始终,而是随着训练进行,逐渐衰减,衰减的方式是,一共有 nnn 层网络,当训练进行到 k2n\cfrac{k}{2n}2nk 时停止计算第 kkk 层 Loss。也就是说当训练进行到一半的时候,所有的 中间层 都不再贡献loss。

【中间层的loss并不贯穿整个train始终,而是随着训练进行,逐渐衰减,衰减的方式是,一共有n层网络,当训练进行到 $\cfrac{k}{2n}$ ​时停止计算第 $k$ 层 Loss。也就是说当训练进行到一半的时候,所有的 中间层 都不再贡献 Loss】

3、Multiple Targets

辅助 Loss 的第三类是在序列的每个位置,模型不仅对下一个token做预测,还会对下下一个token做预测。

在本论文中,每次预测下一步和下下步的字符结果,具体的看下面的图即可,非常清楚。

【但对于下下步的预测结果产生的loss是要发生衰减的,论文中该loss乘以0.5后再加入到整体的loss中】

三、Vanilla Transformer的相关结果

作者使用的数据集有enwik8,lm1b,text8这3个,列举了64层的transformer模型与12层的transformer模型(这个也是作者写的,目的是比较一下是否深度增加效果更好)还有一些RNN结构的模型进行了比较,实践证明该方法是比较好的,具体数据见论文,此处不列出。

但是作者有一个地方的比较结果我认为是很有意义的,这个对于我们以后设计模型有参考性,就是作者这篇论文里提到了加了3种辅助loss帮助训练,还有就是作者使用了momentum优化器训练,使用的pos embedding也是跟之前不同的。那么这些因素到底有没有用,如果有用,哪个用处大,有多大?针对这个问题作者进行了一个比较,比较的基线是上面讲的64层模型。

可以看出,辅助loss中的Multiple Positions和Intermediate Layer Losses效果是最明显的,至于使用了需要学习的pos embedding并没有太大的作用,优化器和Multiple Targets的辅助loss感觉效果都不大。

四、其他

该模型我认为的亮点就是添加了辅助loss帮助训练模型,缺点是计算量非常大,这一点作者自己也提到了,因为在预测阶段,每预测一个字符,就要将所有的结果重新计算一遍,它不能像RNN这种结构,隐节点保存了前面所有时刻的信息(保没保存住是另外一个维度的内容),只要给定的前一个时刻隐节点的信息和该时刻的输入,直接可以计算输出。




参考资料:
树状图:https://coggle.it/diagram/XcUjpYumpGC1iCzK/t/vanilla-transformer-2018
【读论文】Character-Level Language Modeling with Deeper Self-Attention(Vanilla Transformer)

NLP-生成模型-2018:Vanilla Transformer【将长文本序列划截断为多个固定长度的段;段与段之间没有上下文依赖性;无法建模字符之间超过固定长度的依赖,关系导致上下文碎片化】相关推荐

  1. 这六大方法,如何让 Transformer 轻松应对高难度长文本序列?

    2020-06-08 05:24:09 编译 | Mr Bear 编辑 | 丛末 众所周知,多头注意力机制 (Multi-Head Self-Attention) 的计算开销很大.在处理长度为 n 的 ...

  2. IJCAI 2018 基于主题信息的神经网络作文生成模型

    本文介绍哈尔滨工业大学社会计算与信息检索研究中心(SCIR)录用于IJCAI 2018的论文<Topic-to-Essay Generation with Neural Networks> ...

  3. 《预训练周刊》第14期:World-GAN:Minecraft 世界的生成模型、CMU博士论文探究可控文本生成...

    No.14 智源社区 预训练组 预 训 练 研究 观点 资源 活动 关于周刊 超大规模预训练模型是当前人工智能领域研究的热点,为了帮助研究与工程人员了解这一领域的进展和资讯,智源社区整理了第14期&l ...

  4. 如何解决NLP分类任务的11个关键问题:类别不平衡低耗时计算小样本鲁棒性测试检验长文本分类 JayLou娄杰

    原文链接:https://zhuanlan.zhihu.com/p/183852900 欢迎关注<高能AI>公众号- 声明:文中观点谨代表笔者个人立场,盲目搬运有风险- 在2020这个时间 ...

  5. 看我逆向小米rom层应用做碎片化适配

    作者博客 http://www.jianshu.com/u/abc8086489c7 文章目录 前言 基础知识 dex odex smali rom层应用分析 odex与apk 框架文件 逆向工具箱 ...

  6. 谷歌提出 RNN 版 Transformer,或为长文本建模的当前最优解

    文 | 小轶 今天给大家介绍一篇谷歌的最新工作,解决的是 Transformer 的长文本处理问题.在原生 Transformer 中,attention 的复杂度是输入序列长度的平方级别,因此限制了 ...

  7. java文件编译_【java】javac编译多个有依赖关系的java文件为class文件

    历史文章: [javac命令不能使用,提示不是内部命令或外部命令,请查看历史文章] =================需求说明========================== 之前的文章中,仅说明 ...

  8. IoT 物联网碎片化是云厂商的桎梏,中小企业的机会

    伴随着物联网概念的出现,直到今天,在提及物联网产业相关发展问题时,碎片化仍是一个绕不开的话题.很多人也都认可,是碎片化的问题阻碍了物联网的快速发展,导致企业无法大规模复制某个物联网项目. 01. &q ...

  9. 如何避免碎片化知识的危害

    长期通过微博.微信.知乎等平台接收碎片化的知识有什么弊端? Lachel,关注思考,生活,方法论 | 转载约稿请私信 你所接受的一切信息,构成了你的思维方式. 所以,长期接受碎片信息的后果,就是让你的 ...

最新文章

  1. Bzoj4016: [FJOI2014]最短路径树问题
  2. 计算机二级python基础知识总结-计算机二级python 知识点篇(程序的控制结构)...
  3. linux 脚本 获取当前目录,Linux下获取脚本当前工作目录的一点感触
  4. mssql 2008恢复xp_cmdshell
  5. 【kafka】 kafka如何设置指定分区进行发送和消费
  6. 惠普光影精灵拆机换屏幕_聊聊惠普游戏本大军的“先遣部队”
  7. 宅在家里写数据库中事务(ACID)
  8. pure CSS3 triangle icon
  9. java byte 转换int_java byte负数转换int失真?
  10. ie8下a标签中的图片出现边框
  11. 香港推广“绿色年宵” 呼吁商贩和市民惜物减废
  12. css 居中50%,CSS中的translate(-50%,-50%)实现水平垂直居中效果
  13. JCreator使用技巧
  14. matlab如何求空间一点到直线距离,空间点到直线距离怎么求
  15. 全面图解路由器接口及连接
  16. 简历的教育经历怎么写计算机,简历中教育经历怎么写?
  17. 运用特征脸方法的基于Opencv的猫脸检测实现
  18. xp怎么修改桌面图标?
  19. Docker数据卷挂载相关
  20. Paddle-NEAT——飞桨进化神经网络组件

热门文章

  1. Mybatis提取BigDecimal字段值显示丢失末尾0精度的问题解决
  2. 1734: 炮兵阵地
  3. 原来系统还能这样重装!看这里,教您一键重装win10系统
  4. Simulink对突变信号用transfer fcn的迟滞平滑处理
  5. wireshark 抓 蓝牙数据_利用nRF Sniffer对蓝牙BLE通信数据进行嗅探和分析
  6. 艾诺迪亚【八门神器+超级教程】
  7. 全球化(4):CultureInfo
  8. 为什么计算机模拟试题无法评分,科目一电脑模拟打分答题
  9. 基于两阶段鲁棒优化算法的微网多电源容量配置(Matlab)
  10. Kubernetes Pod垂直自动伸缩