引言

本文是Poly-encoder1的阅读笔记。对Bi-encoder/Cross-encoder进行了一个较好的阐述,同时提出了综合两者优点的Poly-encoder。

输入表示

作者使用的预训练输入是 输入句子(INPUT)和标签句子(LABEL )的拼接,它们都以特殊字符[S]包围。

在Reddit数据集上训练时,输入句子是历史对话语句的拼接,对话语句之间以[NEWLINE]特殊字符分隔,标签句子是该段对话中的下一个句子。

核心思想

对于句子对比较任务来说,有两种常用的途径:Cross-encoder和Bi-encoder。

Cross-encoder基于给定的输入句子和标签句子(组成一个句子对,将它们拼接在一起作为输入)进行交叉自注意,通常能获得较高的准确率,但速度较慢。

而Bi-encoder单独地对句子对中的句子进行自注意,分别得到句子编码。由于这种独立性,Bi-encoder可以对候选句子进行缓存,从而在推理时只需要计算输入句子的编码表示即可,大大加快推理速度。但是表现没有Cross-encoder好。

本文作者提出了一种新的Transformer结构,Poly-encoder,学习全局级而不是单词级的自注意特征。

Poly-encoder比Cross-encoder快,同时比Bi-encoder更准确。

同时作者证明选择与下游任务更相关的数据集进行预训练能获得较大的效果提升。

Bi-encoder

Bi-encoder独立地编码上下文和候选句

Bi-encoder允许快速、实时地推理。此时,输入上下文(输入句子)和候选标签句子都被编码成向量:
yctxt=red(T1(ctxt))ycand=red(T2(cand))y_{ctxt}= red(T_1(ctxt))\\ y_{cand} = red(T_2(cand)) yctxt​=red(T1​(ctxt))ycand​=red(T2​(cand))

bi-是英文里面的两个的意思,比如binary(二进制)。

其中T1T_1T1​和T2T_2T2​是两个预训练好的Transformer。它们以同样的权重初始化,但是允许独立更新。red(⋅)red(\cdot)red(⋅)是将Transformer产生的向量序列压缩成一个向量的函数。

假设T(x)=h1,⋯,hnT(x)=h_1,\cdots,h_nT(x)=h1​,⋯,hn​是TransformerTTT的输出。由于输入和标签句子都由特殊字符[S][S][S]包围,因此这里h1h_1h1​对应的就是[S]

作者考虑了三种压缩输出序列到一个向量的方法:

  • 选择Transformer输出的第一个向量(即[S])
  • 计算所有输出向量的均值
  • 计算前MMM个输出的均值

作者通过实验证明,选择输出的第一个向量结果更好。

打分 候选句子通过与上下文句子进行点积计算得分:s(ctxt,candi)=yctxt⋅ycandis(ctxt,cand_i)=y_{ctxt} \cdot y_{cand_i}s(ctxt,candi​)=yctxt​⋅ycandi​​。该网络基于最小化一个交叉熵损失,其中logtis是yctxt⋅ycand1,⋯,yctxt⋅ycandny_{ctxt} \cdot y_{cand_1},\cdots,y_{ctxt} \cdot y_{cand_n}yctxt​⋅ycand1​​,⋯,yctxt​⋅ycandn​​,其中cand1cand_1cand1​是正确标签,剩下的是从训练集中随机选择的负样本。

推理速度 Bi-encoder允许预先计算所有可能候选句子的嵌入表示。只要计算出了上下文嵌入yctxty_{ctxt}yctxt​,就可以利用GPU进行加速计算,同时可以使用FAISS快速找到最相近的句子。

Cross-encoder

Cross-encoder同时在一个Transformer中编码上下文和候选句,产生它们之间更丰富的表示信息但增加了计算量

Cross-encoder允许同时编码输入上下文和候选标签语句,得到一个最终的表示。只需要一个Transformers。我们使用输出的第一个向量作为上下文-候选嵌入:
yctxt,cand=h1=first(ctxt,cand)y_{ctxt,cand} = h_1 = first(ctxt,cand) yctxt,cand​=h1​=first(ctxt,cand)
firstfirstfirst是从输出的向量序列中抽取第一个向量的函数。基于这种设定,Transformer可以在上下文和候选句之间进行交叉的自注意,从而相比Bi-encoder可以得到更丰富的表示。因为候选句子可以和上下文句子在所有编码器层中进行自注意,从而产生一个候选句相关的输入表示。

