Multi-task中的多任务loss平衡问题
Multi-task learning MTL 中的多任务loss平衡问题
- 背景
- 7 Nov 2017 - GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks
- 19 May 2017 - Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics
- 10 Oct 2018 - Multi-Task Learning as Multi-Objective Optimization
- 该文的方法
背景
multi-task的损失函数:
L(t)=∑wi(t)Li(t,θ)L(t)=\sum{w_i(t)L_i(t,\theta)}L(t)=∑wi(t)Li(t,θ)
在multi-task训练中存在着:
- 如何平衡各个任务损失的权重,
- 对于不同的任务的loss梯度之间的大小关系如果平衡,
- 各任务学习率如何控制.
这些问题影响mult-task训练的最终效果. 处理不当, 很有可能一个task学的很好, 其他task学的很差.
三个问题是相通的.
7 Nov 2017 - GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks
论文地址
GradNorm的解决方法, 将wiw_iwi作为可学习的参数.
目标是:
- 在不考虑学习率的情况下, 尽量使得各个task的wi(t)Li(t,θ)w_i(t)L_i(t,\theta)wi(t)Li(t,θ)对于参数θ\thetaθ的gradient都与平均值接近.
- 学习不充分的task, 给与比其他task更大的学习率.
ttt时间task iii的θ\thetaθ梯度:
Gθ(i)(t)=∥▽θwi(t)Li(t,θ)∥2G_\theta^{(i)}(t)=\Vert \bigtriangledown_{\theta}w_i(t)L_i(t,\theta)\Vert_2Gθ(i)(t)=∥▽θwi(t)Li(t,θ)∥2
所有task对θ\thetaθ的平均梯度:
G‾θ(t)=Etask[Gθ(i)(t)]\overline{G}_\theta(t)=E_{task}[G_\theta^{(i)}(t)]Gθ(t)=Etask[Gθ(i)(t)]
对于学习率, 定义若干个变量:
定义一个loss相对于初始化时的占比, 优化程度.
L~i(t)=Li(t)/Li(0)\widetilde{L}_i(t)=L_i(t)/L_i(0)Li(t)=Li(t)/Li(0)
定义一个比率值, 衡量task iii相对于所有task的优化程度. 值越大, 表明优化程度相对于其他task 优化程度不够.
ri(t)=L~i(t)/Etask[L~i(t)]r_i(t)=\widetilde{L}_i(t)/E_{task}[\widetilde{L}_i(t)]ri(t)=Li(t)/Etask[Li(t)]
最后定义一个损失函数, 用来作为wiw_iwi的优化目标函数:
Lgrad(t;wi(t))=∑i∣Gθ(i)(t)−G‾θ(t)×[ri(t)]α∣1L_{grad}(t; w_i(t)) = \sum_{i}\vert G_\theta^{(i)}(t)-\overline{G}_\theta(t)\times [r_i(t)]^\alpha \vert_1Lgrad(t;wi(t))=i∑∣Gθ(i)(t)−Gθ(t)×[ri(t)]α∣1
每次求上式对wiw_iwi的导数时, 固定G‾θ(t)×[ri(t)]α\overline{G}_\theta(t)\times [r_i(t)]^\alphaGθ(t)×[ri(t)]α的值, 所求的对wiw_iwi的梯度用来更新wiw_iwi的值.在每次更新之后. 需要对wiw_iwi进行normalization, 使得:
∑iwi(t)=T\sum_i w_i(t) = Ti∑wi(t)=T
wiw_iwi初始化时赋值为1, 即wi(0)=1w_i(0) = 1wi(0)=1
19 May 2017 - Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics
论文地址
方法比较简单, 假设也很合理:
如果是回归问题, fw(x)f^w(x)fw(x)为模型输出:
则假设label 似然概率yyy符合高斯分布:
p(y∣fw(x))=N(fw(x),σ2)p(y|f^w(x))=\mathcal{N}(f^w(x), \sigma^2)p(y∣fw(x))=N(fw(x),σ2)
如果是分类问题, 最终class yyy的概率为:
p(y∣fw(x))=softmax(fw(x))p(y|f^w(x))=softmax(f^w(x))p(y∣fw(x))=softmax(fw(x))
这里面的yyy可以为一个向量. 比如序列问题中.
多个task(KKK个), 希望整体概率最大.
p(y1,...,yK∣fw(x))=p(y1∣fw(x))×...×p(yK∣fw(x))p(y_1, ...,y_K|f^w(x))=p(y_1|f^w(x))\times...\times p(y_K|f^w(x))p(y1,...,yK∣fw(x))=p(y1∣fw(x))×...×p(yK∣fw(x))
又:
如果是回归问题:
logp(y∣fw(x)∝−12σ2∥y−fw(x)∥2−logσ\log p(y|f^w(x) ∝ −\frac{1}{2\sigma^2}\|y − f^w(x)\|^2 − \log ^\sigma logp(y∣fw(x)∝−2σ21∥y−fw(x)∥2−logσ
如果是分类问题, c∈Cc\in Cc∈C, 假设c^\hat{c}c^为正确分类:
logp(y=c^∣fW(x),σ)=1σ2fc^W(x)−log∑cexp(1σ2fcW(x))\log p(y = \hat{c}|f^W(x), \sigma) =\frac{1}{\sigma^2}f^W_{\hat{c}}(x) − \log \sum_{c}{exp({\frac{1}{\sigma^2}f^W_c(x)})}logp(y=c^∣fW(x),σ)=σ21fc^W(x)−logc∑exp(σ21fcW(x))
分类问题时, 还有一个用来简化优化目标的的近似替换, 具体见论文.
因此最终的log似然损失函数可以将分类和回归问题通过一个公式表达:
logp(y1∣fw(x))+...+logp(yK∣fw(x))\log p(y_1|f^w(x)) +...+\log p(y_K|f^w(x))logp(y1∣fw(x))+...+logp(yK∣fw(x))
10 Oct 2018 - Multi-Task Learning as Multi-Objective Optimization
论文地址
这篇文章的效果声称能够赛过前两篇:
第三个是文中第二个方法.
基本的想法是, 寻找多目标优化的帕累托最优解.
帕累托最优的定义:
帕累托最优解不止一个:
如下图: 由C→AC\rightarrow AC→A 或者由C→BC\rightarrow BC→B都是一个pareto optimize, 但A与B互相不是.
多任务的帕累托最优求解方法:
MGDA (multiple-gradient descent algorithm) 算法.
- 第一个条件是对θsh\theta^{sh}θsh求导.
- 第二个条件是对θt\theta^{t}θt求导.
这个算法参见论文:
Multiple-gradient descent algorithm (MGDA) for multiobjective optimization
但这边论文提出MGDA算法在面对神经网络的庞大参数时存在两个问题:
- 对高纬度梯度不是很试用(not scale gracefully)
- 需要对每个task计算梯度
所以本文提出了一种使用Frank-Wolfe-based optimizer的算法来处理这两个问题. 只需要一次backward explicit task-specific gradients 能计算出MGDA的一个上界, 使用这个上界, 能够得出帕累托最优的结果.
该文的方法
对于公式3, 如果是两个task:
有解析解为:
更一般的解法, 使用Frank wolfe 算法:
这个算法看的不是很懂, 可能需要看一下MGDA算法才能明白.
- et^\mathcal{e}_{\hat{t}}et^ 是什么.
- 第10步和11步是怎么来的.
- M矩阵从哪里冒出来的.
参考:
- 原始frank wolfe算法
Multi-task中的多任务loss平衡问题相关推荐
- Multi task learning多任务学习背景简介
2020-06-16 23:22:33 本篇文章将介绍在机器学习中效果比较好的一种模式,多任务学习(Multi task Learning,MTL).已经有一篇机器之心翻译的很好的博文介绍多任务学习了 ...
- 多智能体强化学习Multi agent,多任务强化学习Multi task以及多智能体多任务强化学习Multi agent Multi task概述
概述 在我之前的工作中,我自己总结了一些多智能体强化学习的算法和通俗的理解. 首先,关于题目中提到的这三个家伙,大家首先想到的就是强化学习的五件套: 状态:s 奖励:r 动作值:Q 状态值:V 策略: ...
- 神经网络中,设计loss function有哪些技巧?
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自:视学算法 神经网络中,设计loss function有哪 ...
- multi task训练torch_采用single task模型蒸馏到Multi-Task Networks
论文地址. 这篇论文主要研究利用各个single task model来分别作为teacher model,用knowledge distillation的方法指导一个multi task model ...
- EMNLP 2021 | 多标签文本分类中长尾分布的平衡策略
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者 | 黄毅 作者简介:黄毅,本文一作,目前为罗氏集团的数据科学家 ...
- Multi Task Learning在工业界如何更胜一筹
摘要: 本文主要介绍多任务学习和单任务学习的对比优势以及在工业界的一些使用.如何从单任务学习转变为多任务学习?怎样使AUC和预估的准确率达到最佳?如何对实时性要求较高的在线应用更加友好?本文将以淘宝实 ...
- JSRE中的多任务与多线程
前言 这几天在爱智官网看了下JSRE其他的Api,看了一个比较有意思的模块 - 多任务模块task,大致看了下他们的接口说明和案例,感觉和多线程差不多,然后就准备去看下实现方式,找了很久没有找到源 ...
- 如何管理和记录 SSIS 各个 Task 的开始执行时间和结束时间以及 Task 中添加|删除|修改的记录数...
开篇语 在这篇日志中 如何在 ETL 项目中统一管理上百个 SSIS 包的日志和包配置框架 我介绍到了包级别的日志管理框架,那么这个主要是针对包这一个层级的 Log 信息,包括包开始执行和结束时间,以 ...
- Debug深度学习中的NAN Loss
深度学习中遇到NAN loss 什么都不改,重新训练一下,有时也能解决问题 学习率减小 检查输入数据(x和y),如果是正常突然变为NAN,有可能是学习率策略导致,也可能是脏数据导致 If using ...
- 图像风格迁移_【论文解读】图像风格迁移中的Contextual Loss
[08/04更新]在前几天的Commit中,Contextual Loss已经支持多GPU训练 1.Background 对于图像风格迁移,最常用的做法就是通过GAN网络实现,然而,如果你没有很强大的 ...
最新文章
- ssm项目集成ftp_SSM开发框架实例(struts+spring+springmvc)
- 一文详解C++文件读写(FileStorage、txt)
- JavaScript实现depth First Search深度优先搜索算法(附完整源码)
- 怎样分辨谁才是朋友圈里的真·贵族?
- 我了解到的面试的一些小内幕!附面试题答案
- 祝融号火星车亮相,每小时仅移动40米,为何比乌龟还慢?
- bzoj1967 [AHOI2005]穿越磁场 离散最短路
- eclipse切换git分支
- ALGO-147_蓝桥杯_算法训练_4-3水仙花数
- 再说变体结构 - 回复 彬 的问题
- STM32的AD通道干扰问题
- [博弈论]JZOJ 3339 wyl8899和法法塔的游戏
- 多目标人工秃鹫优化算法(MATLAB源码分享,智能优化算法) 提出了一种多目标版本的人工秃鹫优化算法(AVOA)
- vue-cli生成的模板各个文件详解(转)
- Markdown表格、单元格合并、快速编辑表格
- 抱歉出现问题:关闭 windows hello,然后尝试再次运行安装程序
- Tendermint KVStore案例解析
- 南加州大学计算机专业研究生录取,南加州大学计算机科学(数据科学)理学硕士研究生申请要求及申请材料要求清单...
- 最大机枪池被黑客攻击,BSC接连被暴击后将走向何方?
- [原创]全面增强版 eXtremeComponents !!!!
热门文章
- 这几天阅读的shadowgun的几个shader
- Win 10 添加多国语言
- 让 WordPress 支持多国语言包
- 《 极秀校园行Windows XP SP3装机专版 》 光盘介绍
- 如何在计算机快速删掉快捷方式,电脑桌面上的网页快捷方式怎么删除?怎么在桌面便签上快速删除网页快捷方式...
- SQLite数据库的CRUD操作
- X8AIP 驱动程序
- 网页设计下拉菜单栏css代码,HTML+CSS实现导航条下拉菜单的示例代码
- cvs转datatable_C# CSV 文件转换成DataTable
- 从移动硬盘安装计算机系统文件,硬盘之前做成了移动硬盘,现在装回电脑上重装系统时分区认不到盘,怎么办?...