©PaperWeekly 原创 · 作者 | 苏剑林

单位 | 追一科技

研究方向 | NLP、神经网络

能提升模型性能的方法有很多,多任务学习(Multi-Task Learning)也是其中一种。简单来说,多任务学习是希望将多个相关的任务共同训练,希望不同任务之间能够相互补充和促进,从而获得单任务上更好的效果(准确率、鲁棒性等)。然而,多任务学习并不是所有任务堆起来就能生效那么简单,如何平衡每个任务的训练,使得各个任务都尽量获得有益的提升,依然是值得研究的课题。

最近,笔者机缘巧合之下,也进行了一些多任务学习的尝试,借机也学习了相关内容,在此挑部分结果与大家交流和讨论。

加权求和

从损失函数的层面看,多任务学习就是有多个损失函数 ,一般情况下它们有大量的共享参数、少量的独立参数,而我们的目标是让每个损失函数都尽可能地小。为此,我们引入权重,通过加权求和的方式将它转化为如下损失函数的单任务学习:

在这个视角下,多任务学习的主要难点就是如何确定各个 了。

初始状态

按道理,在没有任务先验和偏见的情况下,最自然的选择就是平等对待每个任务,即。然而,事实上每个任务可能有很大差别,比如不同类别数的分类任务混合、分类与回归任务混合、分类与生成任务混合等等,从物理的角度看,每个损失函数的量纲和量级都不一样,直接相加是没有意义的。

如果我们将每个损失函数看成具有不同量纲的物理量,那么从“无量纲化”的思想出发,我们可以用损失函数的初始值倒数作为权重,即

其中 表示任务 的初始损失值。该式关于每个 是“齐次”的,所以它的一个明显优点是缩放不变性,即如果让任务 的损失乘上一个常数,那么结果不会变化。此外,由于每个损失都除以了自身的初始值,较大的损失会缩小,较小的损失会放大,从而使得每个损失能够大致得到平衡。

那么,怎么估计 呢?最直接的方法当然是直接拿几个 batch 的数据来估算一下。除此之外,我们可以基于一些假设得到一个理论值。比如,在主流的初始化之下,我们可以认为初始模型(加激活函数之前)的输出是一个零向量,如果加上 softmax 则是均匀分布,那么对于一个“ 分类+交叉熵”问题,它的初始损失就是 ;对于“回归+ L2 损失”问题,则可以用零向量来估计初始损失,即 , 是训练集的全体标签。

先验状态

用初始损失的一个问题是初始状态不一定能很好地反应当前任务的学习难度,更好的方案应该是将“初始状态”改为“先验状态”:

比如,如果 分类中每个类的频率分别是 (先验分布),那么虽然初始状态的预测分布为均匀分布,但我们可以合理地认为模型可以很容易学会将每个样本的结果都预测为 ,此时模型的损失为熵

某种意义上来说,“先验分布”比“初始分布”更能体现出“初始”的本质,它是“就算模型啥都学不会,也知道按照先验分布来随机出结果”的体现,所以此时的损失值更能代表当前任务的初始难度,因此用 代替 应该更加合理;类似地,对于“回归+L2损失”问题,它的先验结果应该是全体标签的期望 ,所以我们用 代替 ,有望取得更合理的结果。

动态调节

不管是用初始状态的式(2)还是先验状态的式(3),它们的任务权重在确定之后就保持不变了,并且它们确定权重的方法不依赖于学习过程。然而,尽管我们可以通过先验分布等信息简单感知一下学习难度,但究竟有多难其实要真正去学习才知道,所以更合理的方案应该是根据训练进程动态地调整权重。

实时状态

纵观前文,式(2)和式(3)的核心思想都是用损失值的倒数来作为任务权重,那么能不能干脆用“实时”的损失值倒数来实现动态调整权重?即:

这里的 是 的简写。在这个方案中,每个任务的损失函数都被调整恒为 1,所以不管是量纲还是量级上都是一致的。由于 算子的存在,虽然损失恒为 1,但梯度并非恒为 0:

简单来说就是加上 算子后,它的值不变,但是导数为 0,所以最终结果就是以动态权重 来实时调整了梯度的比例。很多“民间实验”表明,式(5)确实在多数情况下都可以作为一个相当不错的 baseline。

