UniT:基于统一Transformer的多模态多任务学习 《UniT:Multimodal Multitask Learning with a Unified Transformer》

论文地址:https://arxiv.org/pdf/2102.10772.pdf

相关博客:
【自然语言处理】【多模态】多模态综述:视觉语言预训练模型
【自然语言处理】【多模态】CLIP:从自然语言监督中学习可迁移视觉模型
【自然语言处理】【多模态】ViT-BERT:在非图像文本对数据上预训练统一基础模型
【自然语言处理】【多模态】BLIP:面向统一视觉语言理解和生成的自举语言图像预训练
【自然语言处理】【多模态】FLAVA:一个基础语言和视觉对齐模型
【自然语言处理】【多模态】SIMVLM:基于弱监督的简单视觉语言模型预训练
【自然语言处理】【多模态】UniT:基于统一Transformer的多模态多任务学习
【自然语言处理】【多模态】Product1M:基于跨模态预训练的弱监督实例级产品检索
【自然语言处理】【多模态】ALBEF:基于动量蒸馏的视觉语言表示学习
【自然语言处理】【多模态】VinVL:回顾视觉语言模型中的视觉表示
【自然语言处理】【多模态】OFA:通过简单的sequence-to-sequence学习框架统一架构、任务和模态
【自然语言处理】【多模态】Zero&R2D2:大规模中文跨模态基准和视觉语言框架

一、简介

​ Transformer\text{Transformer}Transformer在各个领域都获得的巨大的成功,包括但不限于自然语言、图像、视频和音频。先前的工作表明,在大规模语料上预训练的Transformer\text{Transformer}Transformer能够学习到有益于下游广泛语言任务的向量表示。在视觉领域,基于Transformer\text{Transformer}Transformer的模型也在图像分类、目标检测和全景分割上实现了很好的效果。除了建模单一模态外,Transformer\text{Transformer}Transformer模型也在VQA\text{VQA}VQA等联合视觉-语言推理任务上实现了很好的表现。

​ 然而,尽管Transformer\text{Transformer}Transformer在特定领域的应用中达成了不错的成就,但是基于Transformer\text{Transformer}Transformer的跨领域链接不同任务的工作并不多。在目睹了Transformer\text{Transformer}Transformer的成功后,各种问题自然会出现:用于自然语言推理训练的Transformer\text{Transformer}Transformer模型是否也能在图像上执行目标检测,或者说基于Transformer\text{Transformer}Transformer的图像编码器是否能用来进行检测文本蕴含?总的来说,是否能够建立单个模型来同时处理不同领域的各种任务,向通用人工智能前进一步?先前的工作尝试解决这些问题,但是存在一定的限制:

  • 仅应用在单个领域或者特定模态的任务上;ViT\text{ViT}ViT和DETR\text{DETR}DETR仅专注在视觉任务上,BERT\text{BERT}BERT以及延伸的工作仅能处理语言任务,而VisualBERT\text{VisualBERT}VisualBERT和VILBERT\text{VILBERT}VILBERT等仅在特定的视觉-语言多模态领域。
  • 对于每个任务涉及任务相关的微调,没有在任务间利用共享参数,通常NNN个任务具有NNN倍的参数,例如:必须使用BERT\text{BERT}BERT分别为每个任务的模型进行微调。
  • 仅在单个领域中的相关或者相似任务上执行多任务,有时会使用硬编码的训练策略;例如,T5\text{T5}T5仅在语言领域任务上工作,而VILBERT-MT\text{VILBERT-MT}VILBERT-MT则仅在相关的视觉-语言任务上。

