前言
本文是对文章 Learning To Retrieve Prompts for In-Context Learning (NAACL, 2022) 的阅读笔记,论文代码:链接。

文章目录

  • 1. in-context learning
  • 2. 本文工作
  • 3. 模型训练和推理
    • 1)如何产生标记数据
    • 2)如何给候选集合打分
    • 3)训练打分模型
    • 4) 训练推断模型
  • 4. 实验
    • 1)数据集
    • 2)评价指标
    • 3)基线和标准做法
    • 4)模型测试的两种模式
    • 5)实验结果分析
      • LM - as - a - service
      • LM - as - a - proxy
  • 5. 总结
  • 6. 相关知识

1. in-context learning

(此介绍中部分是转载)
in-context learning 是2020年下半年兴起的一个概念。以下是incontext learning的逻辑图。

in-context learning 是一种新的训练模式。在进行测试的时候,将提示(通常是几个与输入相似的example)与输入句子一起输入,然后得到输出。

之前的模型学习到的是一个类似函数的映射,给定输入x,得到输出 y=f(x)。而 in-context learning 学到的不是一个单纯的映射函数,而是要掌握给出答案的 “能力”。也就是,这个模型,通过提示,知道了答案应该是什么样的。

in-context使得模型有着天然的泛化能力和实际部署的潜质,在很多领域,通过合理的构建prompt和选取example,in-context的水平已经接近于比他自己小但是本身不小的模型的能力了(比如说你在GPT3 175B上做in-context learning的性能基本和T5-large 770M全数据训练finetune持平)。追平这件事可以说是大模型落地的福音。

得到的 prompt 将会和测试句子拼接,作为测试句子的前缀输入。如果得到的 prompt 良好,那么通过推断模型的解码就应该得到目标输出。

效果很大程度上取决于 prompt 的质量。

In-context learning is a recent paradigm in natural language understanding, where a large pre-trained language model (LM) observes a test instance and a few training examples as its input, and directly decodes the output without any update to its parameters.

An attractive property of in-context learning is that it provides a single model for multiple language understanding tasks.

2. 本文工作

在过去的工作中,对于 prompt 的选取往往基于相似度,无论是直接通过相似度计算选取,或者训练专门的提取器来提取,都是依据相似度。

本文中不依靠相似度,而是利用一个语言模型来给提示打分,本文认为利用语言模型给提示打分是优于之前的相似度的。

3. 模型训练和推理

1)如何产生标记数据

在训练集中针对每一个训练数据有哪些最适合作为其 prompt 的方法,代价太高。本文针对每一个训练用例,先从测试集中选出一个候选集,然后在候选集中选取 positive examples 和 negative examples,标记之后用于对比学习。

为了选择一个好的候选集,使用无监督的提取器。

提取器 来源 介绍
BM25 Robertson and Zaragoza, 2009 a sparse retriever that relies on surface text similarity
SBERT Reimers and Gurevych, 2019 based on dense sentence encoding

For both, we experimented with passing the retriever the training pair (x, y) or the target sequence y only, and found that using y leads to slightly higher performance.
对于这两种方法,我们都尝试只给检索器传递训练对(x, y)或目标序列y,发现使用y会导致略高的性能。

2)如何给候选集合打分

针对训练集中的每一个数据对(x,y),对于其选出的候选集合 ϵˉ={eˉ1,eˉ2,…,eˉL}\bar{\epsilon} = \{\bar{e}_1, \bar{e}_2, \dots, \bar{e}_L\}ϵˉ={eˉ1,eˉ2,,eˉL},对集合中的每一个候选都利用打分模型进行打分,打分模型如下:
对于候选集中的所有实例,与对应的训练数据(x,y)相似度越高得分越高。最终取候选集中的 top-k 作为 positive examples ,其中的 bottom-k 作为 negative examples 。

3)训练打分模型

训练过程类似DPR (Karpukhin et al., 2020)

得到的输出

  1. Ex(.)E_x(.)Ex(.): input encoder, receives the sequence of input tokens
  2. Ep(.)E_p(.)Ep(.): prompt encoder, receives a candidate prompt, namely, a concatenation of the tokens in an input-output pair

所有的encoder 都是使用 BERT-base 初始化,所有的输出向量都是以 CLS token 的形式给出。

一个训练实例的表示
其中,batch size 是 B。 xix_ixi 是测试,ei+e_i^+ei+是从 xix_ixi 对应的 正例集 ϵpos\epsilon_{pos}ϵpos 中抽样得到,其余 2B−12B-12B1 个均是 xix_ixi 的负例,在这些负例中,有一个从 xix_ixi 对应的 负例集 ϵneg\epsilon_{neg}ϵneg 中抽样得到,其余 2B−22B-22B2 个中,有 B−1B-1B1 个是同一 batch 的其他实例的正例,有 B−1B-1B1 个是同一 batch 的其他实例的负例。(是每个实例各取一个,还是总共取 B−1B-1B1 个?)

