©PaperWeekly 原创 · 作者 | 苏剑林

单位 | 追一科技

研究方向 | NLP、神经网络

在《多任务学习漫谈:以损失之名》中,我们从损失函数的角度初步探讨了多任务学习问题,最终发现如果想要结果同时具有缩放不变性和平移不变性,那么用梯度的模长倒数作为任务的权重是一个比较简单的选择。我们继而分析了,该设计等价于将每个任务的梯度单独进行归一化后再相加,这意味着多任务的“战场”从损失函数转移到了梯度之上:看似在设计损失函数,实则在设计更好的梯度,所谓“以损失之名,行梯度之事”。

那么,更好的梯度有什么标准呢?如何设计出更好的梯度呢?本文我们就从梯度的视角来理解多任务学习,试图直接从设计梯度的思路出发构建多任务学习算法。

整体思路

我们知道,对于单任务学习,常用的优化方法就是梯度下降,那么它是怎么推导的呢?同样的思路能不能直接用于多任务学习呢?这便是这一节要回答的问题。

下降方向

其实第一个问题,我们在《从动力学角度看优化算法(三):一个更整体的视角》就回答过。假设损失函数为 ,当前参数为 ,我们希望设计一个参数增量 ,它使得损失函数更小,即 。为此,我们考虑一阶展开:

假设这个近似的精度已经足够,那么 意味着 ,即更新量与梯度的夹角至少大于 90 度,而其中最自然的选择就是

这便是梯度下降,即更新量取梯度的反方向,其中 即为学习率。

无一例外

回到多任务学习上,如果假设每个任务都同等重要,那么我们可以将这个假设理解为每一步更新的时候 都下降或保持不变。如果参数到达 后,不管再怎么变化,都会导致某个 上升,那么就说 是帕累托最优解(Pareto Optimality)。说白了,帕累托最优意味着我们不能通过牺牲某个任务来换取另一个任务的提升,意味着任务之间没有相互“内卷”。

假设近似(1)依然成立,那么寻找帕累托最优意味着我们要寻找 满足

注意到它存在平凡解 ,所以上述不等式组的可行域肯定非空,我们主要关心可行域中是否存在非零解:如果有,则找出来作为更新方向;如果没有,则有可能已经达到了帕累托最优(必要不充分),我们称此时的状态为帕累托稳定点(Pareto Stationary)。

求解算法

方便起见,我们记 ,我们寻求一个向量 ,使得对所有的 都满足 ,那么我们就可以像单任务梯度下降那样取 作为更新量。如果任务数只有两个,可以验证 自动满足 和 ,也就是说,双任务学习时,前面说的梯度归一化可以达到帕累托稳定点。

当任务数大于 2 时,问题开始变得有点复杂了,这里介绍两种求解方法,其中第一种思路是笔者自己给出的推导结果,第二种思路则是《Multi-Task Learning as Multi-Objective Optimization》[1] 给出的“标准答案”。

问题转化

首先我们对问题进行进一步的转化。留意到

所以我们只需要尽量最大化最小的那个 ,就能找出理想的 ,即问题变成了

不过这有点危险,因为一旦真的存在非零的 使得 ,那么让 的模长趋于正无穷,那么最大值便会趋于正无穷。所以为了结果的稳定性,我们需要加个正则项,考虑

这样无穷大模长的 就不可能是最优解了。注意到代入 后有 ,所以假设对 取 的最优解为 ,那么必然有

所以问题(6)的解必然是满足条件(4)的解,并且如果是非零解,那么其反方向必然是使得所有任务损失都下降的方向。

光滑近似

现在介绍问题(6)的第一种求解方案,它假设读者像笔者一样不熟悉 min-max 问题的求解,那么我们可以将第一步的 用光滑近似代替(参考《寻求一个光滑的最大值函数》[2] ),即

于是我们就可以先求解

然后再让 。这样我们就将问题转化为了单个函数的无约束最大化问题,直接求梯度然后让梯度为零得到

