点上方计算机视觉联盟获取更多干货

仅作学术分享,不代表本公众号立场,侵权联系删除

转载于:AI公园   知乎

链接:https://www.zhihu.com/question/268105631

AI博士笔记系列推荐

周志华《机器学习》手推笔记正式开源!可打印版本附pdf下载链接

在实验中,每当涉及到loss修改都比较恼火,存在如下问题:

1、multitask时,如何控制各部分loss的权重?(如目标检测任务中的框回归loss+分类loss)

2、multitask时,起始训练各loss的数量级不同对收敛存在着哪些影响?

3、根据任务修改loss时,如改进tripet loss,时常出现神经网络作弊的情况,该如何设计loss?
看看各位小伙伴的回答吧!

作者:刘诗昆

https://www.zhihu.com/question/268105631/answer/333738561

题主这个问题是 multi-task learning 里相当重要的一个核心问题。我正好在做相关的工作,很多细节将在论文投稿后再更新此答案。此外,我对 re-identfication 相关研究不熟所以无法回答第三个问题请见谅,望其他研究者补充。

理解多任务学习: Understanding Multi-task learning

Multi-task learning 核心的问题通常是可简单分为两类:

  1. How to share: 这里主要涉及到基于 multi-task learning 的网络设计。

  2. Share how much: 如何平衡多任务的相关性使得每个任务都能有比 single-task training 取得更好的结果。

题主的问题主要落在第二类,尽管这两个问题通常同时出现也互相关联。对于 multi-task learning 更加粗略的介绍以及和 transfer learning 的关系请参看我之前的回答:什么是迁移学习 (Transfer Learning)?这个领域历史发展前景如何?https://www.zhihu.com/question/41979241/answer/123545914其中同样包括了 task weighting 的一些讨论,以下再做更加细节的补充。

网络设计和梯度平衡的关系: The Relationship Between Network Design and Gradient Balancing

无论是网络设计还是平衡梯度传播,我们的目标永远是让网络更好的学习到 transferable, generalisable feature representation 以此来缓解 over-fitting。为了鼓励多任务里多分享各自的 training signal 来学泛化能力更好的 feature,之前绝大部分研究工作的重点在网络设计上。直到去年才有陆续一两篇文章开始讨论 multi-task learning 里的 gradient balancing 问题。

再经过大量实验后,我得出的结论是,一个好的 gradient balancing method 可以继续有效增加网络的泛化能力,但是在网络设计本身的提高强度面前,这点增加不足一提。更加直白的表达是:

Gradient balancing method 一定需要建立在网络设计足够好的基础上,不然光凭平衡梯度并不会对网络泛化能力有着显著的改变。

梯度统治: Gradient Domination

在 multi-task learning 里又可根据 training data 的类别再次分为两类:

  • one-to-many (single visual domain): 输入一个数据,输出多个标签。通常是基于 image-to-image 的 dense prediction。一个简单的例子,输入一张图片,输出 semantic segmentation + depth estimation。

  • many-to-many (multi visual domain):输入多个数据,输入各自标签。比如如何同时训练好多个图片分类任务。

由于不同任务之间会有较大的差异,平衡梯度的目标是为了减缓任务本身的由于 variance, scale, complexity 不同而导致的差异。

在训练 multi-task 网络时候则会因为任务复杂度的差异出现一个现象,我把他称之为: Gradient Domination, 通常发生在 many-to-many 的任务训练中。因为图片分类可以因为图片类别和本身数据数量而出现巨大差异。而基于 single visual domain 的 multi-task learning 则不容易出现这个问题因为数据集是固定的。

最极端的例子:MNIST + ImageNet 对于这种极端差异的多任务训练基本可以看成基于 MNIST initialisation 的网络对于 ImageNet 的 finetune。所以这种情况的建议就是:优先训练复杂度高的数据集,收敛之后再训练复杂度低的数据集。当然这种情况下,多任务学习也没有太大必要了。

对于一些差别比较大但是还是可接受范围的比如:SVHN + CIFAR100。这种情况的 gradient balancing 就会出现一定的效果但也取决于你输入数据的方式。输入数据的通常方法,例如在这篇文章里:Incremental Learning Through Deep Adaptation:https://arxiv.org/pdf/1705.04228.pdf就是通过一个 dataset switch 来决定更新哪一个数据集的参数。对于这种方法,起始 learning rate 调的低,网络本身就会有一个较好的下降速率。

