最近研究了一下用基于BERT的encoder-decoder结构做文本生成任务,碰巧管老师昨天的文章也介绍了以生成任务见长的GPT模型,于是决定用两篇文章大家介绍一下在文本生成任务中常用的解码策略Beam Search(集束搜索)。

解码及贪心搜索

生成式任务相比普通的分类、tagging等NLP任务会复杂不少。在生成的时候,模型的输出是一个时间步一个时间步依次获得的,而且前面时间步的结果还会影响后面时间步的结果。也就是说,每一个时间步,模型给出的都是基于历史生成结果的条件概率。为了生成完整的句子,需要一个称为解码的额外动作来融合模型多个时间步的输出,而且使得最终得到的序列的每一步条件概率连乘起来最大。

在文本生成任务中,每一个时间步可能的输出种类称为字典大小(vocabulary size,我们用 表示),进行T步随机的生成可能获得的结果总共有 种。拿中文文本生成来说, 的值大约是5000-6000,即常用汉字的个数。在如此大的基数下,遍历整个生成空间是不现实的。

最容易想到的策略是贪心搜索,即每一个时间步都取出一个条件概率最大的输出,再将从开始到当前步的结果作为输入去获得下一个时间步的输出,直到模型给出生成结束的标志。例如下图,每一个时间步都取出了条件概率最大一个结果,生成了序列[A,B,C]

贪心搜索示意图

很明显,这样做将原来指数级别的求解空间直接压缩到了与长度线性相关的大小。由于丢弃了绝大多数的可能解,这种关注当下的策略无法保证最终得到的序列概率是最优的。

Beam Search

而beam search是对贪心策略一个改进。思路也很简单,就是稍微放宽一些考察的范围。在每一个时间步,不再只保留当前分数最高的1个输出,而是保留num_beams个。当num_beams=1时集束搜索就退化成了贪心搜索。

下图是一个实际的例子,每个时间步有ABCDE共5种可能的输出,即 ,图中的num_beams=2,也就是说每个时间步都会保留到当前步为止条件概率最优的2个序列。

beam search示意图
  • 在第一个时间步,A和C是最优的两个,因此得到了两个结果[A],[C],其他三个就被抛弃了;

  • 第二步会基于这两个结果继续进行生成,在A这个分支可以得到5个候选人,[AA],[AB],[AC],[AD],[AE],C也同理得到5个,此时会对这10个进行统一排名,再保留最优的两个,即图中的[AB][CE]

  • 第三步同理,也会从新的10个候选人里再保留最好的两个,最后得到了[ABD],[CED]两个结果。

可以发现,beam search在每一步需要考察的候选人数量是贪心搜索的num_beams倍,因此是一种牺牲时间换性能的方法。

以上就是Beam Search的基本概念,下面我们解析一种高效率实现方式。

Beam Search代码解析

Beam Search的原理虽然简单,但实际实现的时候却有很多细节要考虑。下面要解析这个实现出自于NLP界著名Python包Transformers[1],我为了说明方便做了一些改动。

一个正确且高效的算法需要处理的问题大概有两个:

  • 充分利用硬件,可以处理批量数据,且尽量使用并行计算少用循环

  • 处理好长短不同的生成结果

下面是基础版的beam search函数定义。其中context是编码器编码获得的向量,batch_size是每批数据中包含的样本量,bos_token_id是句子开头标志的token id,pad_token_id是用于填充的token id,eos_token_id是句子结束标志的token id。这里给参数填上的默认值和我们后面讲解时使用的例子是一致的。

def beam_search_generate(context,batch_size=3,max_length=20,min_length=2,num_beams=2,bos_token_id=101,pad_token_id=0,eos_token_id=102,):pass

在函数中主要执行以下三个步骤:

  • 准备初始输入

  • 在当前生成的序列长度未达到max_length时扩展生成序列

  • 准备最终输出的序列

下面我们分别解析。

准备初始输入

# 建立beam容器,每个样本一个
generated_hyps = [BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)for _ in range(batch_size)
]# 每个beam容器的得分,共batch_size*num_beams个
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=encoder_input_ids.device)
beam_scores = beam_scores.view(-1)# 每个样本是否完成生成,共batch_size个
done = [False for _ in range(batch_size)]# 为了并行计算,一次生成batch_size*num_beams个序列
# 第一步自动填入bos_token
input_ids = torch.full((batch_size*num_beams, 1),bos_token_id,dtype=torch.long,device=next(self.parameters()).device,
)# 当前长度设为1
cur_len = 1

