目录

一、文本生成和翻译的基本流程

翻译类模型的训练和解码

训练过程

解码过程

生成类模型的训练和解码(GPT系列)

训练过程

解码过程

二、解码策略

1、贪心搜索(greedy search)

2、beam_search集束搜索

3、随机sampling

4、Top-K Sampling和Top-p (nucleus) sampling

Top-K Sampling

Top-p (nucleus) sampling

三、transformer中的解码使用


文本生成和文本翻译的效果不仅仅在于模型层面的好坏,同时预测阶段的解码策略也是比较重要,不同的解码策略得出的效果也是不同的。经过学者们多年的研究,目前就我所知的文本生成相关的解码策略主要有贪心搜索(greedy search)、beam_search集束搜索、随机sampling、top-k sampling和Top-p Sampling,今天我们主要聊聊这几种文本解码策略算法。

一、文本生成和翻译的基本流程

翻译类模型的训练和解码

训练过程

翻译类任务的流程是一个src输入对应一个tag输入,一般而言,src长度和tag长度不一样的;一个简单的流程图如下图所示:

模型训练的结果是和tag长度一样的一个向量,output[T,B,D]经过一个分类全连接层得到[T,B]的概率分布,这个就和tag的输入[T,B]计算loss;

解码过程

如下图所示,模型训练好以后,解码的初始就是src的embedding加上tag端的起始字符<cls>等特殊的字符,解码输出得到第一个字符token然后把这个token添加到tag端输入,继续解码得到第二个token......重复不断的解码,每一次解码都是需要过一次模型推理,所以比较耗时;只到碰到结束字符或者最大长度。

生成类模型的训练和解码(GPT系列)

训练过程

GPT模型的训练过程直接输入一段自然文本,然后输出其embedding,然后再经过一个分类器,得到logits[B,L,V];同时把输入文本作为标签,计算交叉熵损失。模型的输入就是inputids [B,L]-------->embedding[B,L,D]------->logits[B,L,V]。

解码过程

同上面类似也是把当前解码结果token和之前的tokens合并起来作为输入解码得到下一个token。

二、解码策略

上面通过示意图简单的解释了一下生成类任务的模型训练和解码过程以及中间的向量维度变化,最后解码的结果好坏出了和模型本身有关,同时也与采用什么样的解码策略也是很相关的。

1、贪心搜索(greedy search)

预测阶段得到的概率分布,连接全连接层后,可以得到一个序列的概率分布[(B*S),vocab_size]——含义就是每个字在词表上的概率分布,共有B*S个字。怎么样通过这个概率分布得到最合理的序列。一种很直观的做法就是从每个字的概率分布中取它的最大概率的那个可能性,直到整个序列完成或者发现终止符[SEP]。简单实现,代码如下:


def gen_nopeek_mask(length):"""Returns the nopeek maskParameters:length (int): Number of tokens in each sentence in the target batchReturns:mask (arr): tgt_mask, looks like [[0., -inf, -inf],[0., 0., -inf],[0., 0., 0.]]"""mask = torch.triu(torch.ones(length, length))mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))return maskdef greedy_search_decode(model, src,src_key_padding_mask, max_len:int = 64, start_symbol:int = 1):""":param model: Transformer model:param src: the encoder input:param max_len: 序列最大长度:return:ys 这个就是预测的具体序列解码的时候这几个mask是不能够少的"""src_mask = gen_nopeek_mask(src.shape[1]).to(device)memory_key_padding_mask = src_key_padding_mask#最开始的字符[CLS]在词表的位置是1ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)for i in range(max_len-1):tar_mask = gen_nopeek_mask(ys.shape[1]).to(device)out = model.forward(src, ys, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=None,src_mask=src_mask,tar_mask=tar_mask,memory_key_padding_mask=memory_key_padding_mask)#预测结果out,选取最后一个概率分布out = out[:,-1,:]#得到最大的那个概率的index,就是该次预测的字在词表的index_, next_word = torch.max(out, dim=1)next_word = next_word.data[0]if next_word != 2:#如果没有预测出终止符[SEP]#把这次预测的结果和以前的结果cat起来,再次循环迭代预测ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)else:breakreturn ys

上面实现的缺陷就是不能并行的解码batch>1的情形,可以适当修改适应并行处理,每次batch内的数据每次解码后,做一个判定,是否batch内的每一行数据都出现了结束字符。判定代码就是:

(ys == 2).sum(1).bool().all()

判定ys的每一行是否出现过2(结束符号)这个元素

解码完整代码如下图

