欢迎大家前往腾讯云+社区,获取更多腾讯海量技术实践干货哦~

本文由张金超博士发表于云+社区专栏

导语: Google Tensor2Tensor系统是一套十分强大的深度学习系统,在多个任务上的表现非常抢眼。尤其在机器翻译问题上,单模型的表现就可以超过之前方法的集成模型。这一套系统的模型结构、训练和优化技巧等,可以被利用到公司的产品线上,直接转化成生产力。本文对Tensor2Tensor系统从模型到代码进行了全面的解析,期望能够给大家提供有用的信息。

第一章:概述

​ Tensor2Tensor(T2T)是Google Brain Team在Github上开源出来的一套基于TensorFlow的深度学习系统。该系统最初是希望完全使用Attention方法来建模序列到序列(Sequence-to-Sequence,Seq2Seq)的问题,对应于《Attention Is All You Need》这篇论文。该项工作有一个有意思的名字叫“Transformer”。随着系统的不断扩展,T2T支持的功能变得越来越多,目前可以建模的问题包括:图像分类,语言模型、情感分析、语音识别、文本摘要,机器翻译。T2T在很多任务上的表现很好,并且模型收敛比较快,在TF平台上的工程化代码实现的也非常好,是一个十分值得使用和学习的系统。

​ 如果是从工程应用的角度出发,想快速的上手使用T2T系统,只需要对模型有一些初步的了解,阅读一下workthrough文档,很快就能做模型训练和数据解码了。这就是该系统想要达到的目的,即降低深度学习模型的使用门槛。系统对数据处理、模型、超参、计算设备都进行了较高的封装,在使用的时候只需要给到数据路径、指定要使用的模型和超参、说明计算设备就可以将系统运行起来了。

​ 如果想深入了解系统的实现细节,在该系统上做二次开发或是实现一些研究性的想法,那就需要花费一定的时间和精力来对模型和代码进行研究。T2T是一个较复杂的系统,笔者近期对模型和代码实现进行了全面的学习,同时对涉及到序列到序列功能的代码进行了剥离和重构,投入了较多的时间成本。因笔者是做自然语言处理研究的,这篇文章里主要关注的是Transformer模型。写这篇文章一方面是总结和记录一下这个过程中的一些收获,另一方面是把自己对T2T的理解分享出来,希望能够提供一些有用的信息给同学们。

第二章:序列到序列任务与Transformer模型

2.1 序列到序列任务与Encoder-Decoder框架

​ 序列到序列(Sequence-to-Sequence)是自然语言处理中的一个常见任务,主要用来做泛文本生成的任务,像机器翻译、文本摘要、歌词/故事生成、对话机器人等。最具有代表性的一个任务就是机器翻译(Machine Translation),将一种语言的序列映射到另一个语言的序列。例如,在汉-英机器翻译任务中,模型要将一个汉语句子(词序列)转化成一个英语句子(词序列)。

​ 目前Encoder-Decoder框架是解决序列到序列问题的一个主流模型。模型使用Encoder对source sequence进行压缩表示,使用Decoder基于源端的压缩表示生成target sequence。该结构的好处是可以实现两个sequence之间end-to-end方式的建模,模型中所有的参数变量统一到一个目标函数下进行训练,模型表现较好。图1展示了Encoder-Decoder模型的结构,从底向上是一个机器翻译的过程。

图1: 使用Encoder-Decoder模型建模序列到序列的问题

​ Encoder和Decoder可以选用不同结构的Neural Network,比如RNN、CNN。RNN的工作方式是对序列根据时间步,依次进行压缩表示。使用RNN的时候,一般会使用双向的RNN结构。具体方式是使用一个RNN对序列中的元素进行从左往右的压缩表示,另一个RNN对序列进行从右向左的压缩表示。两种表示被联合起来使用,作为最终序列的分布式表示。使用CNN结构的时候,一般使用多层的结构,来实现序列局部表示到全局表示的过程。使用RNN建模句子可以看做是一种时间序列的观点,使用CNN建模句子可以看做一种结构化的观点。使用RNN结构的序列到序列模型主要包括RNNSearch、GNMT等,使用CNN结构的序列到序列模型主要有ConvS2S等。

2.2 神经网络模型与语言距离依赖现象