假设各个 的差距大于 量级,那么当 时,上式实际上是

然而,如果直接按照 的格式迭代,那么大概率是会振荡的,因为它要我们找到让 最小的 作为 ,假设为 ,那么下一步让 最小的 就很可能不再是 了,反而 可能是最大的那个。

直观来想,上述算法虽然振荡,但应该也是围绕着最优点 振荡的,所以如果我们把振荡过程中的所有结果都平均起来,就应该能得到最优点了,这意味着收敛到最优点的迭代格式是

留意到每次叠加上去的都是某个 ,所以最终的 必然是各个 的加权平均,即存在 且 ,使得

我们也可以将 理解为各个 的当前最优权重分配方案。

对偶问题

光滑近似技巧的好处是比较简单直观,不需要太多的优化算法基础,不过它终究只是“非主流”思路,有颇多不严谨之处(但结果倒是对的)。下面我们来介绍基于对偶思想的“标准答案”。

首先,定义 为所有 元离散分布的集合,即

那么容易检验

因此问题(6)等价于

上述函数关于 是凹的,关于 是凸的,并且 的可行域都是凸集(集合中任意两点的加权平均仍然在集合中),所以根据冯·诺依曼的 Minimax 定理 [3],式 (16)的 和 是可以交换的,即等价于

等号右边是因为 部分只是一个无约束的二次函数最大值问题,可以直接算出 ,因此最后只剩下 ,问题变成了求 的一个加权平均,使得其模长最小。

当 时,问题的求解比较简单,相当于作三角形的高,如下图所示:

▲ 当 时的求解算法及几何意义

当 时,我们可以用 Frank-Wolfe 算法 [4] 将它转化为多个 的情形进行迭代。对于 Frank-Wolfe 算法,我们可以将它理解为带约束的梯度下降算法,适合于参数的可行域为凸集的情形,但展开来介绍篇幅太大,这里就不详说了,请读者自行找资料学习。简单来说,Frank-Wolfe 算法先通过线性化目标,找到下一步更新的方向为 ,其中 而 为 位置为 1 的 one hot 向量,然后求解在 与 之间进行插值搜索,找出最优者作为迭代结果。所以,它的迭代过程为

其中 的求解,正是 的特例,用上述截图中的算法即可。如果 不通过搜索而得,而是固定为 ,那么结果则等价于(12),这也是 Frank-Wolfe 算法的一个简化版本。也就是说,我们通过光滑近似得到的结果,跟简化版 Frank-Wolfe 算法的结果是等价的。

去约束化

其实对于问题(17)的求解,理论上我们也可以通过去约束的方式直接用梯度下降求解。比如直接设参数 以及

那么就可以转化为

这是个无约束的优化问题,常规的梯度下降算法就可以求解。然而不知道为什么,笔者似乎没看到这样处理的(难道是不想调学习率?)。

一些技巧

在前一节中,我们给出了寻找帕累托稳定点的更新方向的两种方案,它们都要求我们在每一步的训练中,都要先通过另外的多步迭代来确定每个任务的权重,然后才能更新模型参数。由此不难想象,实际计算的时候计算量还是颇大的,所以我们需要想些技巧降低计算量。

梯度内积

可以看到,不管哪种方案,其关键步骤都有 ,这意味着我们要遍历梯度算内积。然而在深度学习场景下,模型参数量往往很大,所以梯度是一个非常大维度的向量,如果每一步迭代都要算一次内积,计算量很大。这时候我们可以利用展开式

每次迭代其实只有 不同,所以其实在每一步训练中 只需要计算一次存下来就行了,不用重复这种大维度向量内积的计算。

共享编码

然而,当模型大到一定程度的时候,要把每个任务的梯度都分别算出来然后进行迭代计算是难以做到的。如果我们假设多任务的各个模型共用同一个编码器,那么我们还可以进一步近似地简化算法。

具体来说,假设 batch_size 为 ,第 个样本的编码输出为 ,那么由链式法则我们知道:

记 ,那么就得到 ,利用矩阵范数不等式得到

不难想到,如果我们最小化 ,那么计算量就会明显减少,因为这只需要我们对最后输出的编码向量的梯度,而不需要对全部参数的梯度。而上式告诉我们,最小化 实际上就是在最小化式(17)的上界,像很多难以直接优化的问题一样,我们期望最小化上界也能获得类似的结果。

不过,这个上界虽然效率更高,但也有其局限性,它一般只适用于每一个样本都有多种标注信息的多任务学习,不适用于不同任务的训练数据无交集的场景(即每个任务是对不同的样本进行标注的,单个样本只有一种标注信息),因为对于后者来说,各个 是相互正交的,此时任务之间没有交互,上界没有体现出任务之间的相关性,也就是过于宽松而失去意义了。

错误证明

前面所提到的“标准答案”以及关于共享编码器时优化上界的结果,都来自论文《Multi-Task Learning as Multi-Objective Optimization》[1]。接下来原论文试图证明当 满秩时,优化上界也能找到帕累托稳定点。但是很遗憾,原论文的证明是错误的。

证明位于原论文的附录 A,里边用到了一个错误的结论:

如果 是对称正定矩阵,那么 当且仅当 。

很容易举例证明该结论是错的,比如 ,此时 但 。

经过思考,笔者认为原论文中的证明是难以修复的,即原论文的推测是不成立的,换言之,即便 满秩,优化上界得出的更新方向未必是能使得所有任务损失都不上升的方向,从而未必能找到帕累托稳定点。至于原论文中优化上界的实验效果也不错,只能说深度学习模型参数空间太大,可供“挪腾”的空间也很大,从而上界近似也能获得不错的结果了。

文本小结

在这篇文章中,我们从梯度的视角来理解多任务学习。在梯度视角下,多任务学习的主要工作是寻找一个尽可能与每个任务的梯度都反向的方向作为更新方向,从而使得每个任务的损失都能尽量下降,而不能通过牺牲某个任务来换取另一个任务的提升。这是任务之间无“内卷”的理想状态。

参考文献

[1] https://arxiv.org/abs/1810.04650

[2] https://kexue.fm/archives/3290

[3] https://en.wikipedia.org/wiki/Minimax_theorem

[4] https://en.wikipedia.org/wiki/Frank–Wolfe_algorithm

特别鸣谢

感谢 TCCI 天桥脑科学研究院对于 PaperWeekly 的支持。TCCI 关注大脑探知、大脑功能和大脑健康。

更多阅读

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