def greedy_search_decode(model, src, src_key_padding_mask, max_len: int = 64, start_symbol: int = 1, bs:int=32):""":param model: Transformer model:param src: the encoder input:param max_len: 序列最大长度:return:ys 这个就是预测的具体序列解码的时候这几个mask是不能够少的"""src_mask = gen_nopeek_mask(src.shape[1]).to(device)memory_key_padding_mask = src_key_padding_mask# 最开始的字符[CLS]在词表的位置是1ys = torch.ones(bs, 1).fill_(start_symbol).type_as(src.data)for i in range(max_len - 1):tar_mask = gen_nopeek_mask(ys.shape[1]).to(device)out = model.forward(src, ys, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=None,src_mask=src_mask, tar_mask=tar_mask, memory_key_padding_mask=memory_key_padding_mask)# 预测结果out,选取最后一个概率分布out = out[:, -1, :]# 得到最大的那个概率的index,就是该次预测的字在词表的index_, next_word = torch.max(out, dim=1)next_word = next_word.data[0]ys = torch.cat([ys, next_word], dim=1)#判定一个batch内是不是所有的都解码完成了if (ys == 2).sum(1).bool().all():breakreturn ys

解码举例如下

the nice woman 是每个时间步当前的最佳选择概率为0.5*0.4=0.2,但是从图上看概率最大的结果并不是这个the dog has 才具有整句最大的概率0.4*0.9 = 0.36;很明显的贪心搜索(greedy search)的缺点就是得出的序列并不一定具有整句最大概率,它很有可能遗漏掉一个比较小的当前概率后面的非常大概率的序列。为了避免这种情况,学者们提出了beam_search算法。

2、beam_search集束搜索

为了避免上述贪心搜索遗漏掉后面大概率的序列,beam search算法提出每次都保留当前最大的beam_num个结果。把当前beam_num个结果分别输入到模型中进行解码,每个序列又新生成v个新结果,共计beam_num*v个结果,排序选择最佳的beam_num个结果;然后重复上述过程,直到解码完成,最后从beam_num个结果选择出概率积最大的那个序列。——即每一步解码过程中都是保留前beam_num个最大的结果,最后才得出概率最大的那个。

以beam_num为2进行举例,图片来自——(全面了解Beam Search 1)

第一步解码,我们选择概率最大的两个单词[A, C],然后分别带入第二步解码,分别得到[AA, AB, AC, AD, AE, CA, CB, CC, CD, CE] 10种情况,这里仅保留最优的两种情况[AB, CE],然后再继续带入第三步解码,以此类推.....最后得到整体概率最大的序列。

bs=1时,实现beam search还是比较简单的,直接在贪心搜索的代码上做修改,记录当前最佳的beam_num个序列以及得分,然后每一步结果从beam_num*v的结果中做排序得到新的beam_num个结果。

当bs>1的时候,要实现一个高效的beam search还是比较麻烦的,参考了全面了解Beam Search 1和世界第一NLP实现库huggingface的transformers中的源码,修改如下的beam search代码:

import torch
import torch.nn.functional as F
from einops import rearrange"""
batch_size为n  这样的处理
"""class BeamHypotheses(object):def __init__(self,num_beams,max_length,length_penalty):self.max_length=max_length-1  # ignoringbos_tokenself.length_penalty=length_penalty  # 长度惩罚的指数系数self.num_beams=num_beams  # beamsizeself.beams=[]  # 存储最优序列及其累加的log_probscoreself.worst_score=1e9  # 将worst_score初始为无穷大。def __len__(self):return len(self.beams)def add(self,hyp,sum_logprobs):score=sum_logprobs / len(hyp) ** self.length_penalty  # 计算惩罚后的scoreif len(self) < self.num_beams or score > self.worst_score:# 如果类没装满num_beams个序列# 或者装满以后,但是待加入序列的score值大于类中的最小值# 则将该序列更新进类中,并淘汰之前类中最差的序列self.beams.append((score, hyp))if len(self) > self.num_beams:sorted_scores=sorted([(s,idx)for idx, (s, _) in enumerate(self.beams)])del self.beams[sorted_scores[0][1]]self.worst_score = sorted_scores[1][0]else:# 如果没满的话,仅更新worst_scoreself.worst_score = min(score, self.worst_score)def is_done(self,best_sum_logprobs,cur_len):# 当解码到某一层后,该层每个结点的分数表示从根节点到这里的log_prob之和# 此时取最高的log_prob,如果此时候选序列的最高分都比类中最低分还要低的话# 那就没必要继续解码下去了。此时完成对该句子的解码,类中有num_beams个最优序列。if len(self) < self.num_beams:return Falseelse:cur_score = best_sum_logprobs / cur_len ** self.length_penaltyret = self.worst_score >= cur_scorereturn retdef gen_nopeek_mask(length):"""Returns the nopeek maskParameters:length (int): Number of tokens in each sentence in the target batchReturns:mask (arr): tgt_mask, looks like [[0., -inf, -inf],[0., 0., -inf],[0., 0., 0.]]"""mask = rearrange(torch.triu(torch.ones(length, length)) == 1, 'h w -> w h')mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))return maskdef beam_sizing(num_beams,src,src_key_padding_mask):#为了满足beam_search 算法在解码的时候的使用,需要进行数据复制——按行进行复制,复制num_beams份temp1 = srctemp2 = src_key_padding_maskfor i in range(num_beams-1):temp1 = torch.cat([temp1,src],dim=0)temp2 = torch.cat([temp2,src_key_padding_mask],dim=0)index = 0for i in range(src.shape[0]):for _ in range(num_beams):temp1[index,...] = src[i,...]temp2[index,...] = src_key_padding_mask[i,...]index += 1src = temp1src_key_padding_mask = temp2return src,src_key_padding_maskdef beam_search(device,model,src,src_key_padding_mask,sos_token_id:int=1,pad_token_id:int=0,eos_token_id:int = 2,max_length:int = 20,num_beams:int =6,vocab_size:int=5993):batch_size = src.shape[0]src_mask = gen_nopeek_mask(src.shape[1]).to(device)src,src_key_padding_mask = beam_sizing(num_beams,src,src_key_padding_mask)memory_key_padding_mask = src_key_padding_maskbeam_scores = torch.zeros((batch_size, num_beams)).to(device)  # 定义scores向量,保存累加的log_probsbeam_scores[:, 1:] = -1e9  # 需要初始化为-infbeam_scores = beam_scores.view(-1)  # 展开为(batch_size * num_beams)done = [False for _ in range(batch_size)]  # 标记每个输入句子的beam search是否完成generated_hyps = [BeamHypotheses(num_beams, max_length, length_penalty=0.7)for _ in range(batch_size)]  # 为每个输入句子定义维护其beam search序列的类实例# 初始输入: (batch_size * num_beams, 1)个sos tokeninput_ids = torch.full((batch_size * num_beams, 1), sos_token_id, dtype=torch.long).to(device)cur_len = 1while cur_len < max_length:tar_mask = gen_nopeek_mask(input_ids.shape[1]).to(device)memory_key_padding_mask = src_key_padding_maskoutputs,_= model.forward(src, input_ids, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=None,src_mask=src_mask,tar_mask=tar_mask,memory_key_padding_mask=memory_key_padding_mask)# 取最后一个timestep的输出 (batch_size*num_beams, vocab_size)next_token_logits = outputs[:, -1, :]scores = F.log_softmax(next_token_logits, dim=-1)  # log_softmaxnext_scores = scores + beam_scores[:, None].expand_as(scores)  # 累加上以前的scoresnext_scores = next_scores.view(batch_size, num_beams * vocab_size)  # 转成(batch_size, num_beams * vocab_size), 如上图所示# 取topk,这里一定要取2*num_beams个最大值,才能保证后续下一批次每个batch内会有num_beams个需要处理的next_scores, next_tokens = torch.topk(next_scores, 2*num_beams, dim=1, largest=True, sorted=True)# 下一个时间步整个batch的beam列表# 列表中的每一个元素都是三元组# (分数, token_id, beam_id)next_batch_beam = []for batch_idx in range(batch_size):if done[batch_idx]:# 当前batch的句子都解码完了,那么对应的num_beams个句子都继续padnext_batch_beam.extend([(0, pad_token_id, 0)] * num_beams)  # pad the batchcontinuenext_sent_beam = []  # 保存三元组(beam_token_score, token_id, effective_beam_id)for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(zip(next_tokens[batch_idx], next_scores[batch_idx])):beam_id = beam_token_id // vocab_size  # 1token_id = beam_token_id % vocab_size  # 1# 上面的公式计算beam_id只能输出0和num_beams-1, 无法输出在(batch_size, num_beams)中的真实id# 如上图, batch_idx=0时,真实beam_id = 0或1; batch_idx=1时,真实beam_id如下式计算为2或3# batch_idx=1时,真实beam_id如下式计算为4或5effective_beam_id = batch_idx * num_beams + beam_id# 如果遇到了eos, 则讲当前beam的句子(不含当前的eos)存入generated_hypif (eos_token_id is not None) and (token_id.item() == eos_token_id):is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beamsif is_beam_token_worse_than_top_num_beams:continuegenerated_hyps[batch_idx].add(input_ids[effective_beam_id].clone(), beam_token_score.item(),)else:# 保存第beam_id个句子累加到当前的log_prob以及当前的token_idnext_sent_beam.append((beam_token_score, token_id, effective_beam_id))if len(next_sent_beam) == num_beams:break# 当前batch是否解码完所有句子done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(next_scores[batch_idx].max().item(), cur_len)  # 注意这里取当前batch的所有log_prob的最大值# 每个batch_idx, next_sent_beam中有num_beams个三元组(假设都不遇到eos)# batch_idx循环后,extend后的结果为num_beams * batch_size个三元组next_batch_beam.extend(next_sent_beam)# 如果batch中每个句子的beam search都完成了,则停止if all(done):break# 准备下一次循环(下一层的解码)# beam_scores: (num_beams * batch_size)# beam_tokens: (num_beams * batch_size)# beam_idx: (num_beams * batch_size)# 这里beam idx shape不一定为num_beams * batch_size,一般是小于等于# 因为有些beam id对应的句子已经解码完了 (下面假设都没解码完)# print('next_batch_beam',len(next_batch_beam))beam_scores = beam_scores.new([x[0] for x in next_batch_beam])beam_tokens = input_ids.new([x[1] for x in next_batch_beam])beam_idx = input_ids.new([x[2] for x in next_batch_beam])# 取出有效的input_ids, 因为有些beam_id不在beam_idx里面,# 因为有些beam id对应的句子已经解码完了# print('beam_idx',beam_idx)# print('next_scores.shape',next_scores.shape)#以下代码是核心的必须添加上input_ids = input_ids[beam_idx, :]  # (num_beams * batch_size, seq_len)src = src[beam_idx,...]src_key_padding_mask = src_key_padding_mask[beam_idx,...]# (num_beams * batch_size, seq_len) ==> (num_beams * batch_size, seq_len + 1)input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)cur_len = cur_len + 1# 注意有可能到达最大长度后,仍然有些句子没有遇到eos token,这时done[batch_idx]是falsefor batch_idx in range(batch_size):if done[batch_idx]:continuefor beam_id in range(num_beams):# 对于每个batch_idx的每句beam,都执行加入add# 注意这里已经解码到max_length长度了,但是并没有遇到eos,故这里全部要尝试加入effective_beam_id = batch_idx * num_beams + beam_idfinal_score = beam_scores[effective_beam_id].item()final_tokens = input_ids[effective_beam_id]generated_hyps[batch_idx].add(final_tokens, final_score)# 经过上述步骤后,每个输入句子的类中保存着num_beams个最优序列# 下面选择若干最好的序列输出# 每个样本返回几个句子output_num_return_sequences_per_batch = num_beams  #一定要小于num_beamsoutput_batch_size = output_num_return_sequences_per_batch * batch_size# 记录每个返回句子的长度,用于后面padsent_lengths = input_ids.new(output_batch_size)best = []best_score = []# retrieve best hypothesesfor i, hypotheses in enumerate(generated_hyps):# x: (score, hyp), x[0]: scoresorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])for j in range(output_num_return_sequences_per_batch):effective_batch_idx = output_num_return_sequences_per_batch * i + jtemp = sorted_hyps.pop()best_hyp = temp[1]best_s = temp[0]sent_lengths[effective_batch_idx] = len(best_hyp)best.append(best_hyp)best_score.append(best_s)if sent_lengths.min().item() != sent_lengths.max().item():sent_max_len = min(sent_lengths.max().item() + 1, max_length)# fill paddecoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)# 填充内容for i, hypo in enumerate(best):decoded[i, : sent_lengths[i]] = hypoif sent_lengths[i] < max_length:decoded[i, sent_lengths[i]] = eos_token_idelse:# 否则直接堆叠起来decoded = torch.stack(best).type(torch.long)# (output_batch_size, sent_max_len) ==> (batch_size*output_num_return_sequences_per_batch, sent_max_len)best_score = torch.tensor(best_score).type_as(next_scores)return decoded,best_score