其中BeamHypotheses是一个容器类,每个样本绑定一个。每个容器中会维护num_beams个当前最优的序列。当往容器中添加一个序列而导致序列数大于num_beams的时候,它会自动踢掉分数最低的那个序列。类代码如下。

class BeamHypotheses(object):def __init__(self, num_beams, max_length, length_penalty):self.max_length = max_length - 1   # ignoring bos_tokenself.num_beams = num_beamsself.beams = []self.worst_score = 1e9def __len__(self):return len(self.beams)def add(self, hyp, sum_logprobs):score = sum_logprobs / len(hyp) ** self.length_penaltyif len(self) < self.num_beams or score > self.worst_score:# 可更新的情况:数量未饱和或超过最差得分self.beams.append((score, hyp))if len(self) > self.num_beams:# 数量饱和需要删掉一个最差的sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])del self.beams[sorted_scores[0][1]]self.worst_score = sorted_scores[1][0]else:self.worst_score = min(score, self.worst_score)def is_done(self, best_sum_logprobs, cur_len=None):"""相关样本是否已经完成生成。best_sum_logprobs是新的候选序列中的最高得分。"""if len(self) < self.num_beams:return Falseelse:if cur_len is None:cur_len = self.max_lengthcur_score = best_sum_logprobs / cur_len ** self.length_penalty# 是否最高分比当前保存的最低分还差ret = self.worst_score >= cur_scorereturn ret

序列扩展

序列扩展是beam search的核心过程,我们特地画了一张图来解释这个版本的实现策略。

序列扩展示意图

下面对照这个图来讲解代码。

while cur_len < max_length:# 将编码器得到的上下文向量和当前结果输入解码器,即图中1output = decoder.decode_next_step(context, input_ids)# 输出矩阵维度为:(batch*num_beams)*cur_len*vocab_size# 取出最后一个时间步的各token概率,即当前条件概率# (batch*num_beams)*vocab_sizescores = next_token_logits = output[:, -1, :]############################ 这里可以做一大堆操作减少重复 ############################# 计算序列条件概率的,因为取了log,所以直接相加即可。得到图中2矩阵# (batch_size * num_beams, vocab_size)next_scores = scores + beam_scores[:, None].expand_as(scores)# 为了提速,将结果重排成图中3的形状next_scores = next_scores.view(batch_size, num_beams * vocab_size)  # (batch_size, num_beams * vocab_size)# 取出分数最高的token(图中黑点)和其对应得分# sorted=True,保证返回序列是有序的next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)# 下一个时间步整个batch的beam列表# 列表中的每一个元素都是三元组# (分数, token_id, beam_id)next_batch_beam = []# 对每一个样本进行扩展for batch_idx in range(batch_size):# 检查样本是否已经生成结束if done[batch_idx]:# 对于已经结束的句子,待添加的是pad tokennext_batch_beam.extend([(0, pad_token_id, 0)] * num_beams)  # pad the batchcontinue# 当前样本下一个时间步的beam列表next_sent_beam = []# 对于还未结束的样本需要找到分数最高的num_beams个扩展# 注意,next_scores和next_tokens是对应的# 而且已经按照next_scores排好顺序for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(zip(next_tokens[batch_idx], next_scores[batch_idx])):# get beam and word IDs# 这两行可参考图中3进行理解beam_id = beam_token_id // vocab_sizetoken_id = beam_token_id % vocab_sizeeffective_beam_id = batch_idx * num_beams + beam_id# 如果出现了EOS token说明已经生成了完整句子if (eos_token_id is not None) and (token_id.item() == eos_token_id):# if beam_token does not belong to top num_beams tokens, it should not be addedis_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beamsif is_beam_token_worse_than_top_num_beams:continue# 往容器中添加这个序列generated_hyps[batch_idx].add(input_ids[effective_beam_id].clone(), beam_token_score.item(),)else:# add next predicted word if it is not eos_tokennext_sent_beam.append((beam_token_score, token_id, effective_beam_id))# 扩展num_beams个就够了if len(next_sent_beam) == num_beams:break# 检查这个样本是否已经生成完了,有两种情况# 1. 已经记录过该样本结束# 2. 新的结果没有使结果改善done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(next_scores[batch_idx].max().item(), cur_len=cur_len)# 把当前样本的结果添加到batch结果的后面next_batch_beam.extend(next_sent_beam)# 如果全部样本都已经生成结束便可以直接退出了if all(done):break# 把三元组列表再还原成三个独立列表beam_scores = beam_scores.new([x[0] for x in next_batch_beam])beam_tokens = input_ids.new([x[1] for x in next_batch_beam])beam_idx = input_ids.new([x[2] for x in next_batch_beam])# 准备下一时刻的解码器输入# 取出实际被扩展的beaminput_ids = input_ids[beam_idx, :]# 在这些beam后面接上新生成的tokeninput_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)# 更新当前长度cur_len = cur_len + 1# end of length while