​ 在本文中,作者构建了一个称为UniT\text{UniT}UniT的统一Transformer\text{Transformer}Transformer模型,其将图像和(或)文本作为输入,然而在视觉感知、自然语言理解和联合视觉-语言推理的各种任务上进行联合训练。UniT\text{UniT}UniT由Transformer\text{Transformer}Transformer编码器构成,其能够将每个输入模态编码为hidden states,在编码后的输入模态上应用一个Transformer\text{Transformer}Transformer解码器,然后在解码器的输出上应用一个任务相关的输出头来对每个任务进行预测。相较于先前基于Transformer\text{Transformer}Transformer的多任务学习工作,UniT\text{UniT}UniT在更广阔的任务上实现了与先前工作相当的效果,不仅仅VQA\text{VQA}VQA这样的视觉语言任务,也有纯视觉和纯语言任务。本文的贡献如下:

  • 提出了UniT\text{UniT}UniT,一个统一的Transformer\text{Transformer}Transformer编码器解码器架构,能够使用较少的参数来同时处理多任务和多领域;
  • 学习视觉领域、文本领域和交叉领域的最突出的任务,包括目标检测、VQA\text{VQA}VQA、视觉蕴含以及GLUE\text{GLUE}GLUE基准上的自然语言理解任务,包括QNLI\text{QNLI}QNLI、MNLI\text{MNLI}MNLI、QQP\text{QQP}QQP和SST-2\text{SST-2}SST-2。证明了这些多样的任务能够同时学习,并且在本文的训练方案下能够适当收敛;
  • 通过对各种任务的分析,展示了像VQA\text{VQA}VQA和视觉蕴含这样的多模态任务能够从多模态多任务训练上收益。

二、UniT\text{UniT}UniT:跨领域统一Transformer\text{Transformer}Transformer

​ 本工作中,使用统一的单个模型联合学习跨不同模态的多任务。模型UniT\text{UniT}UniT是建立在基于Transformer\text{Transformer}Transformer编码器-解码器架构上的,由每个模态一个编码器和统一的解码器组成。上图是整个UniT\text{UniT}UniT的架构。

​ 本文考虑图像和文本两种输入模态。对于图像上的基于Transformer\text{Transformer}Transformer编码器,首先会应用卷积神经网络来抽取一个视觉feature map\text{feature map}feature map,然后其被Transformer\text{Transformer}Transformer编码器进一步编码为合并了全局上下文信息的hidden state\text{hidden state}hidden state序列。对于语言输入,这里使用12层的uncased版本的BERT\text{BERT}BERT,其将输入的单词序列也编码为hidden state\text{hidden state}hidden state序列。在将输入编码为hidden state\text{hidden state}hidden state序列后,将Transformer\text{Transformer}Transformer解码器应用在单个模态上向量序列上,或者多个模态拼接的向量序列上(这取决于任务是单模态还是多模态)。作者在所有任务上测试了分离解码器和共享解码器。最终,从Transformer\text{Transformer}Transformer解码器获得的表示被传递至任务相关的头,并输出最终的预测值。由于UniT\text{UniT}UniT的简单性,其可以轻易扩展到更多模态和输入上。

​ 作者实验表明,UniT\text{UniT}UniT可以在8个数据集上联合学习7个任务。

2.1 图像编码器

​ 单独的视觉任务和"视觉-语言"任务需要感知和理解图像III。UniT\text{UniT}UniT中使用卷积神经网络后跟一个Transformer\text{Transformer}Transformer编码器来编码图像III,将其转换为视觉编码hidden state\text{hidden state}hidden state列表:hv={h1v,h2v,…,hLv}\textbf{h}^v=\{h_1^v,h_2^v,\dots,h_L^v\}hv={h1v​,h2v​,…,hLv​}。图像编码的过程是受DETR\text{DETR}DETR启发。首先,使用将卷积神经网络BBB应用在输入图像上,抽取出一个尺寸为Hv×Wv×dvbH_v\times W_v\times d_v^bHv​×Wv​×dvb​的feature map\text{feature map}feature map xv\textbf{x}^vxv:
xv=B(I)(1)\textbf{x}^v=B(I) \tag{1} xv=B(I)(1)
在实现中,卷积网络使用ResNet-50\text{ResNet-50}ResNet-50并在目标检测任务上进行预训练。

​ 为了进一步编码出尺寸为L×dveL\times d_v^eL×dve​的视觉hidden state\text{hidden state}hidden state hv\textbf{h}^vhv,在xv\textbf{x}^vxv上应用一个具有NvN_vNv​层且hidden size\text{hidden size}hidden size为dved_v^edve​的Transformer\text{Transformer}Transformer编码器EvE_vEv​,其中L=Hv×WvL=H_v\times W_vL=Hv​×Wv​是视觉hidden state\text{hidden state}hidden state的长度。此外,给定不同的任务可能需要抽取不同类型的信息,因此在Transformer\text{Transformer}Transformer编码器中添加一个任务相关嵌入向量wvtaskw_v^{task}wvtask​,这允许抽取任务相关的信息
hv={h1v,h2v,…,hLv}=Ev(Pb→e(xv),wvtask)(2)\textbf{h}^v=\{h_1^v,h_2^v,\dots,h_L^v\}=E_v(P_{b\rightarrow e}(\textbf{x}^v),w_v^{task}) \tag{2} hv={h1v​,h2v​,…,hLv​}=Ev​(Pb→e​(xv),wvtask​)(2)
Pb→eP_{b\rightarrow e}Pb→e​是一个将视觉特征维度dvbd_v^bdvb​投影至编码器hidden\text{hidden}hidden维度dved_v^edve​的线性投影层。视觉Transformer\text{Transformer}Transformer编码器EvE_vEv​的构建遵循DETR\text{DETR}DETR,其中位置编码会被添加至feature map\text{feature map}feature map。任务相关的token wtaskw^{task}wtask是一个维度为dved_v^edve​的可学习参数,其被合并至视觉特征序列Pb→eP_{b\rightarrow e}Pb→e​的开始。