虽然解决上贪心搜索的缺陷,但是beam search解码策略也有它的缺陷。从实际使用效果来看,beam search很容易重复的出现之前的字符,尤其是在文本生成任务上,机器翻译上效果还行。

How to generate text: using different decoding methods for language generation with Transformers中给出的例子可以看出在生成很短的一句话后,就开始重复了。为了解决这个问题,学者们提出了随机sampling的算法

3、随机sampling

随机采样顾名思义就是对在解码的时候,在下一个token生成的时候,直接随机的进行采样。对于greedy方法的好处是,我们生成的文字开始有了一些随机性,不会总是生成很机械的回复了。存在的问题就很明显了——生成的话术上下文不连贯,语义上可能相互矛盾、也是容易出现一些奇怪的词。

4、Top-K Sampling和Top-p (nucleus) sampling

论文The Curious Case of Neural Text Degeneration中提出一个很有意思的语言现象——

人类的语言总是出人意料的,并不是如同beam search中选择语言模型中概率最大的序列。就是beam search解码策略的结果less surprising!为此论文就基于Top-K Sampling改进得到了核采样Top-p (nucleus) sampling,下面就来聊一聊Top-K Sampling和Top-p (nucleus) sampling。

Top-K Sampling

这个是在随机sampling的基础上改进而来,既然在整个loghits概率分布上做随机采样会导致上下文不连贯,语义上可能相互矛盾、出现奇怪词语等问题,那能不能选取概率最大的K个token,重新形成概率分布,然后再做多项式分布抽样。思想很简单,torch实现起来也不困难。实际使用效果在GPT2模型上得到了很高的提升,GPT2生成的语句非常通顺流利,且重复token大幅度减少。

