Transformers中用于语言生成的不同解码方法

原文地址:https://huggingface.co/blog/how-to-generate

相关博客
【自然语言处理】【ChatGPT系列】InstructGPT:遵循人类反馈指令来训练语言模型
【自然语言处理】【ChatGPT系列】大模型的涌现能力
【自然语言处理】【文本生成】CRINEG Loss:学习什么语言不建模
【自然语言处理】【文本生成】使用Transformers中的BART进行文本摘要
【自然语言处理】【文本生成】Transformers中使用约束Beam Search指导文本生成
【自然语言处理】【文本生成】Transformers中用于语言生成的不同解码方法
【自然语言处理】【文本生成】BART:用于自然语言生成、翻译和理解的降噪Sequence-to-Sequence预训练
【自然语言处理】【文本生成】UniLM:用于自然语言理解和生成的统一语言模型预训练
【自然语言处理】【多模态】OFA:通过简单的sequence-to-sequence学习框架统一架构、任务和模态

一、简介

​ 近些年来,随着大型预训练语言模型的兴起,人们对开发式语言生成越来越感兴趣。之所以开放式语言生成效果令人印象深刻,除了Transformer\text{Transformer}Transformer架构的改善和大量的无监督训练数据,更好的解码方式也扮演着重要的角色。本文对不同的解码方法进行了简单的介绍并展示如何使用transformers库进行实现。

​ 以下所有的功能都是用于自回归语言生成。简单来说,自回归语言生成是基于以下假设:一个单词序列的概率分布可以被分解为下一个词分布的乘积:
P(w1:T∣W0)=∏t=1TP(wt∣w1:t−1,W0),withw1:0=∅(1)P(w_{1:T}|W_0)=\prod_{t=1}^T P(w_t|w_{1:t-1},W_0),\text{with}\quad w_{1:0}=\empty \tag{1} P(w1:T​∣W0​)=t=1∏T​P(wt​∣w1:t−1​,W0​),withw1:0​=∅(1)
其中,W0W_0W0​是初始上下文。单词序列的长度TTT通常是动态的,相当于P(wt∣w1:t−1,W0)P(w_t|w_{1:t-1},W_0)P(wt​∣w1:t−1​,W0​)在时间步t=Tt=Tt=T生成EOS token。

​ 本文将介绍目前最著名的解码方法,主要有:Greedy searchBeam searchTop-k samplingTop-p sampling

二、加载模型

from transformers import GPT2LMHeadModel, GPT2Tokenizermodel = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = model.to("cuda")

三、Greedy Search

​ Greedy search\text{Greedy search}Greedy search简单选择概率最高的单词作为下一个单词:wt=argmaxwP(w∣w1:t−1)w_t=\text{argmax}_wP(w|w_{1:t-1})wt​=argmaxw​P(w∣w1:t−1​)。下图是Greedy search\text{Greedy search}Greedy search的示图。

​ 起始于单词The,算法贪心的选择概率最高的下个单词nice,最终生成的单词序列(The,nice,woman)(\text{The,nice,woman})(The,nice,woman)具有整体概率0.5×0.4=0.20.5\times 0.4=0.20.5×0.4=0.2。

input_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='pt').to("cuda")
greedy_output = model.generate(input_ids, max_length=50)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))

输出:

Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my dog.I'm not sure if I'll

​ 根据上下文生成的单词是合理的,但是模型很快就开始重复!这在语言生成中是非常常见的问题,在greedy searchbeam search中更是如此。

greed search的主要缺点是会忽略低概率单词后面的高概率单词。如上面的图所示,单词has具有最高的条件概率0.9,其是在第二高概率单词dog之后,所以greedy search忽略了单词序列(The,dog,has)(\text{The,dog,has})(The,dog,has)。beam search在一定程度上缓解了这个问题。

四、Beam Search

beam search降低了丢失隐藏高概率单词序列的风险,通过在每个时间步中保留最可能的num_beams,并最终选择总体概率最高的假设。这里基于num_beams=2进行解释:

​ 在时间步1,除了最可能的假设(The, nice)(\text{The, nice})(The, nice),beam search也会跟踪第二可能的(The,dog)(\text{The,dog})(The,dog)。在时间步2,beam search发现单词序列(The,dog,has)(\text{The,dog,has})(The,dog,has)的概率0.36高于(The,nice,woman)(\text{The,nice,woman})(The,nice,woman)的概率0.2。beam search发现的输出序列概率高于greedy search,但是不能保证是概率最高的序列。

​ 下面是transformers中的beam search。设置num_beams>1并且early_stopping=True,当所有的beam hypotheses达到EOS token则生成完成。

beam_output = model.generate(input_ids, max_length=50, num_beams=5, early_stopping=True
)print("Output:\n" + 100 * '-')
print(tokenizer.decode(beam_output[0], skip_special_tokens=True))

输出:

Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.I'm not sure if I'll ever be able to walk with him again. I'm not sure if I'll

​ 虽然结果更加的流畅,但是输出仍然包含重复的单词序列。一种简单的解决方案是引入n-grams的惩罚项。最常见的n-grams惩罚项是确保没有n-gram出现两次,通过设置已经出现的n-gram在下一次生成中概率为0。transformers\text{transformers}transformers中设置no_repeat_ngram_size=2

beam_output = model.generate(input_ids, max_length=50, num_beams=5, no_repeat_ngram_size=2, early_stopping=True
)print("Output:\n" + 100 * '-')
print(tokenizer.decode(beam_output[0], skip_special_tokens=True))

输出:

Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.I've been thinking about this for a while now, and I think it's time for me to take a break

​ 效果看起来不错,没有再出现重复。然而,n-gram惩罚在使用时必须小心。生成一篇关于城市New York的文章则不应该使用2-gram惩罚,否则城市的名字在全文中仅出现一次。

beam search的另一个重要特征是,beam search能够在生成后比较各个beams并选择出最符合目标的beam。在transformers中,可以简单的设置参数num_return_sequences来指定返回的beams数量。

beam_outputs = model.generate(input_ids, max_length=50, num_beams=5, no_repeat_ngram_size=2, num_return_sequences=5, early_stopping=True
)print("Output:\n" + 100 * '-')
for i, beam_output in enumerate(beam_outputs):print("{}: {}".format(i, tokenizer.decode(beam_output, skip_special_tokens=True)))

输出:

Output:
----------------------------------------------------------------------------------------------------
0: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.I've been thinking about this for a while now, and I think it's time for me to take a break
1: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.I've been thinking about this for a while now, and I think it's time for me to get back to
2: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with her again.I've been thinking about this for a while now, and I think it's time for me to take a break
3: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with her again.I've been thinking about this for a while now, and I think it's time for me to get back to
4: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.I've been thinking about this for a while now, and I think it's time for me to take a step

可以看到,5个输出的beam的差别微小。

​ 在开放生成中,最近提出了beam search不是最好选择的原因:

  • beam search在机器翻译或者摘要这种结果的长度或多或少可以被预测的任务中表现很好。但是在开放生成中期望输出的长度变化非常大,例如:对话和故事生成。

  • beam search会有严重的重复生成问题。在故事生成中使用n-gram或者其他惩罚都很难控制,在强制"不重复"和重复特定的n-gram上寻找一个好的平衡需要大量的微调。

  • 高质量的人类语言是不会遵循下一个词的高概率分布。换句话说,作为人类期望生成的文本是有趣的,而不是无聊且可预测的。下图是beam search和人类生成结果的有趣性对比

五、Sampling

​ 在最基本的形式中,sampling意味着根据条件概率分布随机挑选下一个单词wtw_twt​:
wt∼P(w∣w1:t−1)w_t\sim P(w|w_{1:t-1}) wt​∼P(w∣w1:t−1​)
以上面的例子为例,下图可视化了sampling时的语言生成。