​ Transformer是一种建模序列的新方法,序列到序列的模型依然是沿用了上述经典的Encoder-Decoder结构,不同的是不再使用RNN或是CNN作为序列建模机制了,而是使用了self-attention机制。这种机制理论上的优势就是更容易捕获“长距离依赖信息(long distance dependency)”。所谓的“长距离依赖信息”可以这么来理解:1)一个词其实是一个可以表达多样性语义信息的符号(歧义问题)。2)一个词的语义确定,要依赖其所在的上下文环境。(根据上下文消岐)3)有的词可能需要一个范围较小的上下文环境就能确定其语义(短距离依赖现象),有的词可能需要一个范围较大的上下文环境才能确定其语义(长距离依赖现象)。

​ 举个例子,看下面两句话:“山上有很多杜鹃,春天到了的时候,会漫山遍野的开放,非常美丽。” “山上有很多杜鹃,春天到了的时候,会漫山遍野的啼鸣,非常婉转。”在这两句话中,“杜鹃”分别指花(azalea)和鸟(cuckoo)。在机器翻译问题中,如果不看距其比较远的距离的词,很难将“杜鹃”这个词翻译正确。该例子是比较明显的一个例子,可以明显的看到词之间的远距离依赖关系。当然,绝大多数的词义在一个较小范围的上下文语义环境中就可以确定,像上述的例子在语言中占的比例会相对较小。我们期望的是模型既能够很好的学习到短距离的依赖知识,也能够学习到长距离依赖的知识。

​ 那么,为什么Transformer中的self-attention理论上能够更好的捕获这种长短距离的依赖知识呢?我们直观的来看一下,基于RNN、CNN、self-attention的三种序列建模方法,任意两个词之间的交互距离上的区别。图2是一个使用双向RNN来对序列进行建模的方法。由于是对序列中的元素按顺序处理的,两个词之间的交互距离可以认为是他们之间的相对距离。W1和Wn之间的交互距离是n-1。带有门控(Gate)机制的RNN模型理论上可以对历史信息进行有选择的存储和遗忘,具有比纯RNN结构更好的表现,但是门控参数量一定的情况下,这种能力是一定的。随着句子的增长,相对距离的增大,存在明显的理论上限。

图2 使用双向RNN对序列进行建模

​ 图3展示了使用多层CNN对序列进行建模的方法。第一层的CNN单元覆盖的语义环境范围较小,第二层覆盖的语义环境范围会变大,依次类推,越深层的CNN单元,覆盖的语义环境会越大。一个词首先会在底层CNN单元上与其近距离的词产生交互,然后在稍高层次的CNN单元上与其更远一些词产生交互。所以,多层的CNN结构体现的是一种从局部到全局的特征抽取过程。词之间的交互距离,与他们的相对距离成正比。距离较远的词只能在较高的CNN节点上相遇,才产生交互。这个过程可能会存在较多的信息丢失。

图3 使用多层CNN对序列进行建模

​ 图4展示的是基于self-attention机制的序列建模方法。注意,为了使图展示的更清晰,少画了一些连接线,图中“sentence”层中的每个词和第一层self-attention layer中的节点都是全连接的关系,第一层self-attention layer和第二层self-attention layer之间的节点也都是全连接的关系。我们可以看到在这种建模方法中,任意两个词之间的交互距离都是1,与词之间的相对距离不存在关系。这种方式下,每个词的语义的确定,都考虑了与整个句子中所有的词的关系。多层的self-attention机制,使得这种全局交互变的更加复杂,能够捕获到更多的信息。

图4 使用self-attention对序列进行建模

​ 综上,self-attention机制在建模序列问题时,能够捕获长距离依赖知识,具有更好的理论基础。

2.3 self-attention机制的形式化表达

​ 上面小节介绍了self-attention机制的好处,本小结来介绍一下self-attention机制的的数学形式化表达。首先,从attention机制讲起。可以将attention机制看做一种query机制,即用一个query来检索一个memory区域。我们将query表示为key_q,memory是一个键值对集合(a set of key-value pairs),共有M项,其中的第i项我们表示为<key_m[i], value_m[i]>。通过计算query和key_m[i]的相关度,来决定查询结果中,value_m[i]所占的权重比例。注意,这里的key_q,key_m,value_m都是vector。

​ Attention的计算概括起来分三步:1)计算query和memory中每个key_m的相关度。2)对所有的相关度结果使用softmax函数进行概率归一化处理。3)根据概率归一化结果对memory中的所有value_m进行加权平均,得到最终的查询结果。计算过程,形式化为:

​ 常用的相关度计算函数有基于加法式(additive)的和乘法式(dot-product)的两种。加法式的函数,先要经过一个前向神经网络单元,再经过一个线性变换,得到一个实数值。乘法式的函数则是两个向量的直接点乘,得到一个实数值。

​ 在Encoder-Decoder框架中,attention机制一般用于连接Encoder和Decoder,即以Decoder的状态作为key,以源语言句子的分布式表示作为memory,从中查找出相关的源语言信息,生成目标语言的词语。在该机制中,memory中的key_m和value_m是相同的。在self-attention机制中,每个词汇以自己的embedding为query,查询由所有词汇的embedding构成的memory空间,得到查询结果作为本词的表示。假如句子长度为n,所有的词分别查询一遍memory得到的结果长度依然会是n。这些词的查询过程是可以并行的。如果relation函数是乘法式的,那么这个查询的过程就是矩阵的乘法,可以形式化为:

在self-attention中,Q=K=V,是一个由所有词的词向量构成的一个矩阵。

综上,self-attention是一种序列建模的方式,在对句子进行分布式表示的时候,句子中的所有的词都会发生直接的交互关系。

2.4 “Attention is All You Need”

​ 《Attention Is All You Need》这篇文章,描述了一个基于self-attention的序列到序列的模型,即“Transformer”。该模型将WMT2014英-德翻译任务的BLEU值推到了新高,在英-法翻译任务上,接近于之前报出的最好成绩,而这仅仅是Transformer单模型的表现。之前报出的最好成绩都是基于集成方法的,需要训练多个模型,最后做集成。同时该模型也被用在英语的成分句法分析任务上,表现也基本接近于之前报出的最好模型成绩。该模型的收敛速度也非常的快,在英-法3600万句对的训练集上,只需要8卡并行3.5天就可以收敛。

​ 该模型的表现的如此好的原因,其实不仅仅是一个self-attention机制导致的,实际上Transformer模型中使用了非常多有效的策略来使得模型对数据的拟合能力更强,收敛速度更快。整个Transformer的模型是一套解决方案,而不仅仅是对序列建模机制的改进。下面我们对其进行一一讲解。

2.4.1 Self-attention机制的变种

​ 首先,还是来讲一下Transformer中的self-attention机制。上面讲到了self-attention的基本形式,但是Transformer里面的self-attention机制是一种新的变种,体现在两点,一方面是加了一个缩放因子(scaling factor),另一方面是引入了多头机制(multi-head attention)。

​ 缩放因子体现在Attention的计算公式中多了一个向量的维度作为分母,目的是想避免维度过大导致的点乘结果过大,进入softmax函数的饱和域,引起梯度过小。Transformer中的self-attention计算公式如下:

多头机制是指,引入多组的参数矩阵来分别对Q、K、V进行线性变换求self-attention的结果,然后将所有的结果拼接起来作为最后的self-attention输出。这样描述可能不太好理解,一看公式和示意图就会明白了,如下:

图5 单头和多头的Attention结构

​ 这种方式使得模型具有多套比较独立的attention参数,理论上可以增强模型的能力。

2.4.2 位置编码(Positional Encoding)

​ self-attention机制建模序列的方式,既不是RNN的时序观点,也不是CNN的结构化观点,而是一种词袋(bag of words)的观点。进一步阐述的话,应该说该机制视一个序列为扁平的结构,因为不论看上去距离多远的词,在self-attention机制中都为1。这样的建模方式,实际上会丢失词之间的相对距离关系。举个例子就是,“牛 吃了 草”、“草 吃了 牛”,“吃了 牛 草”三个句子建模出来的每个词对应的表示,会是一致的。

​ 为了缓解这个问题,Transformer中将词在句子中所处的位置映射成vector,补充到其embedding中去。该思路并不是第一次被提出,CNN模型其实也存在同样的难以建模相对位置(时序信息)的缺陷,Facebook提出了位置编码的方法。一种直接的方式是,直接对绝对位置信息建模到embedding里面,即将词Wi的i映射成一个向量,加到其embedding中去。这种方式的缺点是只能建模有限长度的序列。Transformer文章中提出了一种非常新颖的时序信息建模方式,就是利用三角函数的周期性,来建模词之间的相对位置关系。具体的方式是将绝对位置作为三角函数中的变量做计算,具体公式如下:

​ 该公式的设计非常先验,尤其是分母部分,不太好解释。从笔者个人的观点来看,一方面三角函数有很好的周期性,也就是隔一定的距离,因变量的值会重复出现,这种特性可以用来建模相对距离;另一方面,三角函数的值域是[-1,1],可以很好的提供embedding元素的值。

2.4.3 多层结构

​ Transformer中的多层结构非常强大,使用了之前已经被验证过的很多有效的方法,包括:residual connection、layer normalization,另外还有self-attention层与Feed Forward层的堆叠使用,也是非常值得参考的结构。图6展示了Transformer的Encoder和Decoder一层的结构。

图6 Transformer模型结构

​ 图6中,左侧的Nx代表一层的Encoder,这一层中包含了两个子层(sub-layer),第一个子层是多头的self-attention layer,第二个子层是一个Feed Forward层。每个子层的输入和输出都存在着residual connection,这种方式理论上可以很好的回传梯度。Layer Normalization的使用可以加快模型的收敛速度。self-attention子层的计算,我们前面用了不少的篇幅讲过了,这里就不再赘述了。Feed Forward子层实现中有两次线性变换,一次Relu非线性激活,具体计算公式如下:

文章的附页中将这种计算方式也看做是一种attention的变种形式。

图6中,右侧是Decoder中一层的结构,这一层中存在三个子层结构,第一层是self-attention layer用来建模已经生成的目标端句子。在训练的过程中,需要一个mask矩阵来控制每次self-attention计算的时候,只计算到前t-1个词,具体的实现方式,我们会在后面讲代码实现的时候进行说明。第二个子层是Encoder和Decoder之间的attention机制,也就是去源语言中找相关的语义信息,这部分的计算与其他序列到序列的注意力计算一致,在Transformer中使用了dot-product的方式。第三个子层是Feed Forward层,与Encoder中的子层完全一致。每个子层也都存在着residual connection和layer normalization操作,以加快模型收敛。

Transformer中的这种多层-多子层的机制,可以使得模型的复杂度和可训练程度都变高,达到非常强的效果,值得我们借鉴。

2.4.4 优化方法与正则策略

​ 模型的训练采用了Adam方法,文章提出了一种叫warm up的学习率调节方法,如公式所示:

公式比较先验,看上去比较复杂,其实逻辑表达起来比较清楚,需要预先设置一个warmup_steps超参。当训练步数step_num小于该值时,以括号中的第二项公式决定学习率,该公式实际是step_num变量的斜率为正的线性函数。当训练步数step_num大于warm_steps时,以括号中的第一项决定学习率,该公式就成了一个指数为负数的幂函数。所以整体来看,学习率呈先上升后下降的趋势,有利于模型的快速收敛。

模型中也采用了两项比较重要的正则化方法,一个就是常用的dropout方法,用在每个子层的后面和attention的计算中。另一个就是label smoothing方法,也就是训练的时候,计算交叉熵的时候,不再是one-hot的标准答案了,而是每个0值处也填充上一个非0的极小值。这样可以增强模型的鲁棒性,提升模型的BLEU值。这个思路其实也是一定程度在解决训练和解码过程中存在的exposure bias的问题。

2.4.5 本章小结

​ Transformer系统的强大表现,不仅仅是self-attention机制,还需要上述的一系列配合使用的策略。设计该系统的研究者对深度学习模型和优化算法有着非常深刻的认识和敏锐的感觉,很多地方值得我们借鉴学习。Transformer的代码实现工程化比较好,但是也存在一些地方不方便阅读和理解,后面的章节中会对Transformer的代码实现进行详细讲解,将整体结构讲清楚,把其中的疑难模块点出来。

第三章:Tensor2Tensor系统实现深度解析

​ Tensor2Tensor的系统存在一些特点,导致使用和理解的时候可能会存在一些需要时间来思考和消化的地方,在此根据个人的理解,写出一些自己曾经花费时间的地方。

3.1 使用篇

​ Tensor2Tensor的使用是比较方便的,对于系统中可以支持的问题,直接给系统设置好下面的信息就可以运行了:数据,问题(problem),模型,超参集合,运行设备。这里的实现其实是采用了设计模型中的工厂模式,即给定一个问题名字,返回给相应的处理类;给定一个超参名,返回一套超参的对象。实现这种方式的一个重点文件是utils/registry.py。在系统启动的时候,所有的问题和超参都会在registry中注册,保存到_MODELS,_HPAPAMS,_RANGED_HPARAMS中等待调用。