如图显示的就是K=6的时候,解码第一步6个token占据了整体tokens的三分之二,第二步则占用了99%,并且这些token都是比较合理的,同时采样的时候也采用了多项式随机采样——这样的话就会得到比较通顺流利的话语,也没有重复的词和奇怪的词。

该方法的难点在于K值如何选取

每一步解码过程中,logits的概率分布都是不一样的,在动态改变,固定的K值有可能造成取到的token是低概率的不合理的token;另外K取值过大又会和之前的随机sampling一样生成的话术上下文不连贯,语义上可能相互矛盾、也是容易出现一些奇怪的词;K过小的话,又会导致生成的语句多样性变差,less surprising!最好是K能动态的适应每一步解码的logits!为此有学者提出了核采样Top-p (nucleus) sampling

Top-p (nucleus) sampling

和Top-K Sampling不同的去一个固定的K值,Top-p (nucleus) sampling对整个logits从大到小累积概率,只要累积概率大于一个阈值,就把这些选取的token构成新的分布,然后采取多项式抽样,得到解码的next token!

示例中累积概率阈值p = 0.92 ,第一步解码中采样从9个token中进行;第二步解码从3个token中进行;这样就可以动态的适应logtis,采取不同的K值。不过有一个点就是累积概率阈值P也是不溶于确定的,大多采用经验值。

当然从使用效果上来讲,Top-K Sampling和Top-p (nucleus) sampling都是比较不错的;当然实际使用过程中也是可以把Top-p (nucleus) sampling和Top-K Sampling结合起来,避免概率很小的token作为候选者,同时也保持动态性。

top-k和top-p 过滤代码:

def top_k_top_p_filtering_batch(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):""" Filter a distribution of logits using top-k and/or nucleus (top-p) filteringArgs:logits: logits distribution shape (vocabulary size)top_k > 0: keep only top k tokens with highest probability (top-k filtering).top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317"""top_k = min(top_k, logits.size(-1))  # Safety checkif top_k > 0:# Remove all tokens with a probability less than the last token of the top-k# torch.topk()返回最后一维最大的top_k个元素,返回值为二维(values,indices)# ...表示其他维度由计算机自行推断for i in range(logits.shape[0]):indices_to_remove = logits[i] < torch.topk(logits[i], top_k)[0][..., -1, None]logits[i][indices_to_remove] = filter_value  # 对于topk之外的其他元素的logits值设为负无穷if top_p > 0.0:for i in range(logits.shape[0]):sorted_logits, sorted_indices = torch.sort(logits[i], descending=True)  # 对logits进行递减排序cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)# Remove tokens with cumulative probability above the thresholdsorted_indices_to_remove = cumulative_probs > top_p# Shift the indices to the right to keep also the first token above the thresholdsorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()sorted_indices_to_remove[..., 0] = 0indices_to_remove = sorted_indices[sorted_indices_to_remove]logits[i][indices_to_remove] = filter_valuereturn logits

然后直接调用该过滤算法进行解码

