引言

Beam search是一种动态规划算法,能够极大的减少搜索空间,增加搜索效率,并且其误差在可接受范围内,常被用于Sequence to Sequence模型,CTC解码等应用中

时间复杂度

对于 T × N T\times N T×N的时间序列,如果我们要遍历所有可能能,则其所需的时间复杂度为 O ( N + N 2 + N 3 + . . . + N T ) \mathcal{O}(N+N^2+N^3+...+N^T) O(N+N2+N3+...+NT),在每一时间节点,所需遍历的节点数呈指数增加。对于Viterbi算法来说,时间复杂度为 O ( N + ( T − 1 ) N 2 ) \mathcal{O}(N+(T-1)N^2) O(N+(T−1)N2),在每个时间节点输入为N个best节点,需要比较的次数为 N 2 N^2 N2,然而这个时间复杂度还是太高。在N比较大的情况下,Beam Search为更好的选择,其时间复杂度为 O ( N + ( T − 1 ) ∗ b e a m s i z e ∗ N ) \mathcal{O}(N+(T-1)*beamsize*N) O(N+(T−1)∗beamsize∗N),每个时间节点的输入为beamsize个best节点,需要比较的次数为 b e a m s i z e ∗ N beamsize*N beamsize∗N

常规Beam Search (BS)


如上图所示,常规的beam search在每个时间节点,对输入的每个节点比较N次,并从 b e a m s i z e ∗ N beamsize*N beamsize∗N个比较结果中,选择 b e a m s i z e beamsize beamsize个结果作为下一时间节点的输入,其python的简单实现如下

import numpy as np
import mathdef beam_search(nodes, topk=1):# log-likelihood可以相加paths = {'A':math.log(nodes[0]['A']), 'B': math.log(nodes[0]['B']), 'C':math.log(nodes[0]['C'])}calculations = []for l in range(1, len(nodes)):# 拷贝当前路径paths_ = paths.copy()paths = {}nows = {}cur_cal = 0for i in nodes[l].keys():# 计算到达节点i的所有路径for j in paths_.keys():nows[j+i] = paths_[j]+math.log(nodes[l][i])cur_cal += 1calculations.append(cur_cal)# 选择topk条路径indices = np.argpartition(list(nows.values()), -topk)[-topk:]# 保存topk路径for k in indices:paths[list(nows.keys())[k]] = list(nows.values())[k]print(f'calculation number {calculations}')return pathsnodes = [{'A':0.1, 'B':0.3, 'C':0.6}, {'A':0.2, 'B':0.4, 'C':0.4}, {'A':0.6, 'B':0.2, 'C':0.2},{'A': 0.3, 'B': 0.3, 'C': 0.4}]
print(beam_search(nodes, topk=2))输出结果:
calculation number [9, 6, 6]
{'CBAA': -3.1419147837320724, 'CBAC': -2.854232711280291, 'CCAC': -2.854232711280291}

我们可以看到,在 N = 3 N=3 N=3, b e a m s i z e = 2 beamsize=2 beamsize=2的情况下,每个节点的比较次数为6。

Prefix(前缀)Beam Search (PBS)

在CTC算法中,由于添加了blank以及重复字符串无blank合并的规则,例如ab可能aab,abb,a blank b等多种情况的输入,因此ab的可能性应该为多种情况log概率之和,而不能通过单条beam进行搜索,因此可以采用改进版的prefix beam search,其代码如下