打分 为了对候选句打分,一个线性层WWW应用到yctxt,candy_{ctxt,cand}yctxt,cand​从而得到一个标量:
s(ctxt,candi)=yctxt,candiWs(ctxt, cand_i) = y_{ctxt, cand_i}W s(ctxt,candi​)=yctxt,candi​​W
该网络也是基于最小化一个交叉熵损失来进行训练,其中logits为s(ctxt,cand1),⋯,s(ctxt,candn)s(ctxt,cand_1),\cdots,s(ctxt,cand_n)s(ctxt,cand1​),⋯,s(ctxt,candn​),cand1cand_1cand1​是正确标签,剩下的是从训练集中随机选取的负样本。

推理速度 在推理时,每个候选句必须与输入上下文进行拼接,然后经过整个模型。这样导致计算速度较慢。

Poly-encoder

Poly-encoder综合了Bi-encoder和Cross-encoder,可以缓存候选句子的表示,同时增加了一个额外的注意力机制从候选句中抽取更多信息

Poly-encoder旨在充分利用Bi-encoder和Cross-encoder两者的优点:

  • 给定的候选标签由Bi-encoder中的一个向量表示,允许预先计算缓存计算结果。
  • 输入上下文与候选语句在Cross-encoder中同时编码,可以抽取更多信息。

Poly-encoder和Bi-encoder一样,使用两个独立的Transformer来计算上下文和候选标签的编码,候选标签可以提前被编码到向量ycandiy_{cand_i}ycandi​​。这样,Poly-encoder可以利用缓存进行加速。

然而,输入上下文通常比候选句子长的多,由多个向量(yctxt1,⋯,yctxtm)(y_{ctxt}^1,\cdots,y_{ctxt}^m)(yctxt1​,⋯,yctxtm​)表示,而不是Bi-encoder中的一个向量。为了提取每个候选句输入上下文的相关部分,作者使用一个以ycandiy_{cand_i}ycandi​​作为query的注意力层:
yctxt=∑iwiyctxtiy_{ctxt} = \sum_i w_i y^i_{ctxt} yctxt​=i∑​wi​yctxti​
其中
(w1,⋯,wm)=softmax(ycand⋅yctxt1,⋯,ycand⋅yctxtm).(w_1,\cdots,w_m) = \text{softmax}(y_{cand}\cdot y_{ctxt}^1, \cdots, y_{cand} \cdot y_{ctxt}^m). (w1​,⋯,wm​)=softmax(ycand​⋅yctxt1​,⋯,ycand​⋅yctxtm​).
由于m<Nm \lt Nm<N,其中NNN是单词的总数,而这里的上下文-候选注意里只在最顶层计算,远快于Cross-encoder所有编码器层的交叉注意力。

问题是如何得到这mmm个上下文向量(yctxt1,⋯,yctxtm)(y_{ctxt}^1,\cdots,y_{ctxt}^m)(yctxt1​,⋯,yctxtm​);作者最终简单地选择输出向量序列中的前mmm个。

Reference


  1. Poly-encoders: architectures and pre-training strategies for fast and accurate multi-sentence scoring ↩︎