2.2 文本编码器

​ 像QNLI,MNLI,QQP,SST-2GLUE基准,以及VQAvisual entailment视觉语言推理任务都会提供文本输入。这里使用BERT来编码文本输入。

​ 给定输入文本,以与BERT相同的方法将其转换为长度为S的token序列{w1,…,wS}\{w_1,\dots,w_S\}{w1​,…,wS​},其中w1=[CLS]w_1=\text{[CLS]}w1​=[CLS]。这个token序列会被输入至预训练BERT中来抽取尺寸为S×dteS\times d_t^eS×dte​的hidden state\text{hidden state}hidden state ht\textbf{h}^tht,其中dted_t^edte​是BERThidden size。类似于图像编码器,文本编码器也会token序列前添加一个可学习任务嵌入向量wttaskw_t^{task}wttask​。
ht={h1t,h2t,…,hSt}=BERT({w1,…,wS},wttask)(3)\textbf{h}^t=\{h_1^t,h_2^t,\dots,h_S^t\}=\text{BERT}(\{w_1,\dots,w_S\},w_t^{task}) \tag{3} ht={h1t​,h2t​,…,hSt​}=BERT({w1​,…,wS​},wttask​)(3)
然而,在实践中发现仅保留ht\textbf{h}^tht中[CLS]对应的向量来作为解码器的输入就能达到同样的效果。

​ 在本文的实现中,使用BERT-base-uncased,其dte=768d_t^e=768dte​=768且Nt=12N_t=12Nt​=12。

2.3 领域不可知UniT\text{UniT}UniT解码器

​ 在将输入模态编码后,应用一个hidden size为dtdd_t^ddtd​且具有 NdN_dNd​层的Transformer\text{Transformer}Transformer解码器DDD,该解码器会输出一个hidden state序列hdec\textbf{h}^{dec}hdec,然后用于每个任务的预测。不同于文本和图像编码器,每个模态都有一个具体的架构,解码器在所有任务上都使用相同的领域不可知Transformer\text{Transformer}Transformer解码器。

​ 对于纯视觉任务,解码器应用在编码后的图像henc=hv\textbf{h}^{enc}=\textbf{h}^vhenc=hv;对于纯语言任务,解码器应用在编码后 的文本henc=ht\textbf{h}^{enc}=\textbf{h}^thenc=ht;对于视觉语言联合任务,将两种模态合并至单个输入henc=concat(hv,ht)\textbf{h}^{enc}=\text{concat}(\textbf{h}^v,\textbf{h}^t)henc=concat(hv,ht)。

​ Transformer\text{Transformer}Transformer解码器DDD将编码后的输入序列henc\textbf{h}^{enc}henc和一个长度为qqq的任务相关的query嵌入序列qtask\textbf{q}^{task}qtask。Transformer\text{Transformer}Transformer解码器第lll层会输出一个解码序列hdec,l\textbf{h}^{dec,l}hdec,l,其长度与qtask\textbf{q}^{task}qtask相同为qqq
{hdec,l}=D(henc,qtask)(4)\{\textbf{h}^{dec,l}\}=D(\textbf{h}^{enc},\textbf{q}^{task}) \tag{4} {hdec,l}=D(henc,qtask)(4)
​ 解码器的架构同DETR中实现的解码器。在解码器的第lll层,自注意力机制被应用在解码的hdec,l\textbf{h}^{dec,l}hdec,l,交叉注意力被用于编码输入模态henc\textbf{h}^{enc}henc。