动态加权梯度传播: Adaptive Weighting Scheme

即使光对优化网络调参并不能给多任务学习有着本质的改变。在考虑最 straightforward 的 loss:

  我们的目标是学习好一个  能够根据训练效果动态变化使得平衡网络的梯度传播。

这个问题目前只有两篇文章做出了相关成果,

  • Weight Uncertainty: 这个是通过 Gaussian approximation 的方式直接对修改了 loss 的方式,并同时以梯度传播的方式来更新里面的两个参数。实际实验效果也还不错,在我复现的结果来看能有显著的提升但是比较依赖并敏感于一个合适的的 learning rate 的设置。

  • GradNorm : 是通过网络本身 back-propagation 的梯度大小进行 renormalisation。这篇文章写的比较草率并被最近的 ICLR 2018 拒绝收录了。个人期待他的更新作品能对方法本身有着更细节的描述。

  • Dynamic Weight Average: 我对于 GradNorm 一个更加简约且有效的改进,细节将会被补充。

一些总结

平衡梯度问题最近一年才刚刚开始吸引并产出部分深入研究的工作,这个方向对于理解 multi-task learning 来说至关重要,也可以引导我们去更加高效且条理化的训练多任务网络。但在之前,更重要的事情是理解泛化能力本身,个人觉得 multi-task learning 的核心目标不在于训练多个任务并得到超越单任务学习的性能,而是通过理解 multi-task learning 学习的过程重新思考并加深理解深度学习里 generlisation 的真正意义和价值。

作者:柯小波
https://www.zhihu.com/question/268105631/answer/336855757

做 multi-task 很重要的一点是要清楚各个 task 收敛的具体曲线。具体来说,不同 task 快速收敛时的梯度大小可能是不一样的,对不同学习率的敏感程度也可能是不一样的。而且有时候不同 task 是互相影响非常大,比如 task A 的结果是 task B 的输入,task A 不收敛的话,task B 收敛就更难了,也许这时候 task B 在训练一开始还 gradient domination,没法玩了。

我的经验是训练的时候一个一个 task 的加,方便调参。比如 task A 训练得收敛了,再 joint A and B 一起训练到收敛,以此类推。。先训练哪个,在 lr 降到什么时候加,给的 loss weight 是多少,这里就可以根据我们之前说的各个 task 的收敛曲线来调了。所以说最重要的还是要了解你的各个 task 在什么情况下才能收敛得好。这个方法其实也可以看成手动调 loss weight,像是一开始让 task B 的权重变成一个极小值,比如 10e-6。这种方法在某些容易互相影响的 task 是好使的,比如 re-id 的时候先训 softmax 再加入各种 margin learning,比如有时候 faster rcnn joint e2e training 不 work 的时候,先训练 rpn 再加入 rcnn 一起 e2e。

multi-task 最终的目的是 1+1=2,如果真的想要 1+1>2,除了要有一个好的网络设计之外,调参能力也是要的。总之大力出奇迹吧。

作者:王晋东不在家
https://www.zhihu.com/question/268105631/answer/333281876

这个问题太大,需要结合具体问题。一般来说就是原始的网络分类loss加上特定任务下的泛化loss。具体怎么设计,还是多看看文章。其实,非深度学习的那些准则,都可以加进去的。比如iclr 18的minimal entropy correlation alignment就是把传统方法coral和熵最小化加进去了。套路比较深。

作者:张小磊
https://www.zhihu.com/question/268105631/answer/333601828

一直认为设计或者改造loss function是机器学习领域的精髓,好的损失函数定义可以既能够反映模型的训练误差,也能够一定程度反映模型泛化误差,可以很好的指导参数向着模型最优的道路进发。接下来关于设计损失函数提一些自己的看法:

1、设计损失函数之前应该明确自己的具体任务(分类、回归或者排序等等),因为任务不同,具体的损失定义也会有所区别。对于分类问题,分类错误产生误差;对于排序问题,样本的偏序错误才产生误差等。

2、设计损失函数应该以评价指标为导向,因为你的损失函数需要你的评价指标来评判,因此应该做到对号入座,回归问题用均方误差来衡量,那么损失函数应为平方损失;二分类问题用准确率来衡量,那么损失函数应为交叉熵损失,等等。