​ 在此主要以序列到序列的系统使用和实现为主线进行讲解。系统的运行分三个阶段:数据处理,训练,解码。对应着三个入口:t2t-datagen,t2t-trainer,t2t-decoder。

数据处理的过程包括:

​ 1.(下载)读取训练和开发数据。如果需要使用自己的数据的话,可以在问题中指定。

​ 2.(读取)构造词汇表。可以使用自己预先构造好的词汇表。系统也提供构建BPE词汇表的方法。注意,这里有个实现细节是系统在抽取BPE词汇表的时候,有个参数,默认并非使用全量的数据。通过多次迭代尝试,得到最接近预设词汇表规模的一个词汇表。在大数据量的时候,这个迭代过程会非常慢。

​ 3. 使用词汇表将单词映射成id,每个句子后会加EOS_ID,每个平行句对被构造成一个dict对象({‘inputs’:value,‘targets’:value}),将所有对象序列化,写入到文件中,供后面训练和评价使用。

模型训练的过程的过程主要通过高级的Tensorflow API来管理,只是需要指定数据、问题名、模型名、超参名、设备信息就可以运行了。比较关键的一个文件是utils/trainer_lib.py文件,在这个文件中,构建Experiment、Estimator、Monitor等来控制训练流程。使用者主要需要设置的就是训练过程的一些参数,比如训练最大迭代次数,模型评估的频率,模型评估的指标等。超参可以直接使用系统已有的参数集,也可以通过字符串的形式向内传参。简单的任务不太需要动超参,因为系统中的超参集合基本上都是经过实验效果验证的。需要注意的就是batch_size过大的时候,可能会导致显存不足,导致程序错误。一般是使用continuous_train_and_eval模式,使模型的训练和评估间隔进行,随时可以监控模型的表现。

解码的过程,可以提供整体文件、也可以是基于Dataset的,同时系统也提供server的方式,可以提供在线的服务,并没有什么特别好讲的。

3.2 深度掌握篇

3.2.1 Tensor2Tensor系统实现的特点

​ 下面列出了要深度掌握Tensor2Tensor系统时,可能因为其实现特点,会遇到的一些问题:

​ 1. 系统支持多任务,任务混杂,导致代码结构比较复杂。在实现的时候,要考虑到整体的结构,所以会存在各种封装、继承、多态的实现。可能你只想用其中的一个功能,理解该功能对应的代码,但是却需要排除掉大量的不相关的代码。

​ 2. 系统基于Tensorflow封装较高的API。使用了Tensorflow中比较高的API来管理模型的训练和预测,Experiment,Monitor,Estimator,Dataset对象的使用隐藏了比较多的控制流程,对于侧重应用的用户来说,可能是是好事情,设一设参数就能跑。但是对于想了解更多的开发人员来说,TF该部分的文档实在很少,说的也不清楚,很多时候需要去阅读源代码才能知道实验到底是不是按照自己预期的进行的。这种方式也不太方便找bug和调试。

​ 3. 某些方法调用比较深。原因应该还是出于整体结构和扩展性的考虑。这导致了实现一点很小功能的方法A,需要再调一个其他方法B,B再去调用方法C,实际上每个方法中就几行代码,甚至有的方法就是空操作。

​ 4. 多层继承和多态也降低了代码的可读性。追溯一个类的某个方法的时候,需要看到其父类的父类的父类。。。这些父类和子类之间的方法又存在着调来调去的关系,同名方法又存在着覆盖的关系,所以要花一些时间来确定当前的方法名到底是调用的的哪个类中的方法。

​ 5. 要求开发者有模型层面的理解和与代码实现的挂钩。肯定是要提高对模型逻辑的理解,但在读代码的过程中,会遇到两种问题:第一个,代码实现的是论文中的功能,但不是论文中的原始公式,可能要做变形以规避溢出的问题,或是实现更高的效率;第二个,某些代码实现与其论文中的表述存在不一致的情况。

3.2.2 总体逻辑模块

总体来说,对T2T系统的代码逻辑划分如下,共包括三个大的模块:

  1. **问题定义和数据管理的模块。**该模块用来定义问题和处理数据,比如定义一个翻译的问题,里面定义抽词汇表和构造训练样本的方法。
  2. **模型定义和计算图构建的模块。**该模块用来定义模型属性和计算图结构。
  3. **实验流程控制与并行化模块。**该模块用于实验流程控制,设置可用计算设备,提供模型并行化运行方法。