curr_input_tensor = input_ids.to(device)generated = []for index in range(args.max_len):outputs = model(input_ids=curr_input_tensor)next_token_logits = outputs[0][:,-1:]# 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率if index>=1:for i in range(gen_finall.shape[0]):gen_token_ids = gen_finall[i].clone()gen_token_ids = list(set(gen_token_ids.detach().cpu().tolist()))for id in gen_token_ids:next_token_logits[i:i+1,:,id:id+1] /= args.repetition_penaltynext_token_logits = next_token_logits / args.temperature# 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个tokentoken_unk_id = tokenizer.convert_tokens_to_ids('[UNK]')next_token_logits[:,:,token_unk_id:token_unk_id+1] = -float('Inf')#进行top-k和top-p过滤filtered_logits = top_k_top_p_filtering_batch(next_token_logits, top_k=args.topk, top_p=args.topp)# torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标next_token = curr_input_tensor[:,-1:].clone()for i in range(next_token.shape[0]):next_token[i] = torch.multinomial(F.softmax(filtered_logits[i].squeeze(0), dim=-1), num_samples=1)generated.append(next_token)gen_finall = torch.cat(generated,dim=1)# print('gen_finall',gen_finall)# print('tokenizer.sep_token_id',tokenizer.sep_token_id)# print((gen_finall==tokenizer.sep_token_id))# print((gen_finall==tokenizer.sep_token_id).sum(1))# print((gen_finall==tokenizer.sep_token_id).sum(1).bool())# print((gen_finall==tokenizer.sep_token_id).sum(1).bool().all())#batch内所有都解码完成if (gen_finall==tokenizer.sep_token_id).sum(1).bool().all():breakcurr_input_tensor = torch.cat((curr_input_tensor, next_token), dim=1)

三、transformer中的解码使用

前文聊了文本生成和翻译的基本流程、解码策略的一些基本原理和思想以及解码策略的实现,当然更优雅的用法就是直接调用世界第一NLP实现库huggingface的transformers中关于文本翻译类或者生成类的解码函数。generation_utils.py提供了多种解码方式greedy search、beam search、sampling(直接随机sampling、top-K和Top-P)、beam_sample(beam_search+top-K和Top-P)和group_beam。至于其他的一些功能,需要读者自己去阅读源码。

解码很简单,代码如下,加载模型,喂入数据,解码,得到结果。

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch
from data_reader.dataReader_zh2en import DataReader
if __name__ == '__main__':tokenizer = AutoTokenizer.from_pretrained("./pretrained_models/MarianMTModel_zh2en")model = AutoModelForSeq2SeqLM.from_pretrained("./pretrained_models/MarianMTModel_zh2en")dataset = DataReader(tokenizer, filepath='data/test_sample.csv')test_dataloader = DataLoader(dataset=dataset,batch_size=4)device = 'cuda' if torch.cuda.is_available() else 'cpu'model.to(device)finanl_result = []for batch in tqdm(test_dataloader,desc='translation prediction'):for k, v in batch.items():batch[k] = v.to(device)batch = {'input_ids': batch['input_ids'], 'attention_mask': batch['attention_mask']}# Perform the translation and decode the outputtranslation = model.generate(**batch, top_k=5, num_return_sequences=1,num_beams=1)batch_result = tokenizer.batch_decode(translation, skip_special_tokens=True)finanl_result.extend(batch_result)print(len(finanl_result))for res in finanl_result:print(res.replace('[','').replace(']',''))

下文以翻译类任务为例,采用基于transformer架构的MarianMT模型,MarianMTModel_zh2en中文到英文的模型参数。

完整代码如下

