智源导读:本文主要介绍清华大学黄高团队被ICLR2021接收的一篇文章:Revisiting Locally Supervised Learning: an Alternative to End-to-End Training。

论文链接:https://openreview.net/forum?id=fAbkE6ant2

代码链接:https://github.com/blackfeather-wang/InfoPro-Pytorch

撰文 | 王语霖,清华大学

首发于知乎专栏:炼丹学徒的笔记

本文研究了一种比目前广为使用的端到端训练模式显存开销更小、更容易并行化的训练方法:将网络拆分成若干段、使用局部监督信号进行训练。

我们指出了这一范式的一大缺陷在于损失网络整体性能,并从信息的角度阐明了,其症结在于局部监督倾向于使网络在浅层损失对深层网络有很大价值的任务相关信息。

为有效解决这一问题,我们提出了一种局部监督学习算法:InfoPro。在图像识别和语义分割任务上的实验结果表明,我们的算法可以在不显著增大训练时间的前提下,有效节省显存开销,并提升性能。

01

研究动机

一般而言,深度神经网络以端到端的形式训练。以一个13层的简单卷积神经网络为例,我们会将训练数据输入网络中,逐层前传至最后一层,输出结果,计算损失值(End-to-End Loss),再从损失求得梯度,将之逐层反向传播以更新网络参数。

图1 端到端训练(End-to-End Training)

尽管端到端训练在大量任务中都稳定地表现出了良好的效果,但其效率至少在以下两方面仍然有待提升。其一,端到端训练需要在网络前传时将每一层的输出进行存储,并在逐层反传梯度时使用这些值,这造成了极大的显存开销,如下图所示。

图2 端到端训练具有较大的的显存开销

其二,对整个网络进行前传-->反传的这一范式是一个固有的线性过程。前传时深层网络必须等待浅层网络的计算完成后才能开始自身的前传过程;同理,反传时浅层网络需要等待来自深层网络的梯度信号才能进行自身的运算。这两点线性的限制使得端到端训练很难进行并行化以进一步的提升效率。

图3 端到端训练难以并行化

为了解决或缓解上述两点低效的问题,一个可能的方案是使用局部监督学习,即将网络拆分为若干个局部模块(local module),并在每个模块的末端添加一个局部损失,利用这些局部损失产生监督信号分别训练各个局部模块,注意不同模块间没有梯度上的联通。下图给出了一个将网络拆分为两段的例子。

图4 局部监督学习(Locally Supervised Learning)

相较于端到端训练的两点不足,局部监督学习在效率上先天具有显著优势。其一,我们一次只需保存一个局部模块内的中间层输出值,待此模块完成反向传播后,即可释放存储空间,进而复用同样的空间用以存储下一个局部模块的中间层输出值,如下图所示。简言之,理论上显存开销随局部模块数呈指数级下降。

图5 局部监督学习可有效降低显存开销

其二,不同局部模块的反向传播过程并没有必然的前后依赖关系,在工程实现上,不同模块的训练可以自然的并行完成,例如分别使用不同的GPU,如下图所示。

图6 局部监督学习易于并行化完成

02

问题分析与假设

相信大家看到这里,都会有一个问题:既然局部监督学习的效率自然地高于端到端训练,为什么它现在没有被大规模应用呢?其问题在于,局部监督学习往往会损害网络的整体性能。

以图片识别为例,考虑一种简单自然的情况,我们使用标准的线性分类器+SoftMax+交叉熵作为每个局部模块的损失函数,在CIFAR-10数据集上使用局部监督学习训练ResNet-32,结果如下所示,其中  代表局部模块的数目。可以看出随着值的增长,网络的测试误差急剧上升。

图7 局部监督学习倾向于损害网络性能

若能解决性能下降的问题,局部监督学习就有可能作为一种更为高效的训练范式而取代端到端训练。出于这一点,我们探究和分析了这一问题的原因。

上述局部监督学习和端到端训练的一个显著的不同点在于,前者对网络的中间层特征直接加入了与任务直接相关的监督信号,从这一点出发,一个自然的疑问是,由此引发的中间层特征在任务相关行为上的区别是怎样的呢?因此,我们固定了图7中得到的模型,使用网络每层的特征训练了一个线性分类器,其测试误差如下图右侧所示。其中,横轴代码取用特征的网络层数,纵轴代表测试误差,不同的曲线对应于不同的取值,表示端到端的情形。

图8 中间层特征的线性可分性

从结果中可以观察到一个明显的现象:局部监督学习所得到的中间层特征在浅层时就体现出了极好的线性可分性,但当特征进一步经过更深的网络层时,其线性可分性却没有得到进一步的增长;相比而言,尽管在浅层时几乎线性不可分,端到端训练得到的中间层特征随着层数的加深可分性逐渐增强,最终取得了更低的测试误差。于是便产生了一个非常有趣的问题:局部监督学习中,深层网络使用了分辨性远远强于端到端训练的特征,为何它得到的最终效果却逊于端到端训练?难道基于可分性已经很强的特征,训练网络以进一步提升其线性可分性,不应该得到更好的最终结果吗?这似乎与一些之前的观察(例如deeply supervised net)矛盾。

