如果有n个任务(传统的深度学习方法旨在使用一种特定模型仅解决一项任务),而这n个任务或它们的一个子集彼此相关但不完全相同,则称为多任务学习(MTL) 通过使用所有n个任务中包含的知识,将有助于改善特定模型的学习。

单任务学习:一次只学习一个任务(task),大部分的机器学习任务都属于单任务学习。
多任务学习:把多个相关(related)的任务放在一起学习,同时学习多个任务。

单任务学习时,各个任务之间的模型空间(Trained Model)是相互独立的。
多任务学习时,多个任务之间的模型空间(Trained Model)是共享的。

1 基本模型框架

通常将多任务学习方法分为:hard parameter sharingsoft parameter sharing

1.1 hard parameter sharing

无论最后有多少个任务,底层参数统一共享,顶层参数各个模型各自独立。由于对于大部分参数进行了共享,模型的过拟合概率会降低,共享的参数越多,过拟合几率越小,共享的参数越少,越趋近于单个任务学习分别学习。形象理解为:几个人在一张桌子上吃几盘菜,自己碗里有自己的饭,共享的就是桌子、几盘菜,不共享的就是自己碗里的,桌子上菜越多,自己碗里的越少,吃腻的概率更小;自己碗里一自己的饭,桌子上没几个菜,一会儿饭就吃腻了。

在所有任务之间共享隐藏层,同时保留几个特定任务的输出层来实现。降低了过拟合的风险。

1.2 soft parameter sharing

每个任务都有自己的模型,自己的参数。

底层共享一部分参数,自己还有独特的一部分参数不共享顶层有自己的参数

底层共享的、不共享的参数如何融合到一起送到顶层,也就是研究人员们关注的重点啦。

这里可以放上经典的MMOE模型结构,大家也就一目了然了。

(a)hard sharing(底层参数统一共享,顶层参数各个模型各自独立)

(b)和(c)先对Expert 0-2进行加权求和之后再送入Tower A和B,通过Gate来决定到底加权是多少。(每个expert,Tower A和B,Gate都可以理解为一个隐层神经网络)

(超纲部分:聪明的小伙伴看到这个加权求和,是不是立刻就想到Attention啦?要不咱们把这个Gate改为一种Attention?对不同Expert的Attention来决定求和权重,那你得想办法设计Attention的query啦,是个有趣的点。)

把多个/单个输入送到一个大模型里(参数如何共享根据场景进行设计),预测输出送个多个不同的目标,最后放一起(比如直接相加)进行统一优化。

2 多任务学习改进的方向

我们先假设多个任务适合放在一起,对于这些适合放在一起的任务,我们有哪些方向呢?

2.1模型结构设计:哪些参数共享,哪些参数不共享?

把模型共享参数部分想象成榴莲千层,想象一下我们是竖着切了吃,还是横着一层侧概念拨开了吃。

竖着切了吃

对共享层进行区分,也就是想办法给每个任务一个独特的共享层融合方式。MOE和MMOE模型就是竖着切了吃的例子。另外MMOE在MOE的基础上,多了一个GATE,意味着:多个任务既有共性(关联),也必须有自己的独特性(Task specific)。共性和独特性如何权衡:每个任务搞一个专门的权重学习网络(GATE),让模型自己去学,学好了之后对expert进行融合送给各自任务的tower,最后给到输出。

一层层拿来吃

对不同任务,不同共享层级的融合方式进行设计。如果共享网络有多层,那么通常我们说的高层神经网络更具备语义信息,那我们是一开始就把所有共享网络考虑进来,还是从更高层再开始融合呢?如图6最右边的PLE所示,Input上的第1层左边2个给了粉色G,右边2个给了绿色G,3个都给了蓝色G,然后第2层左边2块给左边的粉色G和Tower,右边两块输出到绿色G和Tower。

2.2 MTL的目标loss设计和优化改进

既然多个任务放在一起,往往不同任务的数据分布、重要性也都不一样,大多数情况下,直接把所有任务的loss直接求和然后反响梯度传播进行优化,是不是不合适呢?
我们需要仔细平衡所有任务的联合训练过程,以避免一个或多个任务在网络权值中具有主导影响的情况。极端情况下,当某个任务的loss非常的大而其它任务的loss非常的小,此时多任务近似退化为单任务目标学习,网络的权重几乎完全按照大loss任务来进行更新,逐渐丧失了多任务学习的优势。

假设任务特定权重的优化目标wi和任务特定损失函数Li,通常多任务学习的loss function可以写为:

那么对于共享参数Wsh在梯度下降优化时,使用随机梯度下降来尽量减少上图方程的总目标函数值,对共享层Wshare中的网络权值通过以下规则进行更新:

从上图的方程可以看出:

1、loss大则梯度更新量也大;

2、不同任务的loss差异大导致模型更新不平衡的本质原因在于梯度大小