显然,使用sampling的语言生成不再是确定的。单词car是从条件概率分布P(w∣The)P(w|The)P(w∣The)采样的,drives则是从P(w∣The,car)P(w|The,car)P(w∣The,car)采样的。

​ 在transformers中,可以设置do_sample=True并通过top-k=0禁用Top-K sampling\text{Top-K sampling}Top-K sampling来实现。

torch.random.set_seed(0)sample_output = model.generate(input_ids, do_sample=True, max_length=50, top_k=0
)print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))

输出:

Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog. He just gave me a whole new hand sense."But it seems that the dogs have learned a lot from teasing at the local batte harness once they take on the outside."I take

生成的文章看起来很好,但是仔细看可以发现不是很连贯。new hand senselocal batte harness非常的奇怪,听起来不像是人类写的。最大的问题是采样的单词顺序:模型经常产生不连贯的胡言乱语。

​ 一种解决方法是通过softmaxtemperature参数使得分布P(w∣w1:t−1)P(w|w_{1:t-1})P(w∣w1:t−1​)更加的陡峭。应用temperature的例子如下。

在时间步t=1t=1t=1的下个单词概率变的更加的陡峭,使得单词car几乎不会被选择。下面是通过设置temperature=0.7来使得分布更加陡峭:

torch.random.set_seed(0)sample_output = model.generate(input_ids, do_sample=True, max_length=50, top_k=0, temperature=0.7
)print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))

输出:

Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog, but I don't like to be at home too much. I also find it a bit weird when I'm out shopping. I am always away from my house a lot, but I do have a few friends

六、Top-K Sampling

​ Top-K Sampling\text{Top-K Sampling}Top-K Sampling是非常有效的采样方案。在Top-K Sampling\text{Top-K Sampling}Top-K Sampling中,概率最高的KKK个下一个单词被挑选出来,概率权重仅在这KKK个单词中重新分配。GPT2\text{GPT2}GPT2就采用这种采样策略,这也是其在故事生成上成功的原因之一。

​ 为了更好的解释Top-K\text{Top-K}Top-K采样,将上面示例中的两个采样步骤的单词范围从3个扩展至10个。

​ 设置K=6K=6K=6,在两个采样步骤中均现在采样大小为6个词。6个最可能的词定义为Vtop−K\text{V}_{top-K}Vtop−K​,在第一个时间步其包含了整个概率的2/3,在第二个时间步其包含了绝大多数的概率。可以看到其在第二个时间步成功忽略了一些奇怪的候选词notthesmalltold

​ 下面展示通过设置top_k=50来使用Top-K Sampling\text{Top-K Sampling}Top-K Sampling:

torch.random.set_seed(0)sample_output = model.generate(input_ids, do_sample=True, max_length=50, top_k=50
)print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))

输出:

Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog. It's so good to have an environment where your dog is available to share with you and we'll be taking care of you.We hope you'll find this story interesting!I am from

结果还不错。生成的文本目前是最像人类撰写的文本。然而,Top-K sampling\text{Top-K sampling}Top-K sampling不能动态地调整需要过滤的下一个单词的数量。这会带来一些问题,因为一些单词可能是从非常陡峭的分布中采样的,而另一些则是从更加平坦分布采样的。

​ 在时间步t=1t=1t=1,Top-K\text{Top-K}Top-K消除了采样peoplebighousecat的可能,但这些词似乎也是合理的。另一方面,在时间步t=2t=2t=2中,方法在候选单词中包含了downa等不合适的单词。因此,限制采样池为固定大小的KKK会使模型在陡峭分布时胡言乱语,并且在平坦分布时限制模型的创造力。

七、Top-p(nucleus) Sampling

​ 相较于从最可能的KKK个词中采样,Top-p sampling\text{Top-p sampling}Top-p sampling是累计概率超过概率ppp的最小可能单词集合中选择。然后,概率分布在这组单词中重新分布。这个方法中单词集合的尺寸是根据下个单词概率分布动态增减的。