为了解答这个疑问,我们进一步从信息的角度探究网络特征在可分性之外的区别。我们分别估计了中间层特征  与输入数据  和任务标签  之间的互信息  和,并以此作为中包含的全部信息和任务相关信息的度量指标。

图9 估算互信息

其结果如下图所示,其中横轴为取用信息的层数,纵轴表示估计值。从中不难看出,端到端训练的网络中,特征所包含的总信息量逐层减少,但任务相关信息维持不变,说明网络逐层剔除了与任务无关的信息。与之形成鲜明对比的是,局部监督学习得到的网络在浅层就丢失了大量的任务相关信息,特征所包含的总信息量也急剧下降。我们猜测,这一现象的原因在于,仅凭浅层网络难以如全部网络一般有效分离和利用所有任务相关信息,因此索性去丢弃部分无法利用的信息换取局部训练损失的降低。而在这种情况下,网络深层接收到的特征相较网络原始输入本就缺少关键信息,自然难以基于其建立更有效的表征,也就难以取得更好的最终性能。

图10 中间层特征包含的信息

基于上述观察,我们可以总结得到:局部监督学习之所以会损害网络的整体性能,是因为其倾向于使网络在浅层丢失与任务相关的信息,从而使得深层网络空有更多的参数和更大的容量,却因输入特征先天不足而无用武之地。

03

方法详述

为了解决损失信息的问题,本文提出了一种专为局部监督学习定制的损失函数:InfoPro。首先,我们引入一个基本模型。如下图所示,我们假设训练数据受到两个随机变量影响,其一是任务标签 ,决定我们所关心的主体内容;其二是无关变量,用于决定数据中与任务无关的部分,例如背景、视角、天气等。

图11 变量作用关系假设

基于上述变量设置,我们将InfoPro损失函数定义为下面的结合形式。它用于作为局部监督信号训练局部模块,由两项组成。第一项用于推动局部模块向前传递所有信息;在第二项中,我们使用一个满足特殊条件无关变量来建模中间层特征中的全部任务无关信息(无用信息),在此基础上迫使局部模块剔除这些与任务无关的信息。

图12 InfoPro损失函数

InfoPro与端到端训练和在 2. Analysis 中所述的简单局部监督学习(Greedy Supervised Learning)的对比如下图所示。简言之,InfoPro的目标是使得局部模块能够在保证向前传递全部有价值信息的条件下,尽可能丢弃特征中的无用信息,以解决局部监督学习在浅层丢失任务相关信息、影响网络最终性能的问题。事实上,这也是我们前面观察到的、端到端训练对网络浅层的影响形式。InfoPro与其它局部学习方法最大的区别在于它是非贪婪的,并不直接对局部的任务相关行为(如Greedy Supervised Learning中基于局部特征的分类损失)做出直接约束。

图13 3种训练算法的对比

在具体实现上,由于InfoPro损失的第二项比较难以估算,我们推导出了其的一个易于计算的上界,如下图所示:

图14 InfoPro损失的一个易于计算的上界

关于这一上界的具体推导过程、一些数学性质和其实际上的计算方式,由于流程比较复杂且不关键,不在此赘述,欢迎感兴趣的读者参阅我们的文章~

04

实验结果

1、在不同局部模块数目的条件下,稳定胜过baseline。

2、大量节省显存,且不引入显著的额外计算/时间开销,效果相较端到端训练略有提升;

3、ImageNet大规模图像识别任务上的结果,节省显存的效果同样显著,效果略有提升。

4、Cityscapes语义分割实验结果,除节省显存方面的作用外,我们还证明了,在相同的显存限制下,InfoPro可以使用更大的batch size或更大分辨率的输入图片。

05

结语

总结来说,这项工作的要点在于:(1)从效率的角度反思端到端训练范式;(2)指出了局部监督学习相较于端到端的缺陷在于损失网络性能,并从信息的角度分析了其原因;(3)在理论上提出了初步解决方案,并探讨了具体实现方法。

欢迎大家follow我们的工作~