乍一看是不是有些复杂,我感觉关键的有以下几点:

  • 只有出现了EOS token才会将生成的序列装进该样本对应的容器中

  • 当前input_ids保存着当前得分最高的num_beams个序列

准备输出

上面那个while循环跳出意味着已经生成了长度为max_length的文本,比较理想的情况是所有的句子都已经生成出了eos_token_id,即句子生成结束了。但并不是所有情况都这样,对于那些”意犹未尽“的样本,我们需要先手动结束。

# 将未结束的生成结果结束,并置入容器中
for batch_idx in range(batch_size):# 已经结束的样本不需处理if done[batch_idx]:continue# 把结果加入到generated_hyps容器for beam_id in range(num_beams):effective_beam_id = batch_idx * num_beams + beam_idfinal_score = beam_scores[effective_beam_id].item()final_tokens = input_ids[effective_beam_id]generated_hyps[batch_idx].add(final_tokens, final_score)

经过上面的处理,所有生成好的句子都已经保存在generated_hyps容器中,每个容器内保存着num_beams个序列,最后就是输出期望个数的句子。

# select the best hypotheses,最终输出
# 每个样本返回几个句子
output_num_return_sequences_per_batch = 1
# 记录每个返回句子的长度,用于后面pad
sent_lengths = input_ids.new(output_batch_size)
best = []# 对每个样本取出最好的output_num_return_sequences_per_batch个句子
for i, hypotheses in enumerate(generated_hyps):sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])for j in range(output_num_return_sequences_per_batch):effective_batch_idx = output_num_return_sequences_per_batch * i + jbest_hyp = sorted_hyps.pop()[1]sent_lengths[effective_batch_idx] = len(best_hyp)best.append(best_hyp)# 如果长短不一则pad句子,使得最后返回结果的长度一样
if sent_lengths.min().item() != sent_lengths.max().item():sent_max_len = min(sent_lengths.max().item() + 1, max_length)# 先把输出矩阵填满PAD tokendecoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)# 填入真正的内容for i, hypo in enumerate(best):decoded[i, : sent_lengths[i]] = hypo# 填上eos tokenif sent_lengths[i] < max_length:decoded[i, sent_lengths[i]] = eos_token_id
else:# 所有生成序列都还没结束,直接堆叠即可decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)# 返回的结果包含BOS token
return decoded

总结

好了,上面就是最基础的beam search算法。这样生成出来的结果已经会比贪心搜索好一些,但还是会遇到诸如词语重复这样的问题。其实已经有很多针对重复问题的研究,我们在代码中也已经留出了位置,下期再见咯。

参考资料

[1]

Transformers: https://github.com/huggingface/transformers

个人微信:加时请注明 (昵称+公司/学校+方向)