图7 Tensor2Tensor主要逻辑模块

这里不会对代码做追踪式的分析,会分条的讲解一些阅读Tensor2Tensor系统代码时可能遇到的问题,点出一些重要的功能所在的位置和实现逻辑。

  1. **工厂模式。**系统使用工厂模式管理问题、模型、超参、模态等模块的方法。前面在使用篇讲到了registry.py这个比较关键的文件,是系统总体管理和调度模块的一个核心文件。如果要在系统中增加新的问题、模型、超参、模态等,也都需要通过在类前加装饰器的方式来注册到registry中,否则系统找不到新加的模块。
  2. **问题类(problem)。**data_generators/problem.py中的class Problem是后面所有problem的基类。之前说到系统中的类之间的多层继承关系导致代码读起来比较麻烦,举个例子来说,一个翻译问题继承路线是这样的:Problem>>Text2TextProblem>>TranslateProblem>>TranslateEndeWmtBpe32k>> TranslateEndeWmt32k,中间各种的方法和变量覆盖,父类和子类之间方法的穿插调用,导致一些阅读困难。总的来说,一个序列到序列的问题应该包括以下信息和方法:数据文件信息,词汇表文件名、类型、大小,构造词汇表的方法,序列化训练数据和开发数据的方法,读取数据文件为model(estimator)构造输入流input_fn的方法,设定问题评估metric的方法。可以总结来说,问题的属性定义、训练和评价样本的构造、数据的处理和读取,都由problem这个体系里面的类和方法来提供。
  3. **词汇表对象(TextEncoder)。**系统中有多种多样的词汇表(TextEncoder)对象,可以支持字母(character),子词(subword/bpe),词汇(token)等多重方式。TextEncoder主要功能就是构建词汇表、实现符号到id的映射。T2T里有构造bpe词汇表的方法,没有word piece词汇表的构造方法,也可以看出T2T研究团队和GNMT研究团队的区分。两个团队一直在交替的更新机器翻译任务的最高成绩。构建BPE词汇表的具体实现在SubwordTextEncoder中的 build_to_target_size()方法,该方法不是之前Sennrich使用迭代次数来控制词汇表大小的方式,而是使用二分查找的方式,通过搜索最优的minimum token count值来逼近预先设置的词汇表的大小。
  4. **T2TModel类。**utils/t2t_model.py中的class T2TModel是模型功能的基类,该类继承自layer,Transformer类便继承于此类。T2TModel类中定义了模型的计算图结构,即给定feature后,模型是怎么根据feature进行图计算,得到logit,loss,然后根据loss求梯度,调用optimizer进行梯度回传,进行参数更新的。构建计算图的目的是最终要构建tf.estimator.EstimatorSpec()对象。可以理解为,所有的模型图计算过程都在该对象中被表达了。T2TModel可以返回三种EstimatorSpec对象,分别用于训练、评价和解码。训练的过程可以支持数据并行,具体的实现是同时在多个数据片上激活计算图,得到的loss做平均,这是一种同步并行训练的方式。T2TModel中也提供了供解码的方法。
  5. **Transformer类。**models/transformer.py中的class Transformer继承自class T2TModel,为其父类构建图的时候,提供各种支持的方法,encode方法可以使用Encoder结构对源端进行压缩表示,decode方法使用Decoder结构对目标端进行生成。同时,transformer.py中有多套参数供选择。模型中feed-forward子层的实现也在该文件中(transformer_ffn_layer)。
  6. 数据并行类。devices.py和expert_utils.py配合使用,主要功能是根据用户给定的并行设备参数,列出可以使用的设备名,然后给定一个能够调用这些设备,并行执行方法的方法。
  7. **实验流程控制。**实验流程控制使用的是Tensorflow的高级API对象,主要包括Experiment对象、Estimator对象、Dataset对象。对这三个对象,我们可以这么理解:a) Experiment是一次运行的实验,用来控制实验流程,输送数据到模型。b) Estimator是具体的模型对象,可以包括训练、评估、解码三个功能。c) Dataset为运行的实验过程读数据文件提供数据流。
  8. Experiment对象。我们来看下图中Experiment初始化所需的形参就能更好的理解“实验”这个概念了。Experiment对象中需要迭代中的各种step参数,需要一个Estimator对象,两个输入流函数(input)。Experiment对象在运行中,将数据给到Estimator对象,然后控制训练和迭代流程。