这里设p=0.92p=0.92p=0.92,Top-p sampling\text{Top-p sampling}Top-p sampling挑选概率累计超过p=92%p=92\%p=92%的最小数量单词,定义为Vtop-pV_{\text{top-p}}Vtop-p​。在第一个例子中,其包含9个最可能的词;而在第二个例子中,最高概率的3个词就超过了90%90\%90%。可以看到,在下一个单词难以预测时保留了更多的单词,例如P(w∣The)P(w|The)P(w∣The);当下一个单词更加明确时则保留更少的单词,例如P(w∣The,car)P(w|The,car)P(w∣The,car)。

​ 在transformers中通过设置0<top_p<1来激活Top-p sampling\text{Top-p sampling}Top-p sampling。

torch.random.set_seed(0)sample_output = model.generate(input_ids, do_sample=True, max_length=50, top_p=0.92, top_k=0
)print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))

输出:

Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog. He will never be the same. I watch him play.Guys, my dog needs a name. Especially if he is found with wings.What was that? I had a lot o

上面的结果看起来像人类写的,但是还不够。理论上Top-p\text{Top-p}Top-p比Top-K\text{Top-K}Top-K更优雅,两种方法在实践中表现都不错。Top-p\text{Top-p}Top-p还可以和Top-K\text{Top-K}Top-K组合使用,其可以避免排名非常低的单词并允许动态选择。

最后,为了得到多个独立的采样输出,可以设置参数num_return_sequences > 1

torch.random.set_seed(0)sample_outputs = model.generate(input_ids,do_sample=True, max_length=50, top_k=50, top_p=0.95, num_return_sequences=3
)print("Output:\n" + 100 * '-')
for i, sample_output in enumerate(sample_outputs):print("{}: {}".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))

输出:

Output:
----------------------------------------------------------------------------------------------------
0: I enjoy walking with my cute dog. It's so good to have the chance to walk with a dog. But I have this problem with the dog and how he's always looking at us and always trying to make me see that I can do something
1: I enjoy walking with my cute dog, she loves taking trips to different places on the planet, even in the desert! The world isn't big enough for us to travel by the bus with our beloved pup, but that's where I find my love
2: I enjoy walking with my cute dog and playing with our kids," said David J. Smith, director of the Humane Society of the US."So as a result, I've got more work in my time," he said.

八、结论

​ 作为解码方法,在开发语言生成中Top-p\text{Top-p}Top-p和Top-K\text{Top-K}Top-K似乎能够比传统的greedy search\text{greedy search}greedy search和beam search\text{beam search}beam search产生更加流程的文本。最近的研究表明greedy searchbeam search的明显缺陷,即生成重复的单词序列,是由于模型导致的,而不是解码方法。因此,Top-p\text{Top-p}Top-p和Top-K\text{Top-K}Top-K也会遭受生成重复单词序列的问题。