等价梯度

我们可以从另一个角度来看该方案。从式(6)我们可以得到:

因此从梯度上看,式(5)与 没有实质区别,而我们进一步有:

由于 是单调递增的,所以式(5)与下式在梯度方向上是一致:

广义平均

显然,上式正是 的“几何平均”,而如果我们约定 恒等于 ,那么原始的式(1)就是 的“代数平均”。也就是说,我们发现这一系列的推导其实隐藏了从代数平均到几何平均的转变,这启发我们或许可以考虑“广义平均”:

也就是将每个损失函数算 次方后再平均最后再开 次方,这里的 可以是任意实数,代数平均对应 ,而几何平均对应 (需要取极限)。可以证明, 是关于 的单调递增函数,并且有:

这就意味着,当 增大时,模型愈发关心损失中的最大值,反之则更关心损失中的最小值。这样一来,虽然依然存在超参数 要调整,但是相比于原始的式(1),超参数的个数已经从 个变为只有 1 个,简化了调参过程。

平移不变

重新回顾式(2)、式(3)和式(5),它们都是通过每个任务损失除以自身的某个状态来调节权重,并且获得了缩放不变性。然而,尽管它们都具备了缩放不变性,但却失去了更基本的“平移不变性”,也就是说,如果每个损失都加上一个常数,(2)、式(3)和式(5)的梯度方向是有可能改变的,这对于优化来说并不是一个好消息,因为原则上来说常数没有带来任何有意义的信息,优化结果不应该随之改变。

理想目标

一方面,我们用损失函数(的某个状态)的倒数作为当前任务的权重,但损失函数的导数不具备平移不变性;另一方面,损失函数可以理解为当前模型与目标状态的距离,而梯度下降本质上是在寻找梯度为 0 的点,所以梯度的模长其实也能起到类似作用,因此我们可以用梯度的模长来替换掉损失函数,从而将式(5)变成:

跟损失函数的一个明显区别是,梯度模长显然具备平移不变性,并且分子分母关于 依然是齐次的,所以上式还保留了缩放不变性。因此,这是一个能同时具备平移和缩放不变性的理想目标。

梯度归一

对式(12)求梯度,我们得到:

可以看到,式(12)本质上是将每个任务损失的梯度进行归一化后再把梯度累加起来。它同时也告诉了我们一种实现方案,即可以让每个任务依次训练,每次只训练一个任务,然后将每个任务的梯度归一化后累积起来再更新,这样就免除了在定义损失函数的时候就要算梯度的麻烦了。

关于梯度归一化,笔者能找到相关工作是《GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks》[1],它本质上是式(2)和式(13)的混合,里边也包含了对梯度模长重新标定的思想,但却要通过额外的优化来确定任务权重,个人认为显得繁琐和冗余了。

本文小结

在损失函数的视角下,多任务学习的关键问题是如何调节每个任务的权重来平衡各自的损失,本文从缩放不变和平移不变两个角度介绍了一些参考做法,并补充了“广义平均”的概念,将多个任务的权重调节转化为单个参数的调节问题,可以简化调参难度。

参考文献

[1] https://arxiv.org/abs/1711.02257

特别鸣谢

感谢 TCCI 天桥脑科学研究院对于 PaperWeekly 的支持。TCCI 关注大脑探知、大脑功能和大脑健康。

更多阅读

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