"""
Code from https://gist.github.com/awni/56369a90d03953e370f3964c826ed4b0
Author: Awni Hannun
CTC decoder in python, 简单例子可能不太效率
用于CTC模型的输出的前缀beam search
更多细节参考https://distill.pub/2017/ctc/#inferencehttps://arxiv.org/abs/1408.2873
"""import numpy as np
import math
import collectionsNEG_INF = -float("inf")def make_new_beam():fn = lambda: (NEG_INF, NEG_INF)return collections.defaultdict(fn)def logsumexp(*args):"""Stable log sum exp."""if all(a == NEG_INF for a in args):return NEG_INFa_max = max(args)lsp = math.log(sum(math.exp(a - a_max)for a in args))return a_max + lspdef decode(probs, beam_size=100, blank=0):"""对给定输出概率进行预测Arguments:probs: 输出概率 (e.g. post-softmax) for eachtime step. Should be an array of shape (time x output dim).beam_size (int): Size of the beam to use during inference.blank (int): Index of the CTC blank label.Returns the output label sequence and the corresponding negativelog-likelihood estimated by the decoder."""T, S = probs.shapeprobs = np.log(probs)# 在beam中的元素为(prefix, (p_blank, p_no_blank))# 初始beam为空序列,第一个是前缀,第二个是后接blank的log概率,第三个是后接非blank的log概率# 我们需要后接blank和后接非blank两种情况,来区分重复字符是否应该被合并,对于后接blank的情况,重复字符就不会被合并beam = [(tuple(), (0.0, NEG_INF))]for t in range(T):  # 沿时间维度循环# 存储下一个候选集的预设置字典,每次新的时间节点都会重设next_beam = make_new_beam()for s in range(S):  # 沿词表维度循环p = probs[t, s]# p_b和p_nb分别为在当前时刻下前缀后接blank和非blank的log概率for prefix, (p_b, p_nb) in beam:  # 对beam进行循环# 如果s为blank,那么前缀不会改变# 因为后接的是blank,所以只需要更新前缀不变的情况下后接blank的log概率if s == blank:n_p_b, n_p_nb = next_beam[prefix]n_p_b = logsumexp(n_p_b, p_b + p, p_nb + p)next_beam[prefix] = (n_p_b, n_p_nb)continue# 记录前缀最后一个字符,用于判断当前字符与前缀最后一个字符是否相同end_t = prefix[-1] if prefix else Nonen_prefix = prefix + (s,)  # n_prefix代表next prefixn_p_b, n_p_nb = next_beam[n_prefix]  # n_p_b代表 next probability of blank# 将新的字符s加到prefix后面并将整体加入到beam中# 因为后接的是非blank,所以只需要更新后接非blank的log概率if s != end_t:n_p_nb = logsumexp(n_p_nb, p_b + p, p_nb + p)else:# 如果后接s是重复的,那么我们在更新后接非blank的log概率时,# 不包括上一时刻后接非blank的概率。CTC算法会合并没有用blank分隔的重复字符n_p_nb = logsumexp(n_p_nb, p_b + p)# 这里是加入语言模型分数的好地方next_beam[n_prefix] = (n_p_b, n_p_nb)# 这是合并的情况,如果s重复出现了,前缀也不会改变,我们也更新前缀不变的情况下后接非blank的log概率if s == end_t:n_p_b, n_p_nb = next_beam[prefix]n_p_nb = logsumexp(n_p_nb, p_nb + p)next_beam[prefix] = (n_p_b, n_p_nb)# 在进入下一时间步之前,排序并裁剪beambeam = sorted(next_beam.items(),key=lambda x: logsumexp(*x[1]),reverse=True)beam = beam[:beam_size]best = beam[0]return best[0], -logsumexp(*best[1])if __name__ == "__main__":np.random.seed(3)time = 50output_dim = 20probs = np.random.rand(time, output_dim)probs = probs / np.sum(probs, axis=1, keepdims=True)labels, score = decode(probs)print(labels)print("Score {:.3f}".format(score))

与常规BS不同的地方主要在于, PBS区分了几种情况以及log probability的计算方式

  1. 对于BS来说, l o g l i k e l i h o o d = l o g ( p 1 ) + l o g ( p 2 ) + . . . loglikelihood=log(p1)+log(p2)+... loglikelihood=log(p1)+log(p2)+...,对于PBS来说,由于区分了存在blank和不存在blank的情况,并且其中之一的可能性为0,相加log probability等于负无穷的情况,因此不能直接相加,所以采用了一种稳定的logsumexp的方式来计算loglikelihood
  2. 当前缀后接blank时,前缀不变,更新当前前缀后接blank的log概率:
    n _ p _ b = l o g s u m e x p ( n _ p _ b , p _ b + p , p _ n b + p ) n\_p\_b = logsumexp(n\_p\_b, p\_b + p, p\_nb + p) n_p_b=logsumexp(n_p_b,p_b+p,p_nb+p)
  3. 当前缀后接重复字符且中间没有blank隔开时,前缀也不变,更新当前前缀后接非blank的log概率:
    n _ p _ n b = l o g s u m e x p ( n _ p _ n b , p _ n b + p ) n\_p\_nb = logsumexp(n\_p\_nb, p\_nb + p) n_p_nb=logsumexp(n_p_nb,p_nb+p)
  4. 当前缀后接不同字符时,前缀变化,更新当前前缀后接非blank的log概率:
    n _ p _ n b = l o g s u m e x p ( n _ p _ n b , p _ b + p , p _ n b + p ) n\_p\_nb = logsumexp(n\_p\_nb, p\_b + p, p\_nb + p) n_p_nb=logsumexp(n_p_nb,p_b+p,p_nb+p)
  5. 当前缀后接重复字符,且中间有blank隔开,前缀变化,更新当前前缀后接非blank的log概率:
    n _ p _ n b = l o g s u m e x p ( n _ p _ n b , p _ b + p ) n\_p\_nb = logsumexp(n\_p\_nb, p\_b + p) n_p_nb=logsumexp(n_p_nb,p_b+p)

总结

BS根据不同的场景可以有不同的写法,其主要目的在于在每个时间点选择TOPK的路径继续搜索,达到增加搜索效率的目的,在BS的搜索过程中,如果是生成字符串,我们还可以加入语言模型的分数,得到更好的结果:
Y ∗ = l o g P ( Y ∣ X ) + α l o g ( P l m ( Y ) ) + β l e n ( Y ) Y^*=logP(Y|X)+\alpha log(P_{lm}(Y))+\beta len(Y) Y∗=logP(Y∣X)+αlog(Plm​(Y))+βlen(Y)
语言模型的加入地方一般为字符串扩增时。

