ChatGPT模型采样算法详解

ChatGPT所使用的模型——GPT(Generative Pre-trained Transformer)模型有几个参数,理解它们对文本生成任务至关重要。其中最重要的一组参数是temperaturetop_p。二者控制两种不同的采样技术,用于因果语言模型(Causal language models)中预测给定上下文情景中下一个单词出现的概率。本文将重点讲解temperaturetop_p的采样原理,以及它们对模型输出的影响。

文章目录

  • 理解因果语言模型中的采样
  • Top-k采样
    • Top-p采样
  • 温度采样
    • 典型用例
  • 总结

理解因果语言模型中的采样

假设我们训练了一个描述个人生活喜好的模型,我们想让它来补全“我喜欢漂亮的___”这个句子。一般语言模型会按照下图的流程来工作:

模型会查看所有可能的单词,并根据其概率分布从中采样,以预测下一个词。为了方便起见,假设模型的词汇量不大,只有:“大象”、“西瓜”、“鞋子”和“女孩”。通过下图的词汇概率我们可以发现,“女孩”的选中概率最高( p = 0.664 p=0.664 p=0.664),“西瓜”的选中概率最低( p = 0.032 p=0.032 p=0.032)。

上面的例子中,很明显“女孩”最可能被选中。因为人类对于单一问题在心智上习惯采用 “贪心策略”,即选择概率最高的事件。

永远选择分数或概率最大的token,这种策略叫做“贪心策略”。
贪心策略符合人类的心智,但是存在严重缺陷。

但是上面这种策略用在频繁交互的场景下会有一个显著缺陷——如果我们总是选择最可能的单词,那么这个词会反复不断被强化,因为现代语言模型中大多数模型的注意力只集中在最近的几个词(Token)上。这样生成的内容将非常的生硬和可预测,人们一眼就能看出是机器生成的且一点也不智能。

如何让我们的模型不那么具有确定性,让它生成的内容用词更加活跃呢?为此,我们引入了基于分布采样的生成采样算法。但是传统的采样方法会遇到了一个问题:如果我们有5万个候选词(Token),即使最后2.5万个极不可能出现的长尾词汇,它们的概率质量也可能会高达30%。这意味着,对于每个样本,我们有1/3的机会完全偏离原来的“主题”。又由于上面提到的注意力模型倾向于集中在最近出现的词上,这将导致不可恢复的错误级联,因为下一个词严重依赖于最近的错误词。

为了防止从尾部采样,最流行的方法是Top-k采样温度采样

Top-k采样

Top-k采样是对前面“贪心策略”的优化,它从排名前k的token种进行抽样,允许其他分数或概率较高的token也有机会被选中。在很多情况下,这种抽样带来的随机性有助于提高生成质量。

添加一些随机性有助于使输出文本更自然。
上图示例中,我们首先筛选似然值前三的token,然后根据似然值重新计算采样概率。

通过调整k的大小,即可控制采样列表的大小。“贪心策略”其实就是k=1的top-k采样。

Top-p采样

ChatGPT实际使用的不是Top-k采样,而是其改进版——Top-p采样。

Top-k有一个缺陷,那就是“k值取多少是最优的?”非常难确定。于是出现了动态设置token候选列表大小策略——即核采样(Nucleus Sampling)。下图展示了top-p值为0.9的Top-p采样效果:

在top-p中,根据达到某个阈值的可能性得分之和动态选择候选名单的大小。

top-p值通常设置为比较高的值(如0.75),目的是限制低概率token的长尾。我们可以同时使用top-k和top-p。如果kp同时启用,则pk之后起作用。

温度采样

温度采样受统计热力学的启发,高温意味着更可能遇到低能态。在概率模型中,logits扮演着能量的角色,我们可以通过将logits除以温度来实现温度采样,然后将其输入Softmax并获得采样概率。

越低的温度使模型对其首选越有信心,而高于1的温度会降低信心。0温度相当于argmax似然,而无限温度相当于于均匀采样。