[论文笔记]Poly-encoders: architectures and pre-training strategies for fast and accurate multi-sentence相关推荐

  1. Transformer Architectures and Pre-training Strategies for Fast and Accurate Multi-sentence Scoring

    目录 论文导读 课前基础知识 学习目标 知识树 研究背景 初始BERT 研究成果 研究意义 论文结构 摘要 论文精读 模型总览BERT.Poly-encoder BERT的出现 BERT结构 BERT ...

  2. AAAI2018-Long Text Generation via Adversarial Training with Leaked Information论文笔记

    这篇文章主要是名为 LeakGAN 的模型结构,同时处理 D 反馈信息量不足和反馈稀疏的两个问题.LeakGAN 就是一种让鉴别器 D 提供更多信息给生成器 G 的新方式,我自己的笔记: 转自:htt ...

  3. 论文笔记目录(ver2.0)

    1 时间序列 1.1 时间序列预测 论文名称 来源 主要内容 论文笔记:DCRNN (Diffusion Convolutional Recurrent Neural Network: Data-Dr ...

  4. 论文笔记【A Comprehensive Study of Deep Video Action Recognition】

    论文链接:A Comprehensive Study of Deep Video Action Recognition 目录 A Comprehensive Study of Deep Video A ...

  5. 【论文笔记-NER综述】A Survey on Deep Learning for Named Entity Recognition

    本笔记理出来综述中的点,并将大体的论文都列出,方便日后调研使用查找,详细可以看论文. 神经网络的解释: The forward pass com- putes a weighted sum of th ...

  6. 光流 速度_[论文笔记] FlowNet 光流估计

    [论文笔记] FlowNet: Learning Optical Flow with Convolutional Networks 说在前面 个人心得: 1. CNN的光流估计主要是速度上快,之后的v ...

  7. 论文笔记:Distilling the Knowledge

    原文:Distilling the Knowledge in a Neural Network Distilling the Knowledge 1.四个问题 要解决什么问题? 神经网络压缩. 我们都 ...

  8. GAN for NLP (论文笔记及解读

    GAN 自从被提出以来,就广受大家的关注,尤其是在计算机视觉领域引起了很大的反响."深度解读:GAN模型及其在2016年度的进展"[1]一文对过去一年GAN的进展做了详细介绍,十分 ...

  9. 神经稀疏体素场论文笔记

    论文地址:https://proceedings.neurips.cc/paper/2020/file/b4b758962f17808746e9bb832a6fa4b8-Paper.pdf Githu ...

  10. GAN学习历程之CycleGAN论文笔记

    GAN目前发展的很快,成果也很多,从GAN->Pix2pix->CycleGAN 本来是准备看一篇19年一月份ICLR发表的conference paper INSTAGAN,发现这篇论文 ...

最新文章

  1. 为什么需要批判性思维 -- 读《学会提问》
  2. eth0,eth1,eth2,lo是什么
  3. Linux最佳聊天软件:Skype 4.3轻体验
  4. 电子书下载:Illustrated C# 2012 4th
  5. 全国计算机等级考试题库二级C操作题100套(第65套)
  6. leetcode 554. 砖墙
  7. guido python正式发布年份_Python语言适合哪些领域的计算问题? (1.3分)_学小易找答案...
  8. python 打印皮卡丘_用python打印你的宠物小精灵吧
  9. RMAN catalog 的创建和使用
  10. 【容器云】十分钟快速构建 Influxdb+cadvisor+grafana 监控
  11. 禁止“挖矿”!谷歌杀了所有的 Chrome 扩展应用
  12. NLP --- 条件随机场CRF详解
  13. sql 替换字段中的部分字符,替换指定字符
  14. ARP欺骗的艺术 | 断网与监听
  15. zuc算法代码详解_ZUC算法原理及实现过程.doc
  16. android 音乐播放器评测,Android平台四大音乐播放器对比评测
  17. java程序中,如何设置周一为一周的开始?如何设置周一为一周的第一天? 或者说,如何理解java的setFirstDayOfWeek()方法?
  18. linux双网卡连不上网,linux 双网卡配置问题
  19. 全国天气预报信息数据接口 API
  20. Laravel本地Sail开发环境下Phpstorm+浏览器+Postman调试配置

热门文章

  1. Oracle 进程 说明
  2. 【小技巧】自定义asp.net mvc的WebFormViewEngine修改默认的目录结构
  3. Laser Reflections solutions
  4. 20190813 On Java8 第一章 对象的概念
  5. Sql优化之Mysql表分区
  6. 第四章 consul cluster
  7. Blue Jeans - POJ 3080(多串的共同子串)
  8. Android自定义之流式布局
  9. 在windows server 2003服务器上提供NTP时间同步服务
  10. Gatech OMSCS的申请和学习之奥妙