3、设计损失函数应该明确模型的真实误差和模型复杂度(有种说法是,经验误差最小化和结构误差最小化),既要保证损失函数能够很好的反映训练误差,又要保证模型不至于过度繁琐(过拟合的风险),也就是奥卡姆剃刀原理,如无必要,勿增实体。

4、设计损失函数时我们应该善于变通、善于借鉴、善于迁移。以2017年WWW上的Collaborative metric learning为例,该文将SVM的hinge loss引入到了metric learning里边,使得越相近的类里的越近,不相近的类距离越远,同时会有一个最大边界来处理分类错误的点(软间隔),最后将该损失函数又引入到了推荐系统中的协同过滤算法(CF)中。可以看出对于自己的研究领域,我们可以借鉴经典的损失函数来为我所用,以此来提升该领域的性能。

当然,以上说的更多的是普适思路,适用于传统机器学习,相信对于深度学习同样有借鉴意义。至于对于深度学习其他的技巧,应该还需要考虑深度学习模型独有的一些问题,比如模型相对复杂以至于极易过拟合的风险,以及涉及参数众多需要简化调参等。

作者:时代的一粒灰
https://www.zhihu.com/question/268105631/answer/338696341

以我目前做的语义分割项目来讲,主要面临的问题之一是类别不平衡(class-imbalance)。

除了在数据增强阶段对特定区域进行过采样以外,另外一个解决方法就是在原有交叉熵的基础上引入代价矩阵(cost matrix),换句话说原有的交叉熵中每一个label都是equally weighted,现在我们通过调整权重来解决样本不平衡问题。方法很简单但是取得的效果不错,当然也依赖于大量的调参工作。

-------------------

END

--------------------

我是王博Kings,985AI博士,华为云专家、CSDN博客专家(人工智能领域优质作者)。单个AI开源项目现在已经获得了2100+标星。现在在做AI相关内容,欢迎一起交流学习、生活各方面的问题,一起加油进步!

我们微信交流群涵盖以下方向(但并不局限于以下内容):人工智能,计算机视觉,自然语言处理,目标检测,语义分割,自动驾驶,GAN,强化学习,SLAM,人脸检测,最新算法,最新论文,OpenCV,TensorFlow,PyTorch,开源框架,学习方法...

这是我的私人微信,位置有限,一起进步!

王博的公众号,欢迎关注,干货多多

王博Kings的系列手推笔记(附高清PDF下载):

博士笔记 | 周志华《机器学习》手推笔记第一章思维导图

博士笔记 | 周志华《机器学习》手推笔记第二章“模型评估与选择”

博士笔记 | 周志华《机器学习》手推笔记第三章“线性模型”

博士笔记 | 周志华《机器学习》手推笔记第四章“决策树”

博士笔记 | 周志华《机器学习》手推笔记第五章“神经网络”

博士笔记 | 周志华《机器学习》手推笔记第六章支持向量机(上)

博士笔记 | 周志华《机器学习》手推笔记第六章支持向量机(下)

博士笔记 | 周志华《机器学习》手推笔记第七章贝叶斯分类(上)

博士笔记 | 周志华《机器学习》手推笔记第七章贝叶斯分类(下)

博士笔记 | 周志华《机器学习》手推笔记第八章集成学习(上)

博士笔记 | 周志华《机器学习》手推笔记第八章集成学习(下)

博士笔记 | 周志华《机器学习》手推笔记第九章聚类

博士笔记 | 周志华《机器学习》手推笔记第十章降维与度量学习

博士笔记 | 周志华《机器学习》手推笔记第十一章稀疏学习

博士笔记 | 周志华《机器学习》手推笔记第十二章计算学习理论

博士笔记 | 周志华《机器学习》手推笔记第十三章半监督学习

博士笔记 | 周志华《机器学习》手推笔记第十四章概率图模型

点分享

点收藏

点点赞

点在看