@inproceedings{wang2021revisiting,title = {Revisiting Locally Supervised Learning: an Alternative to End-to-end Training},author = {Yulin Wang and Zanlin Ni and Shiji Song and Le Yang and Gao Huang},booktitle = {International Conference on Learning Representations (ICLR)},year = {2021},url = {https://openreview.net/forum?id=fAbkE6ant2}
}

ICLR2021 | 清华大学黄高团队:显存不够?不妨抛弃端到端训练相关推荐

  1. RANet:MSDNet加强版!清华黄高团队提出分辨率自适应的高效推理网络RANet!

    关注公众号,发现CV技术之美 本文分享论文『Resolution Adaptive Networks for Efficient Inference』,由清华黄高团队提出分辨率自适应的高效推理网络RA ...

  2. 延迟渲染G-buffer所占显存带宽计算(解决移动端和抗锯齿的若干疑问)

    延迟渲染需要在前面阶段,将计算的内容保留在N张G-buffer中,但是网上的文章只是提及了G-buffer应该压缩,并且尽量少用,没有说明G-buffer所占带宽应该是多少,我将在下面介绍G-buff ...

  3. 圆形的CNN卷积核?华中科大清华黄高团队康奈尔提出圆形卷积,进一步提升卷积结构性能!

    作者丨小马 编辑丨极市平台 写在前面 目前正常卷积的感受野大多都是一个矩形的,因为矩形更有利于储存和计算数据的方便.但是,人类视觉系统的感受野更像是一个圆形的.因此,作者就提出,能不能将CNN卷积核的 ...

  4. 如何做高质量研究、写高水平论文?| 黄高、王兴刚等共话科研与论文写作

    如何产生好的研究思路?如何撰写一篇高质量论文?如何从浩如烟海的论文中寻找好的科研灵感?如何通过Rebuttal为自己的文章扳回一城?导师跟学生之间怎样才能形成更好的合作关系? 在ECCV 2022中国 ...

  5. 10大游戏显存占用率测试

    请注意这里的游戏全是开的最高效果!~ 近几年,显卡的发展速度可以说快的惊人,几乎隔几个月,NVIDIA和AMD就会推出性能更高的新品.很显然,CPU的发展速度早已经不能和显卡相提并论了,NVIDIA的 ...

  6. 综述:PyTorch显存机制分析

    作者 | Connolly@知乎(已授权) 来源 | https://zhuanlan.zhihu.com/p/424512257 编辑 | 极市平台 导读 作者最近两年在研究分布式并行,经常使用Py ...

  7. tensorflow 显存 训练_【他山之石】训练时显存优化技术——OP合并与gradient checkpoint...

    作者:bindog 地址:http://bindog.github.io/ 01 背景 前几天看到知乎上的文章FLOPs与模型推理速度[1],文中提到一个比较耗时又占显存的pointwise操作x * ...

  8. 计算机改显存会有啥影响,显卡显存越大越好吗?显存对电脑速度的影响有哪些?...

    对于刚接触DIY领域的小白玩家来说,衡量显卡性能的指标就是GPU芯片和其频率,这也确实是显卡性能的决定性因素.但除了GPU,还有一个对显卡性能影响较大的部分,那就是显存. 显卡显存越大越好吗?显存对电 ...

  9. 深度学习中GPU和显存

    GPU状态的监控 nvidia-smi: 是Nvidia显卡命令行管理套件,基于NVML库,旨在管理和监控Nvidia GPU设备.nvidia-smi命令的输出中最重要的两个指标:显存占用和GPU利 ...

最新文章

  1. 埃森哲、亚马逊和万事达卡抱团推出的区块链项目有何神通?
  2. 为了不复制粘贴,我被逼着学会了JAVA爬虫
  3. 音视频解决方案之二次开发
  4. leetCode 50.Pow(x, n) (x的n次方) 解题思路和方法
  5. SQL Server配置delegation实现double-hop
  6. html js 如何判断页面是第一次访问还是重复刷新访问,使用JS判断页面是首次被加载还是刷新...
  7. 解析.sens数据集
  8. Qt:Windows编程—DLL注入与卸载
  9. matlab fopen wt,matlab的fopen和fprintf
  10. 面向对象程序设计中“超类”和“子类”概念的来历
  11. 检查eth是否到账_税务检查视角:高新技术企业核查要点
  12. 什么样的博文才能上首页呢?『博客使用技巧』
  13. 解决算法问题的思路 —— 从问题描述到数学表达
  14. Node rabbitmq 入门就够了
  15. Kotlin typealias属性
  16. 十种QQ在线客服代码
  17. IE浏览器设置默认文档模式
  18. 常见的接口测试 开源网站
  19. Windows 11 找不到文件C:\ProgramData\Package Cache\{xxxx}xxx.exe。请确定文件名是否正确后,再试一次。
  20. 手把手教你电脑图片转文字怎么操作,助你提高工作效率

热门文章

  1. php qq对话,用php聊QQ
  2. python全局变量有缩进吗_Python全局变量和局部变量的问题 400 请求报错 -问答-阿里云开发者社区-阿里云...
  3. mysql更改锁机别_MYSQL调度与锁定问题(转)
  4. python的socket编程_Python Socket编程详细介绍
  5. java面试 拦截器问题_面试必问:给我说一下Spring MVC拦截器的原理?
  6. Kotlin尾递归优化
  7. uboot2014.10移植(一)
  8. Determine whether an integer is a palindrome. Do this without extra space.
  9. Azure中继摆脱了WCF的桎梏,走向跨平台
  10. 计算机缺失缺少mfc110.dll等相关文件的解决办法