文 |  Sherry 不是小哀

集成模型(Ensemble)可以提升模型的精度,但往往面临提升计算量的困境,用级联模型(Cascade)在预测时提前中断则可解决计算量的问题。最近,谷歌和CMU的研究者对此进行了深入的分析,他们比较了常见深度神经网络在图像任务上集成学习的效果。他们提出,通过多个轻量级模型集成、级联可以获得相比单个大模型更高效的提分方案。

目前大家大都通过设计模型结构,或是暴力扩大模型规模来提升效果,之后再通过模型剪枝提高效率。本文提出,这些方法费时费力,在实际应用中,可以通过更好的集成、级联模型设计来获取更高效的提分策略。

论文题目:
Multiple Networks are More Efficient than One: Fast and Accurate Models via Ensembles and Cascades

论文链接:
https://export.arxiv.org/pdf/2012.01988.pdf

Arxiv访问慢的小伙伴也可以在 【夕小瑶的卖萌屋】订阅号后台回复关键词 【1223】 下载论文PDF~

高效的提分策略

▲cascade1.png

Xiaofang Wang等人将集成学习的方法应用到常见的图像分类模型上,仅仅使用2-3个弱分类器(例如EfficientNet-B5)就可在同样推理计算量的条件下达到强分类器(例如EfficientNet-B6甚至B7)的准确率。如果进一步加入了级联学习的机制则可进一步降低运算量。

从上图中我们可以看出,集成学习本身(方块)已经相对于单模型(圆点)在精度(Accuracy)-运算量(FLOPS)平面上有提升,而加入了级联方法(五角星)则可进一步提升效果。特别的,尽管经过精心设计的Inception-v4模型(位于(13,80)的黑点)表现优于所有ResNet(下方黑色圆点)模型,但通过级联得到的ResNet(蓝色五角星)可以在准确率-计算量图上获得优于Inception-Net的效果。

群众的眼睛是雪亮的!

集成学习的方法可以为什么可以暴力提高模型预测准确率呢?我们首先训练多个弱分类器(这里拿分类任务来举例子),把每个弱分类器的意见结合起来看,我们就能得到一个更靠谱的分类结果。常见的集成学习方法包括Bagging[2], Boosting[3]AdaBoost[4]。实际应用中,我们使用不同的随机种子初始化模型,将训练好的模型预测概率取平均,或者是简单的投票,就能提升一定的准确率。

Thomas G Dietterich在[5] 中就给出了集成学习能成功的理论解释。用平均值的方法集成模型可以看成在假设空间中找一组点的重心,投票的方法也类似找某个“心”。

统计学上来说,我们使用模型学习假设时,如果训练数据量小于假设空间的大小时,模型就会学到不同的假设。上图的左上角中,外部曲线表示假设空间,内部曲线表示在训练数据上能学到的假设范围,点f是真实的假设;通过平均几个学习到的假设,我们可以找到f的良好近似值。

从随机梯度下降(SGD)的角度而言,我们通常得到的是局部最优解。把从不同初始参数学到的模型集合起来,可以比任何单独的分类器更好地近似真实分布(上图右上角)。

从表示学习角度出发,由于模型和数据的限制,在大多数训练集,学习到整个假设空间的假设,例如上图下半部分。通过平均,可以扩展可表示函数的空间,从而得到这些原本无法学习到的表示。

暴力获得又好又快的模型

实际应用中,我们的资源往往是有限的。在不降低模型精度的条件下减少运算量一直是个重要的命题,很多研究者也对模型效率的提升作出了深入的研究,例如对模型结构进行精细的改造。但这些方法往往要求对下游任务有深入的理解,或者是需要大量的资源来进行网络进化的搜索。我们已经知道集成学习可以获得更好的精度,那么只要能成功降低运算量,是不是就可以做到又好又快了?级联学习就是个很不错的方法。

对于一个很简单的题目,小盆友就可以准确地得出答案,那我们也没有必要让所有砖家都和ta一起做一遍题,对吧?级联学习就利用这样的想法,我们先让一些弱分类器对问题作出预测,如果它有很高的置信度,我们就可以相信他的答案,这样就不需要用其他模型预测,可以大大减少运算量。文中对每个分类器设定了一个置信度阈值,这里他们使用概率最大类的得分作为预测的置信度,当前第k个分类器的置信度超过阈值的时候我们就结束预测并给出前k个分类器集成的答案,否则继续加入下一个分类器的结果。