定义一个 input 与一个 input-output pair 的相似度为
进而使用对比学习的目标函数:

4) 训练推断模型

在训练了输入编码器和提示编码器之后,我们使用FAISS对整个训练样本集进行了EP(·)编码。

Faiss是Facebook AI团队开源的针对聚类和相似性搜索库,为稠密向量提供高效相似度搜索和聚类,支持十亿级别向量的搜索,是目前最为成熟的近似近邻搜索库。

测试的时候,将 xtestx_{test}xtest 编码为 EX(xtest)E_X(x_{test})EX(xtest),然后从训练数据集中选取 L 个最相似的训练数据,然后将这些编码后的训练数据按照其与编码后的测试数据的内积值的大小顺序排列。构成的提示集 P=(e1,…,eL)\mathcal{P} = (e_1, \dots, e_L)P=(e1,,eL)

L的大小如何确定?
∑i=1L′∣ei∣+∣xtest∣+∣y′∣≤C\sum_{i=1}^{L^{\prime}}|e_i| + |x_{test}| + |y^{\prime}| \le Ci=1Lei+xtest+yCL′≤LL^{\prime} \le LLL
其中,CCC是 inference model 可以接受的最大 token 数,∣y′∣|y^{\prime}|y是期望输出的最大长度。在满足以上条件的情况下,取最大的L′L^{\prime}L

最终,以 greedy decoding 的方式输出为 g([eL′;eL′−1;…;e1;xtest])g([e_{L^{\prime}}; e_{L^{\prime}-1};\dots;e_1;x_{test}])g([eL;eL1;;e1;xtest])
prompt 是作为 xtestx_{test}xtest 的前缀的,也就是说,在它的前缀中,单词的排列方式是按照概率从大到小来的。

greedy decoding,每次选择概率值最大的对应的单词。

4. 实验

模型的两大优势情况:

  1. 当打分模型比推断模型小时,这种小体量的打分模型非常的高效轻量
  2. 当打分模型和推断模型是同一个模型时,即使两个模型相同,此方法也是适用的,当我们无法得到模型参数的时候,这个模型的优势就体现的更加明显。

1)数据集

模型将针对三个 Seq2seq 任务进行测试:

  • BREAKB_{REAK}BREAK: 将复杂的自然语言问题映射到基于语言的意义表示的数据集,其中问题被分解为原子步骤的有序列表。
  • MTOPMT_{OP}MTOP: 语义分析数据集,专注于面向任务的对话,其中命令映射到11个域的复杂嵌套查询。
  • SMCALFLOWSMC_{AL}F_{LOW}SMCALFLOW: 一个面向任务的大型英语数据集,涵盖日历、天气、地点和人员等任务。语义表示是一个数据流程序,它包括API调用、函数组合和复杂的约束。

2)评价指标

EM: Exact Match, 评估推断语言模型的输出和参考输出是否相同。
NEM: Normalized Exact Match, 通过一个基于规则的程序将预测结果和目标结果归一化,然后在归一化后的结果上计算正确字符串匹配。
LF-EM(logical form - exact match): 评估两个含义表达式是否在语义上等价。

3)基线和标准做法

无监督模型

基线模型 描述
RANDOM 随机从训练集中抽样出 prompt
SBERT 利用 paraphrase-mpnet-base-v2 来编码测试语料,并且从训练集中抽取跟测试语料最相似的例子作为 prompt
BM25 经典的 sparse retrieval 方法,是 TF-IDF 的拓展,用其抽取 prompt
BRUTE FORCE 从训练集中随机抽取 200 个训练实例 (x,y) 作为候选集,然后比较 x 与 xtestx_{test}xtest 的相似度,选择相似度高的作为 prompt

有监督模型
有监督基线测试通用的模版:

  • 用BM25抽候选集,候选集大小为 L=50。
  • 使用一些打分函数选出正例集和负例集,正例和负例集的大小均为5。
  • 不同的有监督方法不同在于其自身的打分函数。
基线模型 描述
DR-BM25 使用BM25本身的打分函数来打分,训练的分类器是 dense retriever(向量检索)
CASE-BASED REASONING(CBR) 采用 F1 值弱标记数据,F1 值根据输出yiy_iyiyjy_jyj 中的 token 集合计算
EFFICIENT PROMPT RETRIEVAL 本文的模型,使用 Ru((x,y),D)\mathcal{R}_u((x,y),\mathcal{D})Ru((x,y),D) 抽取候选集, 使用 GPT-NEO 打分