温度采样中的温度与玻尔兹曼分布有关,其公式如下所示:
ρ i = 1 Q e − ϵ i / k T = e − ϵ i / k T ∑ j = 1 M e − ϵ j / k T \rho_i = \frac{1}{Q}e^{-\epsilon_i/kT}=\frac{e^{-\epsilon_i/kT}}{\sum_{j=1}^M e^{-\epsilon_j/kT}} ρi​=Q1​e−ϵi​/kT=∑j=1M​e−ϵj​/kTe−ϵi​/kT​
其中 ρ i \rho_i ρi​ 是状态 i i i 的概率, ϵ i \epsilon_i ϵi​ 是状态 i i i 的能量, k k k 是波兹曼常数, T T T 是系统的温度, M M M 是系统所能到达的所有量子态的数目。

有机器学习背景的朋友第一眼看到上面的公式会觉得似曾相识。没错,上面的公式跟Softmax函数 S o f t m a x ( z i ) = e z i ∑ c = 1 C e z c Softmax(z_i) = \frac{e^{z_i}}{\sum_{c=1}^Ce^{z_c}} Softmax(zi​)=∑c=1C​ezc​ezi​​ 很相似,本质上就是在Softmax函数上添加了温度(T)这个参数。Logits根据我们的温度值进行缩放,然后传递到Softmax函数以计算新的概率分布。

上面“我喜欢漂亮的___”这个例子中,初始温度 T = 1 T=1 T=1,我们直观看一下 T T T 取不同值的情况下,概率会发生什么变化:

通过上图我们可以清晰地看到,随着温度的降低,模型愈来愈越倾向选择”女孩“;另一方面,随着温度的升高,分布变得越来越均匀。当 T = 50 T=50 T=50时,选择”西瓜“的概率已经与选择”女孩“的概率相差无几了。

通常来说,温度与模型的“创造力”有关。但事实并非如此。温度只是调整单词的概率分布。其最终的宏观效果是,在较低的温度下,我们的模型更具确定性,而在较高的温度下,则不那么确定。

典型用例

temperature = 0.0

temperature=0会消除输出的随机性,这会使得GPT的回答稳定不变。

较低的温度适用于需要稳定性、最可能输出(实际输出、分类等)的情况。

temperature = 1.0

temperature=1每次将产生完全不同的输出,且有时输出的结果会非常搞笑。因此,即便是开放式任务,也应该谨慎使用temperature=1。对于故事创作或创意文案生成等任务,温度值设为0.7到0.9之间更为合适。

temperature = 0.75

通常,温度设在0.70–0.90之间是创造性任务最常见的温度。

虽然存在一些关于温度设置的一般性建议,但没有什么是一成不变的。作为GPT-3最重要的设置之一,实际使用中建议多一试下,看看不同设置对输出效果的影响。

总结

本文详细为大家阐述了temperaturetop_p的采样原理,以及它们对模型输出的影响。实际使用中建议只修改其中一个的值,不要两个同时修改。

temperature可以简单得将其理解为“熵”,控制输出的混乱程度(随机性),而top-p可以简单将其理解为候选词列表大小,控制模型所能看到的候选词的多少。实际使用中大家要多尝试不同的值,从而获得最佳输出效果。

另外还有两个参数——frequency_penaltypresence_penalty 对生成输出也有较大影响,请参考《ChatGPT模型中的惩罚机制》。