图8 Experiment对象的部分形参

9.Estimator对象。可以理解为模型对象,可以通过Estimator执行模型的训练、评估、解码。Estimator对象最重要的一个形参是model_fn,也就是具体执行训练、评估、解码的函数入口。三个入口分别对应着三个EstimatorSpec对象,如图9,10所示。

图9 Estimator中最重要的形参是model_fn图10 Estimator中的三种model_fn,实现三种功能

​ 从图10可以看出,用于训练的EstimatorSpec对象需要描述计算图中feature和(loss,train_op)之间的关系;用于评估的EstimatorSpec对象需要描述计算图中feature和(loss,eval_metrics_ops)之间的关系;用于评估的EstimatorSpec对象需要描述features和predictions之间的关系。

  1. Dataset对象。该对象是读文件,构造训练和评估的数据流。训练和评估对应着两种不同的数据输入流,如图11所示。

图11 Dataset对象提供数据流

\11. Positional encoding的实现。论文中的实现和代码中的实现存在公式变形和不一致的情况,可能会导致困惑,故在此指出。论文中Positional encoding中三角函数的参数部分公式如下:

​ 代码中的实现需要对该公式做变形,以规避数值溢出的风险,公式变形过程如下:

​ 还需要指出的是,论文中根据维度下标的奇偶性来交替使用sin和cos函数的说法,在代码中并不是这样实现的,而是前一半的维度使用sin函数,后一半的维度使用cos函数,并没有考虑奇偶性

​ 12. **以token数量作为batch size。**这种方式比起以句子个数作为batch size的方式来,能到batch占显存的空间更加平均,不会导致因为训练数据导致的显存占用忽上忽下,造成显存空间不够用,导致程序崩溃。

\13. 如何做mask。由于模型是以batch为单位进行训练的,batch的句长以其中最长的那个句子为准,其他句子要做padding。padding项在计算的过程中如果不处理的话,会引入噪音,所以就需要mask,来使padding项不对计算起作用。mask在attention机制中的实现非常简单,就是在softmax之前,把padding位置元素加一个极大的负数,强制其softmax后的概率结果为0。举个例子,[1,1,1]经过softmax计算后结果约为[0.33,0.33,0.33],[1,1,-1e9] softmax的计算结果约为[0.5, 0.5,0]。这样就相当于mask掉了数组中的第三项元素。在对target sequence进行建模的时候,需要保证每次只attention到前t-1个单词,这个地方也需要mask,整体的mask是一个上三角矩阵,非0元素值为一个极大的负值。

\14. 基于batch的解码。解码的时候,如果是基于文件的,那么就会将句子组成batch来并行解码。这里有个小trick,就是先对句子进行排序,然后从长的句子开始组batch,翻译,再把句子恢复成原先的顺序返回。这种方式可以很好的检测到显存不足的错误,因为解句子最长的一个batch的时候,显存都是够得,那其他的batch也不存在问题。

总结

​ 本文对Google的Tensor2Tensor系统进行了深度的解读,涉及到了比较多的方面,笔者也还需要对其进行更加深入的学习和研究,希望能够与对该模型以及DL for NLP技术感兴趣的同学们一起交流,共同进步!

问答

docker和docker-compose有什么不同?

相关阅读

深度学习之神经网络核心原理与算法-归一化与参数初始化

启发式寻路算法

深度学习(5)——RBF算法简介

此文已由作者授权腾讯云+社区发布,原文链接:https://cloud.tencent.com/developer/article/1116709?fromSource=waitui

欢迎大家前往腾讯云+社区或关注云加社区微信公众号(QcloudCommunity),第一时间获取更多海量技术实践干货哦~

海量技术实践经验,尽在云加社区! https://cloud.tencent.com/developer?fromSource=waitui