收藏 | 神经网络中,设计loss function有哪些技巧?相关推荐

  1. 神经网络中,设计loss function有哪些技巧?

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自:视学算法 神经网络中,设计loss function有哪 ...

  2. html 点击收藏效果,收藏Javascript中常用的55个经典技巧

    收藏Javascript中常用的55个经典技巧 更新时间:2007年08月12日 15:39:21   作者: 1. οncοntextmenu="window.event.returnVa ...

  3. 神经网络常用损失函数Loss Function

    深度学习神经网络常用损失函数 损失函数--Loss Function 1. MSE--均方误差损失函数 2. CEE--交叉熵误差损失函数 3. mini-batch版交叉熵误差损失函数 损失函数–L ...

  4. 人工神经网络中的activation function的作用具体是什么?为什么ReLu要好过于tanh和sigmoid function?

    转自:https://www.zhihu.com/question/29021768 附:双曲函数类似于常见的(也叫圆函数的)三角函数.基本双曲函数是双曲正弦"sinh",双曲余弦 ...

  5. 损失函数(Loss Function)在实际应用中如何合理设计

    目录 1 前言 2 回归(Regression)任务 2.1 均方误差MSE(mean squared error) 2.2 平均绝对误差MAE( mean absolute error) 2.3 H ...

  6. Semantic Instance Segmentation with a Discriminative Loss Function【论文详解】

    PAPER:https://arxiv.org/abs/1708.02551 CODE:https://github.com/DavyNeven/fastSceneUnderstanding 一.整体 ...

  7. TensorFlow损失函数(loss function) 2017-08-14 11:32 125人阅读 评论(0) 收藏 举报 分类: 深度学习及TensorFlow实现(10) 版权声明:

    TensorFlow损失函数(loss function) 2017-08-14 11:32 125人阅读 评论(0) 收藏 举报  分类: 深度学习及TensorFlow实现(10)  版权声明:本 ...

  8. 经验 | 深度学习中常见的损失函数(loss function)总结

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作分享,不代表本公众号立场,侵权联系删除 转载于:机器学习算法与自然语言处理出品    单位 | 哈工大SCIR实 ...

  9. [概念]医学图像分割中常用的Loss function(损失函数) + 从loss处理图像分割中类别极度不均衡

    目录 一.前言 二.损失函数 2.1 根据像素正确与否设计的loss function 2.1.1  Log Loss 2.1.2 WCE Loss 2.1.3 Focal Loss 2.2 根据评测 ...

最新文章

  1. sql的nvl()函数
  2. 提领指向不完全类型的指针_望远镜不完全指南:望远镜原理、类型和配件
  3. .NET开源MSSQL、Redis监控产品Opserver之Redis配置
  4. java 通过反射得到命名空间_利用反射得到一个命名空间下的所有类,并调用?...
  5. hp_jetdirect 9100漏洞检测
  6. pythonlauncher是干什么用的_python launcher是什么
  7. eact Native开发IDE安装及配置
  8. CMP SUB 区别
  9. java 怎样判断拼图是否可还原_拼图游戏可解性判断,自动生成可解拼图
  10. 本门藏经阁 - AndroidX
  11. aps自动排程助企业缩短制造周期
  12. 某Y易盾滑块acToken、data逆向分析
  13. C语言的 = 和 ==、!=
  14. WeiKuCMS多功能微信营销服务系统
  15. 华为eNSP:ACL的配置-访问控制技术
  16. 「技术工具」阿里开源Java在线诊断工具 Arthas 进阶教程
  17. Unix波澜壮阔的发展史
  18. js 计算N年后日期
  19. 报错解决:Before you can run VMware, several modules must be compiled ...
  20. JS判断网页是否在微信中打开

热门文章

  1. jboss怎么连接Oracle数据库,如何在Jboss中配置数据源
  2. python抓取贴吧_python抓取百度贴吧-校花吧,网页图片
  3. java.util.zip.zipexception_Java 压缩zip异常,java.util.zip.ZipException: duplicate entry: 问题...
  4. python构建简单神经网络_Python构建一个简单的神经网络,Pytorch,搭建
  5. PHP群发300万,mysql 300万数据查询500多秒如何优化
  6. windows搭建SFTP服务器
  7. cd返回上一 git_使用Git实现自动化部署项目
  8. linux exfat分区格式化,技术|如何在 Linux 上将 USB 盘格式化为 exFAT
  9. c++ socket线程池原理_一篇文章看懂 ThreadLocal 原理,内存泄露,缺点以及线程池复用的值传递问题...
  10. flask get 参数_Python web 用它5分钟以后,我放弃用了四年的 Flask