本文用两个弱分类器集成做实验。他们发现当第一个分类器的退出阈值不断提高,在某个阈值之后集成模型的效果将达到平台(可以认为这个平台是不加入提前退出的集成模型效果),而平台的最左端与最右端比,平均运算量有50%左右的降低。同时,在用B3, B5, B5, 和 B5集成获得B7模型准确率的实验中,他们发现这些模型的退出比例依次 67.3%, 21.6%, 5.6% 和 5.5%。也就是说对67.3%的情况,我们只需要用一个B3模型就运算量可以获得B7模型的准确率;而只有5.5%的情况需要运算所有四个模型来集成。这正说明了级联学习可以有效降低集成模型的预测运算量。

▲cascade3.png

准确率和运算量的精准控制

仅仅减少运算量还不够,模型上线的时候往往对准确率和运算量有着严格的要求。我们还可以用优化算法在满足一些条件的情况下找到最佳级联模型的设定。例如:

在满足运算量上限的同时获得更高的准确率。除了限定运算量之外,还可以选择最低准确率,最差情况运算量作为优化问题的限制条件。本文由于只选择较少的弱分类器,使用暴力搜索来解这个优化这个问题。我们还可以通过更有效率的方法得到级联方案,参考[6].

没有多种模型?可以自级联!

上述集成和级联方法都要求我们有多种设定的不同模型,那如果我们只能训练一个模型呢?借鉴(Hugo Touvron, Andrea Vedaldi, Matthijs Douze, and Herve ́ Je ́gou. Fixing the train-test resolution discrepancy. In NeurIPS, 2019.)的想法,在预测的时候,我们将不同清晰度的图片输入同一个模型,从而达到多模型集成的效果。例如在下图表格的第一行(B2)中,我们有一张图片,使用240*240和300*300的两种分辨率的图片输入,结果看作两个模型集成。从实验结果可以发现,通过自级联的方法后,在保持相似准确率的同时,我们可以获得1.2-1.7倍的加速。

总结

本文探究并分析了结合集成和级联的方法,简单有效地在提升模型准确度的同时降低了运算量。除了分类任务之外,本文同样也验证了该方法在视频分类和图像分割任务上的有效性。

整体而言,本文并没有提出新的算法,但是为我们提供了工程上线时低成本获得高精度模型的一种方案。个人认为本文的一大缺点在于如此级联预测会给并行提速增加难度,原文作者也承认了这一点并指出该方法对离线预测更有效。

本文虽然是在图像数据上做的实验,但是集成和级联不局限于CNN,迁移到NLP同样适用。

萌屋作者:Sherry 不是小哀

本科毕业于复旦数院,转行NLP目前在加拿大滑铁卢大学读CS PhD。经历了从NOIer到学数学再重回CS的转变,却坚信AI的未来需要更多来数学和自认知科学的理论指导。主要关注问答,信息抽取,以及有关深度模型泛化及鲁棒性相关内容。

作品推荐:

  1. 无需人工!无需训练!构建知识图谱 BERT一下就行了!

  2. Google Cloud TPUs支持Pytorch框架啦!

后台回复关键词【入群

加入卖萌屋NLP/IR/Rec与求职讨论群

后台回复关键词【顶会

获取ACL、CIKM等各大顶会论文集!

 