十分钟读懂Beam Search(1/2)相关推荐

  1. 十分钟读懂游戏研发、发行、渠道那些事儿

    国庆在家写了7天东西,实在是累得够呛.我重新梳理了一下以前做过的事儿,正好把去年年初发到腾讯GAD的文章拿过来做个修改补充,算是再做个总结吧. 这篇文章主要是介绍游戏行业的上下游产业链有哪些玩家,游戏 ...

  2. 十分钟读懂『卡尔曼滤波算法』

    我是勤劳的搬运工,转自: 1.http://blog.csdn.net/karen99/article/details/7771743 2.http://blog.csdn.net/tudouniur ...

  3. 【干货】十分钟读懂浏览器渲染流程

    在之前写过的一篇<"天龙八步"细说浏览器输入URL后发生了什么>一文中,和大家分享了从在浏览器中输入网址URL到最终页面展示的整个过程.部分读者向我反馈对于最后的浏览器 ...

  4. 产品入门一——十分钟读懂产品经理

    1.用一个故事来引入我们今天的话题:  2.在这个故事中,对于产品经理的解释和补充说明!!!! 1)发现市场需求 补充说明:产品经理要拥有一个善于去发现需求,什么是需求,问题即需求,现实生活中你遇到了 ...

  5. 【十分钟读懂系列】之什么是SLF,PSL,MLF,SLO?

    受国际经济金融形势不确定性增强以及各种影响流动性的因素波动较大影响,近年来我国银行体系短期流动性供求的波动性有所加大,尤其是当多个因素相互叠加或市场预期发生变化时,有可能出现市场短期资金供求缺口难以通 ...

  6. 十分钟读懂AES加密算法

    偶阅博客一篇,漫画式的讲解十分有趣,故转之. 原文地址:https://blog.csdn.net/lrwwll/article/details/78069013 ------------------ ...

  7. 十分钟读懂『K-Means 算法』

    我是勤劳的搬运工,转载自:http://coolshell.cn/articles/7779.html ------------------------------------------------ ...

  8. 通俗易懂,十分钟读懂DES,详解DES加密算法原理,DES攻击手段以及3DES原理。Python DES实现源码

    文章目录 1.什么是DES 2.DES的基本概念 3.DES的加密流程 4.DES算法步骤详解 4.1 初始置换(Initial Permutation,IP置换) 4.2 加密轮次 4.3 F轮函数 ...

  9. 十分钟看懂图像语义分割技术

    转载于:十分钟看懂图像语义分割技术 大多数人接触"语义"都是在和文字相关的领域,或语音识别,期望机器能够识别你发出去的消息或简短的语音,然后给予你适当的反馈和回复.嗯,看到这里你应 ...

最新文章

  1. 多任务的介绍(并发、并行)
  2. Windows应用程序进程级别统一监控实践
  3. 安装yaml报错:ERROR: Cannot uninstall 'PyYAML'.
  4. C# OracleParameter 传参 实例
  5. 算法设计7—哈希表1
  6. 使用工具快速找出custom work center使用的ui component
  7. .NET Core+Selenium+Github+Travis CI =amp;gt; SiteHistory
  8. 讲解Linux数据库安装
  9. 好物推荐|下载超过 23w 次的 IDE 插件,让效率飞速提升
  10. Python——PyCharm常用快捷键
  11. Swift新手教程3-字符串String
  12. 如何最大程度地提高cin和cout的效率
  13. MATLAB 四点定球及三点定圆(完整代码)
  14. 《JAVASE系列》一个小小的图书管理系统
  15. 苹果真伪查询_汇课堂:再见盗版MT4!独家揭秘5种方法辨别真伪MT4
  16. 支持M1芯片的Photoshop 2021安装教程 已经支持M1芯片ARM M1处理器安装PS2021解决方案教程 最新方法!
  17. 你需要掌握的 Koa 洋葱模型和中间件
  18. 网站卡其cdn后不能访问_网站使用CDN加速后,网站无法访问如何解决
  19. 65Z5芯片,65Z5三极管,稳压输出3V的IC资料
  20. Android按键音无效

热门文章

  1. 质量好的自行车品牌有哪些辐轮王土拨鼠全球顶级自行车品牌排行榜
  2. 【EF框架】聊一聊EF框架
  3. 欧拉公式:世界上最完美的公式(上帝公式) 复变函数 、平面几何 、拓扑学、 初等数论、 物理学
  4. GitHub图片加载失败原因追究及解决方案
  5. python中map函数的作用是_python中map()函数
  6. 【element】InputNumber计数器 动态渲染设置默认值后加减号失效问题
  7. Golang环境配置步骤
  8. 【TechJoy第三期】在白云山,住着广州的秋天
  9. 使用chatgpt写6.5分作文范文
  10. 移动互联的未来,谁在紧握命运的咽喉?