3、通过调整不同任务的loss权重wi可以改善这个问题;

4、直接对不同任务的梯度进行处理也可以改善这个问题;

Wsh 的优化受到所有loss的影响,那么优化思路自然而然可以为:1、在权重wi上做文章;2、在梯度上做文章。

  • loss的权重进行设计,最简单的权重设计方式是人为给每一个任务一个权重;
  • 根据任务的Uncertainty对权重进行计算,读者可参考经典的:Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics。

1 数据依赖性(异方差不确定性)依赖于输入数据,模型预测结果的残差的方差即随着数据的输入发生变化;

2、任务依赖性(同方差不确定性)是不依赖于输入数据的任意不确定性,它与模型输出无关,是一个在所有输入数据保持不变的情况下,在不同任务之间变化的量,因此,它可以被描述为与任务相关的不确定性,但是作者并没有详细解释在多任务深度学习中的同方差不确定性的严格定义,而是认为同方差不确定性是由于任务相关的权重引起的。

  • 由于不同loss取值范围不一致,可以尝试通过调整loss的权重wi让每个loss对共享Wsh参数贡献平等呢?GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks,以及另外一篇相似思路的文章End-to-end multi-task(希望不同任务loss的量级接近,纳入梯度计算权重,优点是可以考虑loss的量级,缺点是每一步都要额外算梯度)
  • learning with attention 提出一种Dynamic Weight Averaging的方法来平衡不同的学习任务。(记录每步的loss,loss缩小快的任务权重会变小,缺点是没有考虑量级)
  • Multi-Task Learning as Multi-Objective Optimization对于MTL的多目标优化理论分析也十分有趣,对于MTL优化理论推导感兴趣的同学值得一看。

2.3 直接设计更合理的辅助任务

前面的方法一个要设计网络结构,一个要设计网络优化方式,听起来其实实践上对于很多新手不是很友好,那这里再介绍一个友好的方式!对于MTL优化的一个方向为什么不是找一个更合适的辅助任务呢?只要任务找的好,辅助loss直接加上去,人为设计一下权重调个超参数,模型结构几乎不变都有可能效果更好的!

辅助任务设计的常规思路:

  • 找相关的辅助任务!不想关的任务放一起反而会损害效果的!如何判断任务是否想关呢?当然对特定领域需要有一定的了解,比如视频推荐里的:是否点击+观看时长+是否点赞+是否转发+是否收藏等等。

  • 对于相关任务不太好找的场景可以尝试一下对抗任务,比如学习下如何区分不同的domain的内容。

  • 预测数据分布,如果对抗任务,相关任务都不好找,用模型预测一下输入数据的分布呢?比如NLP里某个词出现的频率?推荐系统里某个用户对某一类iterm的点击频率。

  • 正向+反向。以机器机器翻译为例,比如英语翻译法语+法语翻英语,文本纠错任务中也有类似的思想和手段。

  • Pre-train,某种程度上这属于transfer learning,但是放这里应该其实是可以的。比如有一种pretrain的方式是:先train任务A再联合train任务B。预训练本质上是在更好的初始化模型参数,所以想办法加一个帮助初始化模型参数的辅助任务也是可以的。

你需要搭建一个网络模型来完成一个特定的图像分类的任务。首先,你需要随机初始化参数,然后开始训练网络,不断调整直到网络的损失越来越小。在训练的过程中,一开始初始化的参数会不断变化。当你觉得结果很满意的时候,你就可以将训练模型的参数保存下来,以便训练好的模型可以在下次执行类似任务时获得较好的结果。这个过程就是pre−training。

  • 预测一下要做的任务该不该做,句子中的词位置对不对,该不该放这里,点击序列中该不该出现这个iterm?这也是一个有趣的方向。比如文本纠错任务,可不可以先预测一下这个文本是不是错误的呢?

主要参考:

https://zhuanlan.zhihu.com/p/348873723?utm_source=wechat_session&utm_medium=social&utm_oi=1125847523901984769&utm_campaign=shareopn

https://imzhanghao.com/2020/10/25/multi-task-learning/
https://jishuin.proginn.com/p/763bfbd57673
https://zhangkaifang.blog.csdn.net/article/details/89320108?spm=1001.2101.3001.6661.1&utm_medium=distribute.pc_relevant_t0.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-1.pc_relevant_antiscanv2&depth_1-utm_source=distribute.pc_relevant_t0.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-1.pc_relevant_antiscanv2&utm_relevant_index=1
https://blog.csdn.net/qq_27590277/article/details/115535372