import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch
from data_reader.dataReader_zh2en import DataReaderif __name__ == '__main__':tokenizer = AutoTokenizer.from_pretrained("./pretrained_models/MarianMTModel_zh2en")model = AutoModelForSeq2SeqLM.from_pretrained("./pretrained_models/MarianMTModel_zh2en")dataset = DataReader(tokenizer, filepath='data/test_sample.csv')test_dataloader = DataLoader(dataset=dataset,batch_size=4)device = 'cuda' if torch.cuda.is_available() else 'cpu'model.to(device)finanl_result = []for batch in tqdm(test_dataloader,desc='translation prediction'):for k, v in batch.items():batch[k] = v.to(device)batch = {'input_ids': batch['input_ids'], 'attention_mask': batch['attention_mask']}# Perform the translation and decode the output#greedygreedy_translation = model.generate(**batch,num_return_sequences = 1)greedy_batch_result = tokenizer.batch_decode(greedy_translation, skip_special_tokens=True)finanl_result.append(greedy_batch_result)#beam_searchbeam_translation = model.generate(**batch, num_return_sequences=1, num_beams=5)beam_batch_result = tokenizer.batch_decode(beam_translation, skip_special_tokens=True)finanl_result.append(beam_batch_result)#samplingsample_translation = model.generate(**batch, do_sample=True, num_return_sequences=1)sample_batch_result = tokenizer.batch_decode(sample_translation, skip_special_tokens=True)finanl_result.append(sample_batch_result)#top-ktopk_translation = model.generate(**batch, top_k=5, num_return_sequences=1)topk_batch_result = tokenizer.batch_decode(topk_translation, skip_special_tokens=True)finanl_result.append(topk_batch_result)# top-ptopp_translation = model.generate(**batch, top_p=0.92, num_return_sequences=1)topp_batch_result = tokenizer.batch_decode(topp_translation, skip_special_tokens=True)finanl_result.append(topp_batch_result)# top-k和top-ptopktopp_translation = model.generate(**batch, top_k=5, top_p=0.92, num_return_sequences=1)topktopp_batch_result = tokenizer.batch_decode(topktopp_translation, skip_special_tokens=True)finanl_result.append(topktopp_batch_result)# top-k和top-p+beam_searchbeamtopktopp_translation = model.generate(**batch, top_k=5, top_p=0.92, num_return_sequences=1, num_beams=5)beamtopktopp_batch_result = tokenizer.batch_decode(beamtopktopp_translation, skip_special_tokens=True)finanl_result.append(beamtopktopp_batch_result)decodes_policys = ['greedy search','beam_search','sampling','top-k','top-p','top-k和top-p','top-k和top-p+beam_search']test_sample = ['【由富氏隐孢子虫引起的皮肤真菌病】。','[十二指肠转换手术中的减肥手术:体重变化和相关的营养缺乏]。','[宫腔镜研究数字图像的观察者间诊断协议]。']print(len(finanl_result))for i in range(3):print(test_sample[i])for ele,de_ty in zip(finanl_result,decodes_policys):print(ele[i].replace('[','').replace(']',''))print('*'*100)

翻译src文本

【由富氏隐孢子虫引起的皮肤真菌病】。
[十二指肠转换手术中的减肥手术:体重变化和相关的营养缺乏]。
[宫腔镜研究数字图像的观察者间诊断协议]。

不同解码策略得到的结果对比

【由富氏隐孢子虫引起的皮肤真菌病】。
Skin fungi caused by Fung's Invisible Spores.
Skin fungus disease caused by Fung's Invisible Spores.
Skin fungi caused by Fung's spores.
Skin fungi caused by Fung's Invisible Spores.
Skin fungi caused by Fung's Invisible Spores.
Skin fungi caused by Fung's Invisible Spores.
Skin fungus disease caused by Fung's Invisible Spores.
****************************************************************************************************
[十二指肠转换手术中的减肥手术:体重变化和相关的营养缺乏]。
Twelve reference to fertility reduction in intestinal conversion operations: changes in body weight and associated nutritional deficiencies.
Twelve reference to fertility reduction in intestinal conversion operations: changes in body weight and associated nutritional deficiencies.
Liith finger intestinal conversion operations with dietary loss: weight changes and associated nutritional deficiencies.
Twelve reference to fertility reduction in intestinal conversion operations: changes in body weight and associated nutritional deficiencies.
Twelve reference to fertility reduction in intestinal conversion operations: changes in body weight and associated nutritional deficiencies.
Twelve reference to fertility reduction in intestinal conversion operations: changes in body weight and associated nutritional deficiencies.
Twelve reference to fertility reduction in intestinal conversion operations: changes in body weight and associated nutritional deficiencies.
****************************************************************************************************
[宫腔镜研究数字图像的观察者间诊断协议]。
Observer-to-observer protocol for the study of digital images in the court cavity mirrors.
Observer-to-observer protocol for the study of digital images in the court cavity mirrors.
Observatorial protocol for the study of digital images in the uterine cavity mirror.
Observer-to-observer protocol for the study of digital images in the court cavity mirrors.
Observer-to-observer protocol for the study of digital images in the court cavity mirrors.
Observer-to-observer protocol for the study of digital images in the court cavity mirrors.
Observer-to-observer protocol for the study of digital images in the court cavity mirrors.

翻译任务来看结果差异不是很大,不过也有一些差异。

参考文献

How to generate text: using different decoding methods for language generation with Transformers

Nucleus Sampling与文本生成中的不同解码策略比较

Seq2Seq解码策略-概念

全面了解Beam Search

The Curious Case of Neural Text Degeneration