[1]Multiple Networks are More Efficient than One: Fast and Accurate Models via Ensembles and Cascades (https://export.arxiv.org/pdf/2012.01988.pdf)

[2]Bagging predictors. by Leo Breiman. P123–140, 1996.

[3]The strength of weak learnability. by Robert E Schapire. P197–227, 1990.

[4]A decision-theoretic generalization of on-line learning and an application to boosting. by Yoav Freund and Robert E Schapire.

[5]Ensemble Methods in Machine Learning (https://web.engr.oregonstate.edu/~tgd/publications/mcs-ensembles.pdf)

[6]Approximation Algorithms for Cascading Prediction Models (http://proceedings.mlr.press/v80/streeter18a/streeter18a.pdf)

[7]知乎:关于为什么要使用集成学习 https://zhuanlan.zhihu.com/p/323789069

谷歌、CMU发文:别压榨单模型了!集成+级联上分效率更高!相关推荐

  1. 多路复用 I/O 模型详解, 为什么他能支持更高的并发

    阻塞 I/O 在这种 IO 模型的场景下,我们是给每一个客户端连接创建一个线程去处理它.不管这个客户端建立了连接有没有在做事(发送读取数据之类),都要去维护这个连接,直到连接断开为止.创建过多的线程就 ...

  2. 亚马逊:自动选择AI模型,进化论方法效率更高!

    [新智元导读]亚马逊称,进化论可以帮助AI模型的选择.选择架构是构建AI模型的关键步骤.研究人员表示,鉴定遗传算法和协同进化算法的性能指标取决于彼此之间的相互作用,是寻找最佳(或接近最佳)AI模型架构 ...

  3. 单表查询和多表连接查询哪个效率更快?

    这段时间在做项目的过程中,遇到一个模块,数据之间的联系很复杂,在建表的时候就很纠结,到底该怎么去处理这些复杂的数据呢,是单表查询,然后在业务层去处理数据间的关系,还是直接通过多表连接查询来处理数据关系 ...

  4. 谷歌发布最新看图说话模型,可实现零样本学习,多类型任务也能直接上手

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 兴坤 发自 凹非寺 量子位 报道 | 公众号 QbitAI 谷歌新推 ...

  5. 斩获VCR竞赛榜第一,腾讯微视推出BLENDer单模型,超越多模型最好效果

    出品 | CSDN(ID:CSDNnews) 视觉常识推理VCR (Visual Commonsense Reasoning )是人工智能领域的前沿热点问题,我国<新一代人工智能发展规划> ...

  6. 《预训练周刊》第27期:谷歌发布最新看图说话模型、GitHub:平台上30%的新代码受益于AI助手Copilot...

    No.27 智源社区 预训练组 预 训 练 研究 观点 资源 活动 关于周刊 本期周刊,我们选择了9篇预训练相关的论文,涉及少样本理解.图像检测.决策图.大模型微调.对话微调.分子建模.蛋白质结构预测 ...

  7. 滴滴KDD2017论文:基于组合优化的出租车分单模型 By 机器之心2017年8月14日 10:29 数据挖掘顶会 KDD 2017 已经开幕,国内有众多来自产业界的论文被 KDD 2017 接收。

    滴滴KDD2017论文:基于组合优化的出租车分单模型 By 机器之心2017年8月14日 10:29 数据挖掘顶会 KDD 2017 已经开幕,国内有众多来自产业界的论文被 KDD 2017 接收.本 ...

  8. 谷歌I/O走进TensorFlow开源模型世界:从图像识别到语义理解

    谷歌I/O走进TensorFlow开源模型世界:从图像识别到语义理解 2017-05-23 16:13:11    TensorFlow    2 0 0 一年一度的谷歌开发者大会 Google I/ ...

  9. 为了压榨CNN模型,这几年大家都干了什么

    如果从2006年算,深度学习从产生到火爆已经十年了,在工业界已经产生了很多落地的应用.现在网络的深度已经可达1000层以上,下面我们关注一个问题: 这些年大家是怎么"压榨"CNN模 ...

最新文章

  1. SpringBoot 2.3 新特性之优雅停机,这波操作太秀了!
  2. shell 中柏开机显示efi_中柏 ezpad 平板安装Fedora 21 (Linux)
  3. 接口转发和重定向区别(一)
  4. PHP 单元测试工具 SimpleTest
  5. attachRouteMatched analysis
  6. 小米蓝牙左右互联_解决不同品牌智能家居的兼容问题,小米米家智能多模网关发布...
  7. 工业定焦镜头的选型公式
  8. oracle regr,oracle 分析函数
  9. Linux下MySQL忘记root密码及解决办法
  10. 从安全和不安全两个角度,教你如何发布对象(含各种单例代码)
  11. 学校计算机机房台账,机房工作
  12. 还在头痛被黑客劫持? 五步帮你摆脱烦恼!
  13. Excel转Json 绿色工具
  14. 用命令打开文件服务器资源管理器,Windows10使用命令参数打开文件资源管理器的方法...
  15. SQL Server高级编程
  16. c语言数组众数,众数问题 (C语言代码)
  17. [转]中英文停止词表(stopword)
  18. 计算机毕业设计论文该怎么写?软件工程毕设选题推荐有哪些;计算机毕业设计不会做怎么办;怎么做什么简单;电子信息工程毕业设计要做到什么程度
  19. 微信小程序全选,微信小程序checkbox,微信小程序购物车
  20. 关于Qt 5-MSVC 2015 64位在 win7 64位系统debug程序崩溃的问题

热门文章

  1. Codeforces Round #299 (Div. 2) D. Tavas and Malekas kmp
  2. 华为OJ平台——整形数组合并
  3. 反射,System.Type类
  4. 微机原理8086CPU
  5. 啥叫旁路电容?啥叫去耦?可以不再争论了吗
  6. 一个小码农对嵌入式的理解
  7. mysql连接池_数据库技术:数据库连接池,Commons DbUtils,批处理,元数据
  8. pypinyin 获取多音字的拼音组合
  9. LeetCode 1684. 统计一致字符串的数目(哈希)
  10. LeetCode 848. 字母移位(前缀和+取模)