标准模型

基线模型 描述
BM25-ORACLE 在测试的时候,不依据输入向量和训练集实例向量内积来排序,直接用 BM25 以目标输出为参数寻找最相似的实例组成prompt
LM-ORACLE 在测试的时候,对待每个测试实例就像训练实例一样,利用BM25抽取候选集,然后使用 Scoring LM 打分,最终得到prompt

4)模型测试的两种模式

(a) LM - as - a - service (scoring LM 和 inference LM 相同)
scoring LM 和 inference LM 都是 GPT-NEO,在BREAK, MTOP, SMACALFLOW 的全部数据上评估。

(b) LM - as - a - proxy (scoring LM 小于 inference LM)
从 GPT-3 和 CODEX 中随机抽取 1000 个实例,在这个子集上进行评估。

模型 (C=2048C = 2048C=2048)

scoring LM inference LM
GPT-NEO GPT-NEO
GPT-NEO GPT-J
GPT-NEO GPT-3
GPT-NEO CODEX

5)实验结果分析

LM - as - a - service

Table 2

  • 每一列来看,EPR都是最好的
  • BM25 超过 SBERT 说明利用BM25 来提取候选集比用 SBERT 好。
  • 随机抽取效果很差
  • BruteForce 表现很差可能是因为随机抽取 200 候选覆盖面太窄,信息太少
  • EPR 的效果和 BM25-ORACLE 不相上下,甚至更好,说明了这种用语言模型打分的形式比用文本表面的相似度要更好。
  • LM-ORACLE 的效果比 EPR 好,说明打分语言模型提供的监督很强,依照此监督信号训练出的更好的提取器可以提升表现。
  • Table 3 佐证了table 2 的结论

Table 4

one-shot setup 测试:prompt 只取得分最高的那个例子
ANYCORRECT-ORACLE:测试 BM25提供的所有候选,是否在其提示下得到了正确的输出。

  • 通过实验发现,得到了很好的效果,EPR 比 CBR 高了 8.5%,比BM25-ORACLE 也高了5%。
  • ANYCORRECT-ORACLE 的得分高于 50%,说明 BM25 提供的候选质量很高。同时,它的得分比 LM - ORACLE 高很多,说明通过更好的 scoring model,可以提升整体表现。

LM - as - a - proxy

scoring LM 是GPT-NEO,inference LM 是一个更大的 LM。

Table 5

相较于其他模型,EPR基本都有一定提升
同时使用 GPT-J 作为打分LM和推理LM,31.5 -> 33.6,对CODEX来说 29.5 -> 29.3。因此,用更小的 LM(GPT-NEO)效率更高。

使用不同的模型作为 推理LM 时,表现有所不同,这是因为预训练模型的差异。

Table 6

主要观察第三个例子,CBR 提取的没有出现 code,且没有体现出最多或者最少。

Figure 3

对利用 EPR 模型从 BREAK 数据集中学到的 embeddings 进行可视化展示。

t-SNE 是一种非线性降维算法,非常适用于高维数据降维到 2 维或者 3 维,进行可视化。在实际应用中,t-SNE很少用于降维,主要用于可视化

OPTICS算法也是一种基于密度的聚类算法

对聚类的研究表明,EPR既能捕捉词汇相似性,又能捕捉结构相似性。

Table 7

研究输出的结果究竟是直接复制了 prompt 中的输出,还是说组合了 prompt 中不同例子的输出。

定义了两种复制

  1. exact copying:产生的输出完全匹配 prompt 中的一个实例的输出。
  2. abstract copying:输出的结构是否和 prompt 中的某个实例的结构相同。对目标输出以及 prompt 中实例的逻辑形式,将其出现的实体和函数参数用 [masked] 标志替换。在替换之后,如果目标输出在 prompt 的实例中出现,那么就产生了复制。
  • 在 MTOP 和 SMACL 数据集上,Abstract copying 达到了80%以上,而且,出现了 copy 现象的部分准确率大大高于未出现 copy 现象的部分。可随机举例。
  • 同时值得关注的是,对于没有出现copy现象的部分,其准确率也难以忽视,这说明输出生成了新的结构。

Figure 4

求证,当出现复制情况的时候,容易被复制的是高得分的实例,还是说是全局性的。

为了得到此数据,将出现复制情况的实例对应的 prompt 中的实例定义一个举例,按照得分从高到底排列,然后除以提示的数量,就得到了一个归一化的距离,通过实验发现,距离越近,也就是得分越高的实例被复制的概率越高。

5. 总结

6. 相关知识

dense retriever & sparse retriever