浅谈文本生成或者文本翻译解码策略相关推荐

  1. 浅谈文本生成或者文本翻译解码策略《转》

    原文链接,感谢原作者 目录 一.文本生成和翻译的基本流程 翻译类模型的训练和解码 训练过程 解码过程 生成类模型的训练和解码(GPT系列) 训练过程 解码过程 二.解码策略 1.贪心搜索(greedy ...

  2. python计算现场得分_浅谈用 Python 计算文本 BLEU 分数

    浅谈用 Python 计算文本 BLEU 分数 BLEU, 全称为 Bilingual Evaluation Understudy(双语评估替换), 是一个比较候选文本翻译与其他一个或多个参考翻译的评 ...

  3. 浅谈用Python计算文本BLEU分数

    在本教程中,你探索了BLEU评分,根据在机器翻译和其他语言生成任务中的参考文本对候选文本进行评估和评分. 具体来说,你学到了: BLEU评分的简单入门介绍,并直观地感受到到底是什么正在被计算. 如何使 ...

  4. 【200+论文】深度强化学习、对话系统、文本生成、文本摘要、阅读理解等文献列表

    [导读]本文收录了深度强化学习.对话系统.文本生成.文本摘要.阅读理解.因果推理.记忆网络.推荐系统.神经表示学习等一系列领域参考文献大合集! https://cloud.tencent.com/de ...

  5. 浅谈图像生成模型 Diffusion Model 原理

    重磅推荐专栏: <AI 大模型之美> 揭开 ChatGPT 面纱,拥抱 AI 新潮流 重磅推荐专栏: <Transformers自然语言处理系列教程> 手把手带你深入实践Tra ...

  6. 浅谈问题生成(Question Generation)

    ©作者 | 刘璐 学校 | 北京邮电大学 研究方向 | 问题生成与QA 问题生成(Question Generation)是文本生成中的重要子任务,旨在根据输入数据(文本.知识库.图像等类型,本文仅聚 ...

  7. 浅谈人工智能生成内容(AIGC)

    兴趣了解 [OpenAI ]人工智能绘画产品 DALL·E: 在计算机上输入一句话,DALL·E 就能够理解这句话.然后自动生成一幅意思相应的图像,且该图像是全网首发.独一无二. [谷歌 ] 5400 ...

  8. 浅谈语音助手的对话管理与策略制定

    本篇文章首先梳理了对话系统中的对话管理的原理,包括中控系统的分发.各类bot处理Query的逻辑.候选回复融合和排序的功能,其中也包含了垂直领域知识图谱的构建.最后从PM角度思考,为了提升bot的表现 ...

  9. 浅谈航天防伪开票文本接口文件的解析

    航天信息防伪税控开票系统开票文本接口文件格式对外是公开的,只要你安装了防伪税控开票系统,就可以在其安装的目录如:"C:/Program Files/航天信息/防伪开票/DOC/接口文件示例& ...

最新文章

  1. P3512 [POI2010]PIL-Pilots(单调队列+二分)
  2. 从一个表查询数据插入另一个表
  3. 肚子上挂张画就能隐身:AI完全看不出我在哪,更看不出我是人类了 | 开源
  4. JAVA——网络编程
  5. Micropython教程之TPYBoard DIY金属探测仪实例演示(萝卜学科编程教育)
  6. Swagger使用总结
  7. matlab 生成dbc文件,simulink中使用dbc文件实现CAN消息发送与代码生成
  8. pandas——数据透视表
  9. ArduPilot简介
  10. MATLA 复制文件到指定文件夹
  11. 我的抗争:一个中年编外程序员的挣扎
  12. 【饭谈】【超详细】的资深测开的招聘要求,大家看看这符合了值多少钱?
  13. C. Xenon's Attack on the Gangs(树形dp)
  14. 清华集训2014 day1 task1 玛里苟斯
  15. 来自菜鸡的前端权限简单实现
  16. RISC-V Linux kernel debug 环境搭建
  17. 浅谈Redis面试热点之工程架构篇[1]
  18. YT8511芯片手册 解析|CSDN创作打卡
  19. 从键盘中输入年、月、日,判断这一天为当年的第几天(考虑闰年和非法输入的情况)
  20. 使用原生JS在Vue实例中动态插入元素

热门文章

  1. omf多路径 oracle_【OMF】使用Oracle的OMF 特性
  2. 2021上半年的软考证书,最快要什么时候才能到手?
  3. GISer入门指南 第二季(PPTX)
  4. 【20保研】清华-伯克利深圳学院2019年暑期夏令营招募通知
  5. ​字创未来 方正字库第十二届“方正奖”设计大赛正式来袭
  6. 任何一个做计算机软件的人的梦想:墨绿,我们能做到吗?或者说,什么时候能做到?
  7. 在线教育APP的发展历程
  8. java freepascal_Lazarus一个开源的跨平台FreePasscal集成开发环境
  9. 前端学习便捷软件,插件
  10. 网站端服务器返回错误,报税网站端服务器错误 如何建立网站服务器