​ 在实现时,要么对所有任务使用单个共享的解码器DsharedD^{shared}Dshared,或者为每个具体的任务ttt使用分离解码器DtsepD_t^{sep}Dtsep​。

2.4 任务相关的输出头

​ 每个任务ttt的预测头被应用在解码hidden state {hdec,l}\{\textbf{h}^{dec,l}\}{hdec,l}。对于目标检测任务,使用分类头来产生分类概率输出,以及一个box头来为{1,…,q}\{1,\dots,q\}{1,…,q}中的每个位置产生bounding box。分类头和box头的实现如同DETR。对于每个box上具有属性标签的数据集,实现类似BUTD中的属性分类头 。

​ 类别头和box头的输出会被后处理为object bounding box。对解码器所有层lll的hidden state hdec,l\textbf{h}^{dec,l}hdec,l上都会应用这些头
cl=class_head(hdec,l)bl=box_head(hdec,l)al=attr_head(hdec,l,cl)\begin{aligned} \textbf{c}^l&=\text{class\_head}(\textbf{h}^{dec,l}) \\ \textbf{b}^l&=\text{box\_head}(\textbf{h}^{dec,l}) \\ \textbf{a}^l&=\text{attr\_head}(\textbf{h}^{dec,l},\textbf{c}^l) \\ \end{aligned} clblal​=class_head(hdec,l)=box_head(hdec,l)=attr_head(hdec,l,cl)​
其中,cl,bl,al\textbf{c}^l,\textbf{b}^l,\textbf{a}^lcl,bl,al是类别、box和属性的输出序列,所有的长度均为qqq,与query嵌入qtask\textbf{q}^{task}qtask相同。

​ 在测试时,仅使用从解码器顶层得到的预测值hdec,Nd\textbf{h}^{dec,N_d}hdec,Nd​。因此不同的检测数据集通常有不同数量的类别,每个数据集都有自己的类别头、box头和属性头。在cl\textbf{c}^lcl和bl\textbf{b}^lbl上应用的损失函数同DETR,在al\textbf{a}^lal上的属性损失函数同BUTD

​ 本文中所有的任务,包括:视觉问答、visual entailment和自然语言理解(QNLI,QQP,MNLI,SST-2)\text{(QNLI,QQP,MNLI,SST-2)}(QNLI,QQP,MNLI,SST-2)等,都能被转换为任务ttt上的ctc_tct​类别分类任务。在解码器顶层的第1个hidden state h1dec,Nd\textbf{h}_1^{dec,N_d}h1dec,Nd​​上应用任务相关的分类器,并为任务ttt输出一个尺寸为ctc_tct​的分类预测值p\textbf{p}p。

​ 为了预测输出类别,使用具有GeLU激活函数的两层MLP\text{MLP}MLP,且输出维度等于解码器hidden size。使用预测值p\textbf{p}p和真实标签t\textbf{t}t计算交叉熵损失函数来训练模型
p=W1⋅GeLU(W2⋅h1dec,Nd+b2)+b1loss=CrossEntropyLoss(p,t)\begin{aligned} \textbf{p}&=W_1\cdot\text{GeLU}(W_2\cdot\textbf{h}_1^{dec,N_d}+b_2)+b_1 \\ \text{loss}&=\text{CrossEntropyLoss}(\textbf{p,t}) \end{aligned} ploss​=W1​⋅GeLU(W2​⋅h1dec,Nd​​+b2​)+b1​=CrossEntropyLoss(p,t)​

2.5 训练

​ 在多个任务上联合训练UniT\text{UniT}UniT。在训练中的每次迭代,随机的选择一个任务和数据集来填充batch。根据数据集的大小和经验来人工指定每个任务的抽样概率。在本文的实现中,模型在64块Nvidia Volta V100-SXM2-32GBGPU上进行训练,batch size为64。使用具有学习率为5e-5的加权Adam优化器。

三、实验

四、总结

  • 单纯将两个模态的模型进行联合训练,理论上没有太多可以借鉴的;
  • 实验结果以及训练过程具有借鉴意义。