ChatGPT模型采样算法详解相关推荐

  1. LDA主题模型(算法详解)

    LDA主题模型(算法详解) http://blog.csdn.net/weixin_41090915/article/details/79058768?%3E 一.LDA主题模型简介 LDA(Late ...

  2. 经典算法详解--CART分类决策树、回归树和模型树

    Classification And Regression Tree(CART)是一种很重要的机器学习算法,既可以用于创建分类树(Classification Tree),也可以用于创建回归树(Reg ...

  3. python直线拟合_RANSAC算法详解(附Python拟合直线模型代码)

    之前只是简单了解RANSAC模型,知道它是干什么的.然后今天有个课程设计的报告,上去讲了一下RANSAC,感觉这个东西也没那么复杂,所以今天就总结一些RASAC并用Python实现一下直线拟合. RA ...

  4. 【机器学习】【隐马尔可夫模型-3】后向算法:算法详解+示例讲解+Python实现

    0.前排提示 csdn有些数学公式编辑不出来,所以本博用容易书写的表达式来表示专业数学公式,如: (1)  在本博客中用α<T>(i)来表示 (2)在本博客中用[i=1, N]∑来表示 注 ...

  5. 隐马尔可夫模型之Baum-Welch算法详解

    隐马尔可夫模型之Baum-Welch算法详解 前言 在上篇博文中,我们学习了隐马尔可夫模型的概率计算问题和预测问题,但正当要准备理解学习问题时,发现学习问题中需要EM算法的相关知识,因此,上一周转而学 ...

  6. GrabCut算法详解:从GMM模型说起

    GrabCut算法详解:从GMM模型说起 最近正在学习CG,争取有时间就看点论文写写代码. GrabCut在OpenCv里面是有内置函数的,不过我还是用C++纯手工实现了一边原汁原味的论文,githu ...

  7. 自然语言处理NLP星空智能对话机器人系列:第21章:基于Bayesian Theory的MRC文本理解基础经典模型算法详解

    自然语言处理NLP星空智能对话机器人系列: 第21章:基于Bayesian Theory的MRC文本理解基础经典模型算法详解 1,Bayesian prior在模型训练时候对Weight控制.训练速度 ...

  8. YOLOv5算法详解

    目录 1.需求解读 2.YOLOv5算法简介 3.YOLOv5算法详解 3.1 YOLOv5网络架构 3.2 YOLOv5实现细节详解 3.2.1 YOLOv5基础组件 3.2.2 输入端细节详解 3 ...

  9. CenterNet算法详解

    Objects as Points-论文链接-代码链接 目录 1.需求解读 2.CenterNet算法简介 3.CenterNet算法详解 3.1 CenterNet网络结构 3.2 CenterNe ...

最新文章

  1. spec 2016使用
  2. 关于appcan自动升级功能
  3. 数据库范式的思考以及数据库的设计
  4. 【转】pdf 中如何把几页缩小成一页打印
  5. webRTC+coturn穿透服务器的安装与搭建
  6. 也可以看看GCD(杭州电2504)(gcd)
  7. python_sting字符串的方法及注释
  8. Spring-SpringMVC父子容器
  9. 怎么在电脑安装php文件夹在哪个文件夹,php进行文件上传时找不到临时文件夹怎么办,电脑自动保存的文件在哪里...
  10. flowable6.4.2流程审批后涉及到的表
  11. Hadoop概念学习系列之Hadoop 是什么?(一)
  12. github地址持续收集
  13. webhooks php,GitHub和WebHooks自动部署PHP项目
  14. 分类问题损失函数的信息论解释
  15. 双曲正切函数(tanh)
  16. 用Python写几个小游戏(附源码)
  17. Java多线程系列--“JUC锁”03之 公平锁(一) (r)
  18. 惊!Adam效果不好居然是因为……,Decouple Weight Decay Regulaization阅读笔记
  19. Java——(九)IO流
  20. JPA Specification 自定义查询

热门文章

  1. 【批处理】attrib
  2. 关于xode 12使用cocoapods-packager制作成framework后,编译报错ld: framework not found xxx
  3. 十进制数转换为二进制数 C++
  4. 简单CAD户型图制作过程
  5. liunx系统下载VSCODE教程
  6. 超炫button按钮动画效果
  7. 关于实现订单的相关功能实现
  8. 模拟退火解决旅行商问题
  9. 微信小程序如何获取屏幕宽度
  10. 基于智慧灯杆的智慧社区解决方案,看看智慧灯杆到底能做什么?