​多任务学习漫谈:行梯度之事相关推荐

  1. 多任务学习漫谈:分主次之序

    ©PaperWeekly 原创 · 作者 | 苏剑林 单位 | 追一科技 研究方向 | NLP.神经网络 多任务学习是一个很宽泛的命题,不同场景下多任务学习的目标不尽相同.在<多任务学习漫谈(一 ...

  2. 多任务学习漫谈:以损失之名

    ©PaperWeekly 原创 · 作者 | 苏剑林 单位 | 追一科技 研究方向 | NLP.神经网络 能提升模型性能的方法有很多,多任务学习(Multi-Task Learning)也是其中一种. ...

  3. 多任务学习中的网络架构和梯度归一化

    在计算机视觉中的单任务学习已经取得了很大的成功.但是许多现实世界的问题本质上是多模态的.例如为了提供个性化的内容,智能广告系统应该能够识别使用的用户并确定他们的性别和年龄,跟踪他们在看什么,等等.多任 ...

  4. 独家 | 最新NLP架构的直观解释:多任务学习– ERNIE 2.0(附链接)

    作者:Michael Ye 翻译:陈雨琳 校对:吴金笛 本文约1500字,建议阅读7分钟. 本文将介绍多任务学习. 科技巨头百度于今年早些时候发布了其最新的NLP架构ERNIE 2.0,在GLUE基准 ...

  5. Multi-task Learning(Review)多任务学习概述

    https://www.toutiao.com/a6707402838705701383/ 背景:只专注于单个模型可能会忽略一些相关任务中可能提升目标任务的潜在信息,通过进行一定程度的共享不同任务之间 ...

  6. [译]深度神经网络的多任务学习概览(An Overview of Multi-task Learning in Deep Neural Networks)...

    译自:http://sebastianruder.com/multi-task/ 1. 前言 在机器学习中,我们通常关心优化某一特定指标,不管这个指标是一个标准值,还是企业KPI.为了达到这个目标,我 ...

  7. 3.2.4 迁移学习和多任务学习

    迁移学习 总结一下,什么时候迁移学习是有意义的?如果你想从任务A学习并迁移一些知识到任务B,那么当任务A和任务B都有同样的输入时,迁移学习是有意义的.在第一个例子中,A和B的输入都是图像,在第二个例子 ...

  8. 2.8 多任务学习-深度学习第三课《结构化机器学习项目》-Stanford吴恩达教授

    ←上一篇 ↓↑ 下一篇→ 2.7 迁移学习 回到目录 2.9 什么是端到端的深度学习 多任务学习 (Multi-task Learning) 在迁移学习中,你的步骤是串行的,你从任务 AAA 里学习只 ...

  9. 多目标机器学习_NIPS2018 - 用多目标优化解决多任务学习

    题外话: 多任务学习可以说是机器学习的终极目标之一, 就像物理学家在追求统一所有力一样, 个人认为机器学习也在追求一个模型解决几乎所有问题. 虽然我们现在还离这个目标很远, 但是多任务学习在实际应用中 ...

最新文章

  1. LINQ系列:Linq to Object分区操作符
  2. MySql各引擎特点和性能测试
  3. leetcode 227. Basic Calculator II | 227. 基本计算器 II(中缀表达式求值)
  4. 编译安装python3.6_编译安装Python3.6及以上
  5. java模式设计视频教程_全新JAVA设计模式详解视频教程 完整版课程
  6. OpenCV3学习(2.3)——图像读取与鼠标截图
  7. ios实现图片动画效果
  8. Code Snippets for Windows Mobile 5 in C#
  9. android 阅读器字体,为 Android 换上任意喜欢的字体,你可以试试这个 Magisk 模块...
  10. 软考-系统分析师知识大纲及分数
  11. matlab矩阵的白化,白化原理及Matlab实现
  12. kubernetes v1.11 生产环境 二进制部署 全过程
  13. Python图像处理库PIL的基本概念介绍(一)
  14. JavaScript - 将 Allegro 坐标文件转为嘉立创坐标文件(CSV 格式)的工具
  15. 通过银行卡号获取所属银行
  16. cad指北针lisp_cad指北针命令(CAD如何绘制一个最简单的指北针)
  17. 免费不限速跨平台文件传输神器—文件疯巢
  18. 那些年啊,那些事——一个程序员的奋斗史 ——121
  19. 阿里云服务器设置IPV6通过AppStore审核
  20. Activiti判断流程是否结束

热门文章

  1. 《数据安全法》今日实施,中国信通院联合百度等企业发起“数据安全推进计划”
  2. 前端常见知识点三之HTML
  3. python如何爬虫股票数据_简单爬虫:东方财富网股票数据爬取(python_017)
  4. TFS2013 微软源代码管理工具 安装与使用图文教程
  5. exe文件添加为服务
  6. 组件设计实战--组件之间的关系 (Event、依赖倒置、Bridge)
  7. 北京林大计算机科技应为abc哪类,北京林业大学新生入学要准备什么?
  8. linux lw3m多行文本使用,linux常用命令以及一些常见问题和解决方法教程.docx
  9. python 下载公众号文章_python3下载公众号历史文章
  10. 笔记本电脑键盘切换_全球首款折叠屏笔记本电脑ThinkPad X1 Fold:5G高速互联拥抱PC场景融合时代...