【自然语言处理】【多模态】UniT:基于统一Transformer的多模态多任务学习相关推荐

  1. IJCAI 2019 | 为推荐系统生成高质量的文本解释:基于互注意力机制的多任务学习模型...

    编者按:在个性化推荐系统中,如果能在提高推荐准确性的同时生成高质量的文本解释,将更容易获得用户的"芳心".然而,现有方法通常将两者分开优化,或只优化其中一个目标.为了同时兼顾二者, ...

  2. 【自然语言处理】【多模态】OFA:通过简单的sequence-to-sequence学习框架统一架构、任务和模态

    OFA:通过简单的sequence-to-sequence学习框架统一架构.任务和模态 <Unifying Architectures, Task, and Modalities through ...

  3. 图模型+Bert香不香?完全基于注意力机制的图表征学习模型Graph-Bert

    作者 | Jiawei Zhang.Haopeng Zhang.Congying Xia.Li Sun 译者 | 凯隐 编辑 | Jane 出品 | AI科技大本营(ID:rgznai100) [导读 ...

  4. Facebook AI的多任务多模态的统一Transformer

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者:Syn ...

  5. 来自Facebook AI的多任务多模态的统一Transformer:向更通用的智能迈出了一步

    点击上方,选择星标或置顶,不定期资源大放送! 阅读大概需要5分钟 Follow小博主,每天更新前沿干货 作者:Synced 编译:ronghuaiyang 导读 一个模型完成了CV,NLP方向的7个任 ...

  6. 【自然语言处理】【文本生成】UniLM:用于自然语言理解和生成的统一语言模型预训练

    UniLM:用于自然语言理解和生成的统一语言模型预训练 <Unified Language Model Pre-training for Natural Language Understandi ...

  7. 没有什么多模态任务是一层Transformer解决不了的!

    文 | 子龙 曾几何时,多模态预训练已经不是一个新的话题,各大顶会诸多论文仿佛搭上Visual和BERT,就能成功paper+=1,VisualBERT.ViLBERT层出不穷,傻傻分不清楚..... ...

  8. 2022最新!基于深度transformer的时间序列异常检测模型

    点击上方"python与机器智能",选择"星标"公众号 重磅干货,第一时间送达 论文:TranAD: Deep Transformer Networks for ...

  9. 内容 AI:建立统一的跨媒体多模态内容理解内核

    作者:zixunsun@tencent.com Jeff Dean 谈 2020 年机器学习趋势:多任务和多模式学习将成为突破口 2019 年下半年,CDG 广告. CSIG 音视频,IEG 内容推荐 ...

最新文章

  1. C语言随机字母生成,C++ 随机数字以及随机数字加字母生成的案例
  2. 人工智能的三大教父,谱写了一段关于勇气的寓言
  3. 性能测试场景设计之用户模式设置
  4. git的简单操作命令
  5. hdu2147 kiki's game(巴什博弈java)
  6. 你的专属云资源管家!阿里云正式对外发布云解析PrivateZone!
  7. leetcode - 688. “马”在棋盘上的概率
  8. sqlerror.java.1055,请问大佬,eclipse连接数据库出现这个错误怎么办
  9. load,initialize方法
  10. 从java库学设计模式_java I/O库的设计模式
  11. 关于 javadoc
  12. 编译原理第一章笔记--绪论
  13. STM32 IIC协议详解
  14. python图像风格迁移教程_Python+OpenCV图像风格迁移的实现方法讲解
  15. r语言 html 变为ppt,如何用R来定制个性化PPT
  16. 【LoRa点对点通信与控制】
  17. Vijos P1008 篝火晚会
  18. 婴儿体重不用计算机怎么算,测量宝宝体重计算器
  19. Excel表格防止重复录入数据
  20. 虚幻引擎图文笔记:调整网格的光照贴图分辨率(Light Map Res)改善光照烘焙质量

热门文章

  1. redeclared as different kind of symbol
  2. 【微信小程序控制硬件②】 开始微信小程序之旅,导入小程序Mqtt客户端源码,实现简单的验证和通讯于服务器.(附带源码)
  3. javaul材质包下载_我的世界手机版0.11.0faithul原版高清材质包
  4. 精品软件 推荐 Office 文档专用压缩工具 - NXPowerLite 6.0.5 中文便携版
  5. 2023麦肯锡中国汽车消费者洞察:消费大潮引领智能电动汽车创新, 车企加速转型应对产业格局重塑...
  6. OpenCV开发笔记(一):OpenCV介绍、编译
  7. 关于s:if的使用记载
  8. 微型USB电缆的全球与中国市场2022-2028年:技术、参与者、趋势、市场规模及占有率研究报告
  9. CVPR2020论文和代码
  10. Deep Learning for Automated Contouring of Primary Tumor Volumes by MRI for Nasopharyngeal Carcinoma