多任务学习漫谈:以损失之名相关推荐

  1. 多任务学习漫谈:分主次之序

    ©PaperWeekly 原创 · 作者 | 苏剑林 单位 | 追一科技 研究方向 | NLP.神经网络 多任务学习是一个很宽泛的命题,不同场景下多任务学习的目标不尽相同.在<多任务学习漫谈(一 ...

  2. ​多任务学习漫谈:行梯度之事

    ©PaperWeekly 原创 · 作者 | 苏剑林 单位 | 追一科技 研究方向 | NLP.神经网络 在<多任务学习漫谈:以损失之名>中,我们从损失函数的角度初步探讨了多任务学习问题, ...

  3. PAMTRI:用于车辆重新识别的姿势感知多任务学习

    Today, we will discuss an unorthodox paper by NVIDIA Labs on Vehicle Re Identification. 今天,我们将讨论NVID ...

  4. ICASSP2023 | 基于多任务学习的保留背景音的语音转换

    在影视.有声书内容中,背景音是一种表现丰富的艺术形式.语音转换(Voice Conversion)如能将源说话人语音转换成目标说话人语音的同时,保留源语音中的背景音,将会提供更沉浸的语音转换体验.之前 ...

  5. 最新NLP架构的直观解释:多任务学习– ERNIE 2.0(附链接)| CSDN博文精选

    作者 | Michael Ye 翻译 | 陈雨琳,校对 | 吴金笛 来源 | 数据派THU(ID:DatapiTHU) 百度于今年早些时候发布了其最新的NLP架构ERNIE 2.0,在GLUE基准测试 ...

  6. 独家 | 最新NLP架构的直观解释:多任务学习– ERNIE 2.0(附链接)

    作者:Michael Ye 翻译:陈雨琳 校对:吴金笛 本文约1500字,建议阅读7分钟. 本文将介绍多任务学习. 科技巨头百度于今年早些时候发布了其最新的NLP架构ERNIE 2.0,在GLUE基准 ...

  7. 多任务学习,如何设计一个更好的参数共享机制?| AAAI 2020

    2019-12-26 05:44:43 作者 | 孙天祥 编辑 | 刘萍 原文标题:稀疏共享:当多任务学习遇见彩票假设 本文介绍了复旦大学邱锡鹏团队在AAAI 2020 上录用的一篇关于多任务学习的工 ...

  8. Multi-task Learning(Review)多任务学习概述

    https://www.toutiao.com/a6707402838705701383/ 背景:只专注于单个模型可能会忽略一些相关任务中可能提升目标任务的潜在信息,通过进行一定程度的共享不同任务之间 ...

  9. 当AI实现多任务学习,它究竟能做什么?

    来源:脑极体 提到AI领域的多任务学习,很多人可能一下子就想到通用人工智能那里了.通俗意义上的理解,就像<超能陆战队>里的大白这样一种护理机器人,既能进行医疗诊断,又能读懂人的情绪,还能像 ...

最新文章

  1. 业务代码解构利器--SWAK
  2. jar容器部署成功无法访问_Spring Boot 应用程序五种部署方式
  3. 银行柜员网申计算机水平要求高吗,银行网申没通过,是因为你水平差吗?
  4. ibatis调用mysql带OUT类型参数的存储过程并获取返回值
  5. BI开发之——ETL注意细节
  6. vue set方法_Vue 数据响应式
  7. 响应头中content-type常用的类型有哪些?
  8. C# 之 Win32 Api使用
  9. udp push java ddpush_DDPush首页、文档和下载 - 任意门推送 - OSCHINA - 中文开源技术交流社区...
  10. Linux 常用操作命令大全(最后更新时间:2022年1月)
  11. 数据挖掘-贡献度分析
  12. javaSE之多线程vip插队
  13. Windows查看电脑ip地址方法(用于连接远程桌面)
  14. 「硬核讲解」通达信跨周期引用均线指标公式
  15. C++计算绝对值的函数
  16. c语言写plc程序正反转,西门子PLC控制电机正反转编程实例!
  17. Missing Tag Identification in COTS RFID Systems: Bridging the Gap between Theory and Practice 翻译
  18. ap计算机课程的内容,AP系列七|解读AP计算机课程与考试
  19. pycharm报错提示:无法加载文件\venv\Scripts\activate.ps1,因为在此系统上禁止运行脚本。
  20. android动画光影效果图,光影游戏(二):用手机 App 制作电影海报风格图片

热门文章

  1. 百度智能云大数据全景架构图如何赋能企业数字化
  2. 一个很SB的方法,来开始调一个刚启动就SB的程序
  3. 创建javascript对象的几种方式
  4. 学习okhttp wiki--Connections.
  5. Magento后台表单字段添加备注
  6. html图片渐隐渐显,js实现图片切换效果渐隐渐显
  7. linux安装php pgsql,Linux下apache php+phppgadmin+postgresql安装配置
  8. centos7 怎么封装自己的镜像_「10」-CentOS7.5(1804)
  9. linux下运行hadoop,Linux环境下hadoop运行平台的搭建
  10. linux 播放器系统,在Linux上安装和使用开源视频播放器MPlayer