CLS编码介绍

【KBQA-2】 Learning To Retrieve Prompts for In-Context Learning相关推荐

  1. 【论文阅读】DouZero: Mastering DouDizhu with Self-Play Deep Reinforcement Learning

    [论文阅读]DouZero: Mastering DouDizhu with Self-Play Deep Reinforcement Learning 1 本文解决了什么问题? 斗地主是一个非常具有 ...

  2. 【机器学习笔记】可解释机器学习-学习笔记 Interpretable Machine Learning (Deep Learning)

    [机器学习笔记]可解释机器学习-学习笔记 Interpretable Machine Learning (Deep Learning) 目录 [机器学习笔记]可解释机器学习-学习笔记 Interpre ...

  3. 【文献学习】Complex-Valued Convolutions for Modulation Recognition using Deep Learning

    目录 1 简介和创新点 1.1 DL中复数的处理综述 1.2 DL对于调制分类的综述 2 系统模型 2.1 二维实数卷积 2.2 整合到现有的DL架构中 3 模型参数 4 实验分析 5 思考和收获哦 ...

  4. 【论文笔记】Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized

    论文 论文题目:Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personaliz ...

  5. 【Paper Reading】BatchCrypt: Efficient Homomorphic Encryption for Cross-Silo Federated Learning

    BatchCrypt: Efficient Homomorphic Encryption for Cross-Silo Federated Learning 原文来源:[ATC 2020] Batch ...

  6. 【深度学习】多任务学习概览(An Overview of Multi-task Learning in Deep Neural Networks)

    1. 前言 在机器学习中,我们通常关心优化某一特定指标,不管这个指标是一个标准值,还是企业KPI.为了达到这个目标,我们训练单一模型或多个模型集合来完成指定得任务.然后,我们通过精细调参,来改进模型直 ...

  7. 【步态识别】GLN 算法学习《Gait Lateral Network: Learning Discriminative and Compact Representations for Gait R》

    目录 1. 论文&代码源 2. 论文亮点 3. 框架解读 3.1 横向连接☆ 3.2 紧凑块 3.3 训练策略 3.3.1 三元组损失 3.3.2 交叉熵损失 3.3.3 总损失函数 4. 实 ...

  8. 【论文精读】Improving Extreme Low-Light Image Denoising via Residual Learning

    通过残差学习改善极低光图像去噪 摘要 1.引言 2.相关文献 2.1.图像去噪 2.2.低光图像增强 3.我们的方法 4.实验 4.1.数据集和实验设置 4.2.主观质量 4.2.1.去噪 4.2.2 ...

  9. 【资源分享】639页《深度学习:Deep Learning》硬核课程PPT

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达!    课程名称 Deep Learning    课程地址 https://git ...

最新文章

  1. 番茄时间管理和四象限工作法完美搭配造就职场神器
  2. Access 字段拼接(UPDATE 数据追加)
  3. UVA10970大块巧克力
  4. 【MM配置】Delivery Costs 交货成本
  5. IDEA中添加tomcat服务器和创建一个新的web项目
  6. navicat运行sql文件慢_SQL进阶之路——入门
  7. 商家 APP 如何接入新版支付宝支付,老版本商家如何升级
  8. java多线程实例_要把Java吃透您得先吃透这些基本概念
  9. 姿态坐标c语言,判断 AR 中坐标系的姿态和位置的简单方法
  10. CSS技巧之数字美化为机械字体样式
  11. android长截图工具下载,一键长截屏下载-一键长截屏 安卓版v1.0.0-PC6安卓网
  12. linuxi下的做图工具——gnuplot安装
  13. TP5在json入库多出来反斜杠
  14. javascript 百度百科
  15. 计算机usb无法读取u盘启动,电脑不能识别U盘PE的解决方法
  16. ​区块链公链“三元悖论”专题系列之去中心化(Decentralization)
  17. 人工智能专业应不应该读博士?
  18. python画五角星代码_Python GUI 编程tkinter--画五角星和简单的动画制作
  19. Linux和Redis的自学笔记总结
  20. Cordova打包Scratch为APP

热门文章

  1. 阿里云EMAS 4月产品动态
  2. MediaPlayer属性大全
  3. 程序员找工作的各种坑…… 及防坑指南
  4. 算法--100盏灯问题
  5. H3C交换机bootroom菜单清除Console密码
  6. 云辅助隐私集合求交(Server-Aided PSI)协议介绍
  7. python matplotlib 画图 不显示中文 中文乱码 设置中文字体
  8. 美国经济危机日趋严重,科技行业裁员已达10万
  9. SVG公众号排版『适配深色模式图片二维码可识别可点击』模板代码
  10. RFIC4463_F1