参考

Sequence Modeling With CTC

Beam Search与Prefix Beam Search的理解与python实现相关推荐

  1. CTC 解码算法之 prefix beam search

    ctc prefix beam search 算法 CTC 网络的输出 net_out 形状为 T×C,其中 T 是时间长度,C 是字符类别数加1(额外的blank). CTC 的 beam sear ...

  2. php字符串search,js获取location.search每个查询字符串的值

    形如https://www.debug.org/temp/test2.html?a=1&b=2#ddd这样的链接,虽可通过location.search属性获取到问号后的所有查询字符串值,但要 ...

  3. 04-树7. Search in a Binary Search Tree (25)

    04-树7. Search in a Binary Search Tree (25) 时间限制 100 ms 内存限制 65536 kB 代码长度限制 8000 B 判题程序 Standard 作者 ...

  4. pat04-树7. Search in a Binary Search Tree (25)

    04-树7. Search in a Binary Search Tree (25) 时间限制 100 ms 内存限制 65536 kB 代码长度限制 8000 B 判题程序 Standard 作者 ...

  5. python search用法,Python-re中search()函数的用法详解(查找ip)

    1.首先来看一下search()和find()的区别 import re s1 = "2221155" #search 字符串第一次出现的位置 print(re.search(&q ...

  6. android beam华为,Huawei beam是什么 Huawei beam使用方法【图文】

    很多人在使用华为手机时,会发现华为手机有一个Huawei beam的功能,那么Huawei beam是什么呢?Huawei beam怎么用?下面来看看详细介绍吧! Huawei beam是什么: 想要 ...

  7. android beam传输速率,三星S Beam 与Android Beam有什么不同

    前段时间陪朋友去买手机,店里的销售人员向我们推销三星的某台手机,说"这台手机拥有S Beam 功能,只要碰一下别人的手机,就可以通过蓝牙将视频传送到对方的手机中,速度能够达到300MB/秒. ...

  8. python 非线性回归_机器学习入门之菜鸟之路——机器学习之非线性回归个人理解及python实现...

    本文主要向大家介绍了机器学习入门之菜鸟之路--机器学习之非线性回归个人理解及python实现,通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助. 梯度下降:就是让数据顺着梯度最大的方向,也 ...

  9. 深度学习中IU、IoU(Intersection over Union)的概念理解以及python程序实现

    from: 深度学习中IU.IoU(Intersection over Union)的概念理解以及python程序实现 IoU(Intersection over Union) Intersectio ...

最新文章

  1. 凡客诚品成都研发中心招聘.net开发经理
  2. linkedin总共能加30000个好友
  3. Python1:if / while / for...in / break /continue
  4. JZOJ 5048. 【GDOI2017模拟一试4.11】IQ测试
  5. 智能机器人及其应用ppt课件_机器人视觉技术在建筑智能化生产中的应用
  6. python编程高手教程_写给编程高手的Python教程(01) 数据结构
  7. react前端封装接口弹出错误_react+ts打包发布后报Minified React error ..这种错误
  8. Java 加密扩展(JCE)框架 之 Cipher 加密与解密
  9. windows10系统下设置mtu值的方法
  10. matlab 平滑曲线连接_科研画图-率失真曲线图改进:散点连接成曲线并画出原散点的标记点(基于Matlab)...
  11. 基本积分表的联想记忆
  12. matlab randn 范围,matlab randn 范围
  13. 【python】LOFTER抽奖程序
  14. Java计算每月工作天数
  15. SpringBoot整合Graylog3.0
  16. Easyrecovery13 for mac 易恢复软件 官方中文版下载
  17. oCPC实践录 | 随你千变万化,oCPC PID控制(2)
  18. 江西省中职计算机简答题,江西省中等职业学校第八届技能竞赛节计算机类专业竞赛模拟试题(CAD、CAM软件应用)...
  19. 图像处理之平滑滤波、高斯滤波和中值滤波
  20. 翻译:An Introduction to Feature Extraction 特征提取导论。(如有不当欢迎评论区留言指正)

热门文章

  1. 文明城市测评怎么进行?这里有答案
  2. CEC循环生态社区答疑XAG到底有多好的价值前景
  3. C语言 输入n,输出n各位数字之和
  4. rapidminer员工离职分析_RapidMiner 9从根本上简化了分析团队的数据准备工作
  5. JS人民币小写金额转换为大写(没毛病)
  6. Android Application Fundamentals——Android应用程序基础知识
  7. 跟着老陈学嵌入式-C语言入门之类Linux编译环境搭建
  8. CommandName属性和CommandArgument属性[转]
  9. 如何隐藏控制台程序的窗口
  10. 转载 Package CJK Error: Invalid character code错误