整理学习之多任务学习相关推荐

  1. ICML2018见闻 | 迁移学习、多任务学习领域的进展

    作者 | Isaac Godfried 译者 | 王天宇 编辑 | Jane 出品 | AI科技大本营 [导读]如今 ICML(International Conference on Machine ...

  2. 花书+吴恩达深度学习(十八)迁移学习和多任务学习

    目录 0. 前言 1. 迁移学习 2. 多任务学习 如果这篇文章对你有一点小小的帮助,请给个关注,点个赞喔~我会非常开心的~ 花书+吴恩达深度学习(十八)迁移学习和多任务学习 花书+吴恩达深度学习(十 ...

  3. 多分类学习、多标签学习、多任务学习的区别

    Multi-class. Multi-label . Multi-task 三者之间的区别与相同之处 1.直观解释 多分类学习(Multi-class) 一个分类器,但分的类别是包含多个的.例如:分类 ...

  4. 3.2.4 迁移学习和多任务学习

    迁移学习 总结一下,什么时候迁移学习是有意义的?如果你想从任务A学习并迁移一些知识到任务B,那么当任务A和任务B都有同样的输入时,迁移学习是有意义的.在第一个例子中,A和B的输入都是图像,在第二个例子 ...

  5. 模型独立学习:多任务学习与迁移学习

    导读:机器学习的学习方式包括监督学习和无监督学习等.针对一个给定的任务,首先要准备一定规模的训练数据,这些训练数据需要和真实数据的分布一致,然后设定一个目标函数和优化方法,在训练数据上学习一个模型.此 ...

  6. 3.2 实战项目二(手工分析错误、错误标签及其修正、快速地构建一个简单的系统(快速原型模型)、训练集与验证集-来源不一致的情况(异源问题)、迁移学习、多任务学习、端到端学习)

    手工分析错误 手工分析错误的大多数是什么 猫猫识别,准确率90%,想提升,就继续猛加材料,猛调优?     --应该先做错误分析,再调优! 把识别出错的100张拿出来, 如果发现50%是"把 ...

  7. 多核学习、多视图学习、多任务学习和集成学习的区别和联系

    多核学习既可以用在多任务学习,也可以用在多视图学习,也有研究同时对多任务和多视图同时采用多核的,目前已经有通用多任务多核学习方法.如果将多核用在多任务学习,相当于不同任务共享子空间的同时,还有各自特有 ...

  8. 深度学习之----多任务学习

    介绍 在机器学习(ML)中,通常的关注点是对特定度量进行优化,度量有很多种,例如特定基准或商业 KPI 的分数.为了做到这一点,我们通常训练一个模型或模型组合来执行目标任务.然后,我们微调这些模型,直 ...

  9. 【深度学习】多任务学习概览(An Overview of Multi-task Learning in Deep Neural Networks)

    1. 前言 在机器学习中,我们通常关心优化某一特定指标,不管这个指标是一个标准值,还是企业KPI.为了达到这个目标,我们训练单一模型或多个模型集合来完成指定得任务.然后,我们通过精细调参,来改进模型直 ...

  10. Multi-task Learning in LM(多任务学习,PLE,MT-DNN,ERNIE2.0)

    提升模型性能的方法有很多,除了提出过硬的方法外,通过把神经网络加深加宽(深度学习),增加数据集数目(预训练模型)和增加目标函数(多任务学习)都是能用来提升效果的手段.(别名Joint Learning ...

最新文章

  1. java多线程系类:基础篇:10生产者消费者的问题
  2. 《Python数据分析》-Ch01 Python 程序库入门
  3. 决定equipment download到CRM后是否执行save的因素
  4. maven项目和普通项目转换
  5. 屠呦呦3年后再上热搜:女先生,世无双!
  6. x86基础之数与数据类型
  7. scrapy 在迭代爬取时被拒 offsite 增加dont_filter=True
  8. 电力电子技术笔记-逆变电路
  9. 宝藏软件:“小狼毫” 一款开源牛叉输入法
  10. 电信光猫破解 (打开无线wifi及路由功能)
  11. 开源许可证 有人管吗_4个令人困惑的开源许可证场景以及如何浏览它们
  12. MATLAB等值线绘制
  13. Android对应颜色值代码
  14. 润乾报表导出pdf问题
  15. Android程序水印效果
  16. Chrome开发工具Network没有显示完整的http request和response对话
  17. HMACSHA512
  18. Spring Cloud 异常“ Caused by: java.net.UnknownHostException: discovery.host ”
  19. Map map=new HashMap(); 为什么是这样
  20. CRM系统主要包含什么内容

热门文章

  1. 一文搞懂数据结构之 递归-八皇后问题
  2. 移动硬盘格式化后如何数据恢复?
  3. 怎么注册tk域名_新.tk域名免费注册教程
  4. PHP将PPT文件转成图片
  5. lisp 图层字体式样替换_ps将不同图层字体修改成相同字体的方法
  6. tensorflow获取tensor的shape
  7. allegro铜皮倒圆角
  8. Photoshop教程:超全的PS快捷键大全分享
  9. 给视频加水印的软件有哪些?推荐两种软件快速加水印
  10. iOS 上架App Store 遇到的坑