【自然语言处理】【文本生成】Transformers中用于语言生成的不同解码方法相关推荐

  1. 文本生成 计算机语言,自然语言处理 -- 文本生成概述

    文本生成是自然语言处理中一个重要的研究领域,具有广阔的应用前景.本文主要介绍了文本生成的定义.任务.评价指标和实现方法.重点介绍了目前正在成为文本生成技术主流的数据驱动方法. 1.文本生成定义 自然语 ...

  2. 【天池月饼活动】基于自然语言处理文本生成与轮询问答与依图生文与中秋月饼配图

    文章目录 活动要求 项目演示 自然语言处理文本生成之对联上下联对答: 自然语言处理文本生成之依图生文: 多级轮询月饼对话 自动生成月饼 代码讲解 自从我前天进入阿里天池实验室,我就被他吸粉了,嫖了他的 ...

  3. python返回绝对值的函数_Python中用于返回绝对值的abs()方法

    Python中用于返回绝对值的abs()方法 方法abs() 返回x的绝对值,-x-零之间的(正极)的距离. 语法 以下是abs()方法的语法: abs( x ) 参数 x -- 这是一个数值表达式 ...

  4. 复制文本到word中时产生底色的去除方法

    复制文本到word中时产生底色的去除方法 有时复制一些文本到word中会出现这种情况 方法: 在布局里面找边框然后找底纹,填充改成无颜色 就好了

  5. Python中用于计算对数的log()方法

    本文转载至:http://www.jb51.net/article/66130.htm 这篇文章主要介绍了Python中用于计算对数的log()方法,是Python入门基础中的必会的方法,需要的朋友可 ...

  6. 1507四舍五入c语言,JavaScript中用于四舍五入的Math.round()方法讲解

    此方法返回一个数四舍五入为最接近的整数的值. 语法 Math.round( x ) ; 下面是参数的详细信息: x: 一个数字 返回值: 返回数字四舍五入为最接近的整数的值. 例子: JavaScri ...

  7. python指数运算函数_分享Python中用于计算指数的exp()方法实例教程

    exp()方法返回指数x: ex. 语法 以下是exp()方法的语法:import math math.exp( x ) 注意:此函数是无法直接访问的,所以我们需要导入math模块,然后需要用math ...

  8. 富文本转换html,在百度富文本编辑器UEditor中增加word转html的方法

    1.需求 在一个项目中有个需求:复制word的内容到编辑器中.但是在复制过程中图片不能成功的复制过来,需要安装flash插件,但是吧又不能要求每个客户都安装上,这就比较麻烦了.所以考虑是不是可以把wo ...

  9. 【自然语言处理】【文本生成】Transformers中使用约束Beam Search指导文本生成

    Transformers中使用约束Beam Search指导文本生成 原文地址:https://huggingface.co/blog/constrained-beam-search 相关博客 [自然 ...

最新文章

  1. webpack打包路径更改_扫盲: Webpack 从扫盲到手撸(上)
  2. Base64编码原理与实现
  3. Java IO学习笔记(四)打印流
  4. 双曲线和直线联立公式_高中圆锥曲线解题技巧之齐次化联立(四)
  5. php数组函数及用法,php数组函数 in_array 的用法及注意事项
  6. 论文浅尝 - AAAI2020 | 多轮对话系统中的历史自适应知识融合机制
  7. jsp中String path = request.getContextPath()的作用
  8. Ubuntu系统运行darknet出OSError: /libdarknet.so: cannot open shared object file: No such file or directory
  9. keychron k8 连接切换蓝牙方案
  10. SAP License:SAP ECC6安装系列一:安装前硬件和软件准备
  11. 操作系统中的用户空间和内核空间
  12. PHP查看内存使用量
  13. ct扫描方式有哪些_日联科技x-ray:工业CT是怎么进行X射线的断层扫描的
  14. 无线数字信息传送服务器,无线数字远程监控管理及网站实时推广项目方案.doc...
  15. Roguelike到底是啥?讲讲和Roguelike 相关知识(搬运)
  16. 控制台报 [WDS] Disconnected!不影响代码运行。
  17. Oracle存储过程实现X日均线计算
  18. 高新技术企业申请,申请高新技术企业需要什么材料
  19. 八路扫描式抢答器设计
  20. CentOS 7 快速搭建JavaWeb开发环境并部署Spring boot项目(纯干货、详细)

热门文章

  1. EMNLP 2021图相关论文合集
  2. 三星手机怎么上传文件到云服务器,三星Quick Share快传功能曝光,基于云端服务实现文件传输...
  3. 从0到1搭建自己的网站保姆级教程 · 上篇 · 域名与云服务器的准备【网站建设】
  4. 计算机学院王国胤,王国胤-中国科学院大学-UCAS
  5. uniapp -- 扫码记录(针对app场景)
  6. 背包问题——01背包/完全背包/多重背包
  7. 23-Openwrt switch vlan配置
  8. mysql怎么跑代码_MySQL菜鸟入门指南_mysql
  9. 搭建机器人电控系统——通信协议——串口通信USART/UART、RS232、RS485及其实例
  10. ****项目压测方案