“变形金刚”为何强大:从模型到代码全面解析Google Tensor2Tensor系统相关推荐

  1. 六种人体姿态估计的深度学习模型和代码总结

    六种人体姿态估计的深度学习模型和代码总结 姿态估计的目标是在RGB图像或视频中描绘出人体的形状,这是一种多方面任务,其中包含了目标检测.姿态估计.分割等等.有些需要在非水平表面进行定位的应用可能也会用 ...

  2. TF之p2p:基于TF利用p2p模型部分代码实现提高图像的分辨率

    TF之p2p:基于TF利用p2p模型部分代码实现提高图像的分辨率 目录 一.tfimage.py文件功能解释 二.process.py添加一个新操作 一.tfimage.py文件功能解释 1.此处的c ...

  3. 模型开发-GBDT决策树模型开发代码

    GBDT(Gradient Boosting Decision Tree) 又叫 MART(Multiple Additive Regression Tree),是一种迭代的决策树算法,该算法由多棵决 ...

  4. 数学建模——TOPSIS综合评价模型Python代码

    数学建模--TOPSIS综合评价模型Python代码 正常代码 import numpy as np # 导入numpy包并将其命名为np ##定义正向化的函数 def positivization( ...

  5. 人脸口罩检测现开源PyTorch、TensorFlow、MXNet等全部五大主流深度学习框架模型和代码...

    号外!号外! 现在,AIZOO开源PyTorch.TensorFlow.MXNet.Keras和Caffe五大主流深度学习框架的人脸检测模型和代码啦! 先附上Github链接为敬. https://g ...

  6. 三维点云网络PointNet——模型及代码分析

    PointNet架构 PointNet主要架构如下图所示: 主要包含了点云对齐/转换.mlp学习.最大池化得到全局特征三个主要的部分. -T-Net用于将不同旋转平移的原始点云和点云特征进行规范化: ...

  7. thinkphp 关联模型配置代码

    原文:thinkphp 关联模型配置代码 <?php /*** 公司与部门关联模型*/ class CompanyRelationModel extends RelationModel{//主表 ...

  8. SIR传染模型Matlab代码,sir传染病模型 MATLAB代码运行不了,

    问题描述: sir传染病模型 MATLAB代码运行不了, function y=ill(t,x) a=1;b=0.3; y=[a*x(1)*x(2)-b*x(1),-a*x(1)*x(2)]'; ts ...

  9. 对角阵在matlab,MATLABSimulink实现对角阵解耦(模型和代码).pdf

    MATLABSimulink实现对角阵解耦(模型和代码) MATLAB/Simulink 实现对角阵解耦 (模型和代码) 1.耦合与解耦: 在控制系统中,不同被控量之间往往存在相互影响,比如某封闭罐体 ...

最新文章

  1. 进一步封装axios并调用其读取数据(吐槽~在安卓9.0以下或者IOS10.X以下手机端H5页面不支持,在这两种情况下的系统只能使用ajax或者原生js请求后台数据)
  2. CAS 服务器端取消 https的配置 方法
  3. input的type为number
  4. USACO-Section1.4 Ski Course Design (枚举)
  5. IBM携手MIT组建新实验室:人工智能将有像人一样的视听功能
  6. 怎么看xp计算机是32位还是64位,教你查看XP系统的不同32位还是64位详细的步骤
  7. DataGridView 获取当前行数据
  8. can例程 ecu_ECU程序及CAN总线实现
  9. 老庙黄金2016春晚抢红包活动技术架构详解
  10. JavaScript高级程序设计读书笔记(第6章面向对象的程序设计之创建对象)
  11. Gbox开源:比RN和WebView更轻的高性能动态化业务容器,你掌握了多少
  12. Java自学书籍推荐,java程序员面试算法宝典
  13. 钉钉在线课程开启屏幕共享时电脑蓝屏问题解决办法
  14. NKOI 1905 慢跑小路
  15. java ice c_Java的Ice包接收中文乱码
  16. 隐私政策网址 (URL)
  17. mysql数据库插入数据为空_插入数据成功,但是数据库中显示为空(菜鸟提问)...
  18. Remoting学习
  19. LTE default bearer dedicated bearer and radio bearer
  20. 【Http协议】Http协议简介

热门文章

  1. 为开发者准备的9个实用PHP代码片段(转)
  2. 多线程socket 端口扫描程序,实现了,但是速度不行,求指点。
  3. vsphere client中部署OVF项目后为项目分配IP
  4. 遍历Map keySet和entrySet
  5. idea导入gitlab上面的项目
  6. 文件源码读取 php伪协议,include(文件包含漏洞,php伪协议)
  7. ubuntu安装ftp_如何在 Ubuntu 20.04 上安装 Webmin
  8. celery 可视化_Django中Celery的实现介绍(一)
  9. 浅谈HotSpot逃逸分析
  10. 第四部分 Calendar使用示例