导读:在序列生成类任务中,如机器翻译、自动摘要等,Seq2Seq是一种非常强大的模型。作为一种条件语言模型,它直接对P(y|x)进行建模,在生成y的过程中,始终有x作为条件。当训练好了一个这样的模型后,在预测过程中,需要进行解码来找到最有可能的输出序列。这篇文章主要讲解Sequence to Sequence模型在预测阶段中使用的序列解码策略。

一些自然语言处理任务,如脚注生成、机器翻译等,会涉及到生成单词序列,也就是预测结果是单词组成的一个序列。通常情况下,针对这些问题开发的模型会为输出序列中的每个单词生成词汇表中的每个单词上的概率分布,然后用在解码过程中以把这些概率分布转换成最终的单词序列。

解码最有可能的输出序列会涉及到在所有可能的输出序列上基于其概率进行搜索。词汇表的大小通常是由成百上千甚至百万个单词。可以想象,搜索难度在输出序列的长度上是呈指数级增长的,并且穷举所有的可能情况是不切实际的。

因此,在实践中,采用启发式搜索方法来返回一个或更多的近似的或者足够好的被解码好的的输出序列作为最终的预测结果,依据它们的概率值对候选词序列进行评分。

“As the size of the search graph is exponential in the source sentence length, we have to use approximations to find a solution efficiently.” —— Handbook of Natural Language Processing and Machine Translation

很常见的做法就是使用贪心搜索(Greedy Search)或集束搜索(Beam Search)来确定候选文本序列。

“Each individual prediction has an associated score (or probability) and we are interested in output sequence with maximal score (or maximal probability) (…) One popular approximate technique is using greedy prediction, taking the highest scoring item at each stage. While this approach is often effective, it is obviously non-optimal. Indeed, using beam search as an approximate search often works far better than the greedy approach.”——Neural Network Methods in Natural Language Processing

贪心搜索解码器

一种简单直观的方法是使用贪心搜索进行序列解码,在输出序列的每一步中始终选择最有可能(即最大概率)的词。具体来说,就是在生成第一个词 y < 1 > y^{<1>} y<1> 的分布之后,根据条件语言模型挑选出最有可能的第一个词 y < 1 > y^{<1>} y<1>,然后生成第二个词 y < 2 > y^{<2>} y<2> 的概率分布,再挑选出第二个词 y < 2 > y^{<2>} y<2>,以此类推。贪心搜索方法的好处就是它非常快,但是它只能保证每一步都是最优的,无法保证最终的预测序列整体是最优的,特别是如果在 t t t 时刻贪心搜索选择的词不是全局最优,会导致 t t t 时刻往后的所有预测词都是错误的,没有回头路了。如果每个时间步都穷举所有可能的情况的话,时间复杂度 O ( V T ) O(V^T) O(VT) 又太高了。下面我们以机器翻译为例来说明。

法语句子:“Jane visite l’Afrique en septembre.”
翻译 1:Jane is visiting Africa in September.
翻译 2:Jane is going to be visiting Africa in September.

很明显,翻译 1 要比翻译 2 更好,更加简洁明了,相比之下,翻译 2 就显得啰嗦。如果贪心搜索算法挑选’Jane’、‘is’作为输出序列的前两个词,即 y < 1 > , y < 2 > = y^{<1>},y^{<2>}= y<1>,y<2>=(‘Jane’, ‘is’),那么当挑选第三个词 y < 3 > y^{<3>} y<3> 时,贪心搜索算法会选择’going’而不是’visiting’,因为在英语中’is going’比’is visiting’更加常见,'going’的概率是最大的,即 P ( ′ g o i n g ′ ∣ ′ J a n e ′ , ′ i s ′ ) > P ( ′ v i s i t i n g ′ ∣ ′ J a n e ′ , ′ i s ′ ) P('going' | 'Jane', 'is') > P('visiting' | 'Jane', 'is') P(′going′∣′Jane′,′is′)>P(′visiting′∣′Jane′,′is′)。最终你会得到一个翻译效果并不是最好的句子。

下面我们通过一个简单的例子来演示用贪心搜索方法进行解码的过程。假设我们需要预测一个由 10 个单词组成的序列,所使用的的词汇表由 5 个单词组成。所得到的输出序列中每个单词在整个词汇表中的每个单词上的概率分布如下:

# 定义一个由10个单词组成的序列,单词来自于大小为5的词汇表
data = [[0.1, 0.2, 0.3, 0.4, 0.5],[0.5, 0.4, 0.3, 0.2, 0.1],[0.1, 0.2, 0.3, 0.4, 0.5],[0.5, 0.4, 0.3, 0.2, 0.1],[0.1, 0.2, 0.3, 0.4, 0.5],[0.5, 0.4, 0.3, 0.2, 0.1],[0.1, 0.2, 0.3, 0.4, 0.5],[0.5, 0.4, 0.3, 0.2, 0.1],[0.1, 0.2, 0.3, 0.4, 0.5],[0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)

贪心搜索算法在预测的每一步中选择最有可能的或概率值最大的单词作为输出,我们可以使用 argmax()这个函数来选择序列的每一步中最有可能的词索引值。下面的代码完整地演示了贪心搜索解码策略的过程:

from numpy import array
from numpy import argmax# 贪心搜索解码器
def greedy_decoder(data):# 每行最大概率的索引号return [argmax(s) for s in data]# 定义一个由10个单词组成的序列,单词来自于大小为5的词汇表
data = [[0.1, 0.2, 0.3, 0.4, 0.5],[0.5, 0.4, 0.3, 0.2, 0.1],[0.1, 0.2, 0.3, 0.4, 0.5],[0.5, 0.4, 0.3, 0.2, 0.1],[0.1, 0.2, 0.3, 0.4, 0.5],[0.5, 0.4, 0.3, 0.2, 0.1],[0.1, 0.2, 0.3, 0.4, 0.5],[0.5, 0.4, 0.3, 0.2, 0.1],[0.1, 0.2, 0.3, 0.4, 0.5],[0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)
# 解码输出序列
result = greedy_decoder(data)
print(result)

运行这段样例程序会输出一个整型数字序列,这些数字表示对应单词在字典中的索引号。

[4, 0, 4, 0, 4, 0, 4, 0, 4, 0]

集束搜索解码器

另一种流行的方法是介于贪心搜索和穷举搜索之间的一种折中方案——集束搜索(Beam Search),它能够返回一个最有可能的输出序列的列表。当构造输出序列时它不是贪婪地选择最有可能的下一步,集束搜索扩展所有可能的下一步,仅保留 k 个最有可能的。其中,k 是一个用户指定的参数,控制着整个概率序列的集束(beams)或并行搜索(parallel searches)的数量。

”The local beam search algorithm keeps track of k states rather than just one. It begins with k randomly generated states. At each step, all the successors of all k states are generated. If any one is a goal, the algorithm halts. Otherwise, it selects the k best successors from the complete list and repeats.“——Artificial Intelligence: A Modern Approach

请注意,k 个最高概率不仅仅是指当前时刻 y ^ t \hat{y}_t y^​t​ 的最高概率,而且是截止目前这条路径上的累计概率之和,序列得分计算公式如下:

score ( y 1 , … , y t ) = log ⁡ P L M ( y 1 , … , y t ∣ x ) = ∑ i = 1 t log ⁡ P L M ( y i ∣ y 1 , … , y i − 1 , x ) \begin{array}{ll} \text{score}(y_1,\dots,y_t) &=\log P_{LM}(y_1,\dots,y_t|x) \\ &= \sum_{i=1}^t \log P_{LM}(y_i | y_1, \dots, y_{i-1}, x) \end{array} score(y1​,…,yt​)​=logPLM​(y1​,…,yt​∣x)=∑i=1t​logPLM​(yi​∣y1​,…,yi−1​,x)​

概率是小数,将小数相乘会产生非常小的数。 为避免浮点数下溢,将概率的自然对数相乘在一起以使数字更大且易于管理。 所有的分数取其负数,分数越高越好。最后,我们可以按照得分的升序排列所有候选序列,并选择前 k 个作为最可能的候选序列。

贪心搜索常用的集束宽(beam widths)是 1,在机器翻译的一些基准问题中常用的集束宽是 5 到 10 之间。更大的集束宽会让模型的表现变得更好,因为多个候选序列增加了更好地匹配到目标序列的可能性。但是,表现变好的同时消耗的资源越多,解码速度也会下降。

”The local beam search algorithm keeps track of k states rather than just one. It begins with k randomly generated states. At each step, all the successors of all k states are generated. If any one is a goal, the algorithm halts. Otherwise, it selects the k best successors from the complete list and repeats.“——Beam Search Strategies for Neural Machine Translation

对于每个候选序列来说,要么达到最大序列长度,要么遇到序列终止符号,亦或者达到某个概率阈值,集束搜索过程就会终止。

接下来,我们举例说明集束搜索解码序列的详细过程。假设 k=2,第一个时间步保留 2 个最高概率的词为"he"和"I",它们分别作为下一个时间步的输入。“he"输入预测输出的前 2 名是"hit"和"struck”,则"hit"这条路的累加概率是"he"的概率加上"hit"的概率等于-1.7。同样地,可以计算出其他几个词对应路径的概率得分。最后,在这 4 条路径上保留 k=2 条路径,所以"hit"和"was"对应的路径被保留,作为下一个时间步的输入,"struck"和"got"对应的路径被剪枝。

最终的搜索树如下图所示,可以看到在每个时间步都只保留了 k=2 个节点往下继续搜索。最后"pie"对应的路径得分最高,通过回溯法得到概率最高的翻译句子。

请注意,集束搜索作为一种剪枝策略,并不能保证得到全局最优解,但它能以较大的概率得到全局最优解,同时相比穷举搜索,它极大地提高了搜索效率。

下面我们也通过一个简单的例子来演示对给定的概率序列和参数 k 用集束搜索方法进行解码的过程。所得到的概率分布与贪心搜索方法一致,完整的样例代码如下:

from math import log
from numpy import array
from numpy import argmax# 集束搜索
def beam_search_decoder(data, k):sequences = [[list(), 1.0]]# 遍历序列中的每一步for row in data:all_candidates = list()# 扩展每个候选项for i in range(len(sequences)):seq, score = sequences[i]for j in range(len(row)):candidate = [seq + [j], score * -log(row[j])]all_candidates.append(candidate)# 根据分数排列所有候选项ordered = sorted(all_candidates, key=lambda tup:tup[1])# 选择k个最有可能的sequences = ordered[:k]return sequences# 定义一个由10个单词组成的序列,单词来自于大小为5的词汇表
data = [[0.1, 0.2, 0.3, 0.4, 0.5],[0.5, 0.4, 0.3, 0.2, 0.1],[0.1, 0.2, 0.3, 0.4, 0.5],[0.5, 0.4, 0.3, 0.2, 0.1],[0.1, 0.2, 0.3, 0.4, 0.5],[0.5, 0.4, 0.3, 0.2, 0.1],[0.1, 0.2, 0.3, 0.4, 0.5],[0.5, 0.4, 0.3, 0.2, 0.1],[0.1, 0.2, 0.3, 0.4, 0.5],[0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)
# 解码输出序列
result = beam_search_decoder(data, 3)
# 打印结果
for seq in result:print(seq)

运行这段样例程序会输出 k 个整型数字(数字代码单词在字典中的索引值)序列以及对应的 log 概率值。

[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108]
[[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397]
[[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]

最后,需要注意的是当集束搜索过程结束时,需要从 n 条候选路径中选一个得分最高的路径作为最终结果。由于不同路径的长度不一样,累加越多得分越低,所以需要用长度对得分进行归一化。归一化公式如下:

1 t ∑ i = 1 t log ⁡ P L M ( y i ∣ y 1 , … , y i − 1 , x ) \frac{1}{t} \sum_{i=1}^t \log P_{LM}(y_i | y_1, \dots, y_{i-1}, x) t1​i=1∑t​logPLM​(yi​∣y1​,…,yi−1​,x)


想要了解更多的自然语言处理最新进展、技术干货及学习教程,欢迎关注微信公众号“语言智能技术笔记簿”或扫描二维码添加关注。

Seq2Seq模型中的序列解码策略相关推荐

  1. Seq2Seq模型中的label bias和exposure bias问题

    从序列到序列的seq2seq模型中,存在着label bias和exposure bias问题.这两个偏差问题是由于不同的原因导致的.先给出结论在分别解释 label bias:根本原因是我们真实的目 ...

  2. Seq2Seq模型中的集束搜索(Beam Search)

    1. 引入 用Seq2Seq模型开发翻译系统时,假设输入一句法语,输出英文.在Decoder输出部分,选择不同的单词,输出(翻译)的结果也会不同. 这里用下图来举例说明: 一个法语句子,被Seq2Se ...

  3. Seq2Seq模型中的贪心搜索(Greedy Search)

    1. 引入 用Seq2Seq模型开发翻译系统时,假设输入一句法语,输出英文.在Decoder输出部分,选择不同的单词,输出(翻译)的结果也会不同. 这里用下图来举例说明: 一个法语句子,被Seq2Se ...

  4. pytorch seq2seq模型中加入teacher_forcing机制

    在循环内加的teacher forcing机制,这种为目标确定的时候,可以这样加. 目标不确定,需要在循环外加. decoder.py 中的修改 """ 实现解码器 &q ...

  5. 程序化模型中常用的止损策略

    一.价差止损 最新价与基准价之间的价差触发设定的条件时进行止损平仓.我们将以资金盈亏额为条件的止损策略也归为这一类.比较常用的策略有限价止损.追踪止损.阶梯止损等. 优秀的止损策略,既要避免被无谓的随 ...

  6. 【Attention】深度学习中的注意机制:理解序列模型中的注意机制How Attention works in Deep Learning

    [学习资源] How Attention works in Deep Learning: understanding the attention mechanism in sequence model ...

  7. 浅谈文本生成或者文本翻译解码策略《转》

    原文链接,感谢原作者 目录 一.文本生成和翻译的基本流程 翻译类模型的训练和解码 训练过程 解码过程 生成类模型的训练和解码(GPT系列) 训练过程 解码过程 二.解码策略 1.贪心搜索(greedy ...

  8. 浅谈文本生成或者文本翻译解码策略

    目录 一.文本生成和翻译的基本流程 翻译类模型的训练和解码 训练过程 解码过程 生成类模型的训练和解码(GPT系列) 训练过程 解码过程 二.解码策略 1.贪心搜索(greedy search) 2. ...

  9. 构建seq2seq模型的常见问题

    1. seq2seq模型,输入是一个词向量,而不是词向量列表,对吧? 是的,对于seq2seq模型,输入和输出都需要被转换成词向量形式. 对于输入来说,通常会将一个句子转换成一个词向量序列.具体地,对 ...

最新文章

  1. php跳过当前后续代码,PHP用continue跳过本次循环中剩余代码的注意点
  2. 由逻辑异或运算符而发现的PHP诡异运算符优先级
  3. 详解Vue中watch的高级用法
  4. 山西台达plc可编程控制器_可编程控制器2(PLC)控制原理
  5. 线程池最佳线程数量到底要如何配置?
  6. iPhone 12无线充电模块曝光:AirPower有戏了!
  7. 2015-12-02 计划任务维护数据库
  8. Facebook如何“养号”干货分享
  9. 嵌入式系统的接口类型有哪些
  10. 人工智能之不确定推理方法
  11. linux系统查看u盘容量,在LINUX系统中编程查询U盘或软盘格式信息:总容量、空余容量、已用容量等。...
  12. html5 midi,源码:MIDI 文件生成音乐乐谱(Midi To Sheet Music)
  13. Android11对比IOS14,iPhone11升级至iOS14,对比苹果iOS13,迎来3大新变化
  14. 【博客117】内核如何巧妙实现:min与max函数
  15. 计算机主板怎么学,初学者怎么去记住和理解电脑主板?
  16. 机器学习算法之——走近卷积神经网络(CNN)
  17. 数据结构-启发式算法(隐式图搜索)
  18. 2021年P气瓶充装报名考试及P气瓶充装考试内容
  19. 电脑保护眼睛的颜色设置-把你的电脑窗口颜色设置为淡淡的绿色
  20. 充电器一个就够了:倍思GaN2 Pro氮化镓充电器

热门文章

  1. JVM虚拟机——初识
  2. windows.edb文件过大,导致c盘空间极小问题
  3. Android 实现人脸识别
  4. 用png格式图片和非png格式图片做水印图片
  5. JDBC规范——(3)新特性
  6. 自媒体创作怎么细分领域?怎么进行选题?
  7. Codeforces Round #644 (Div. 3) H.Binary Median
  8. 怎么设置可以将资料横向打印出来
  9. 京东内网遭开源的“顶级”SpringCloud实战手册,GitHub列为首推
  10. JAVA实现导出mysql表结构到Word详细注解版