利用Group Conv在单个CNN中实现集成学习
论文:Group Ensemble: Learning an Ensemble of ConvNets in a single ConvNet
地址:https://arxiv.org/pdf/2007.00649.pdf
0. 动机
集成学习通过融合多个模型得到更准确的结果,在深度神经网络模型上应用集成学习可以提高网络性能。
通常来讲,在深度神经网络中使用的集成策略可以分为2类:
其一为显式的集成策略(explicit ensembling strategy),训练多个模型,通过专家委员会(a committee of experts)或其他方法得到最终结果。由于需要使用多个神经网络模型,这种方法计算量比较大。
其二为隐式的集成策略(implicit ensembling strategy),这种方法又能分为2类:
在单个模型中引入随机操作模仿集成,比如DropOut、DropConnect、DropBlock、StochDepth、Shake-Shake;
在单个模型中使用多条路径模仿集成,比如ResNet、ResNeXt、DenseNet、Inception系列。
作者提出了Group Ensemble方法,在单个模型中融合显式和隐式的集成方法,能够在几乎不提升原模型计算需求的基础上提高模型性能。
1. Group Ensemble
一些研究表明,对同一架构的网络使用不同方法(比如改变权重初始值、改变数据集的划分策略)训练多次,得到的多个网络中浅层的表达很类似;甚至不同的网络结构,比如VGG和ResNet系列网络,基于同一任务训练,得到的网络浅层表达也很类似。基于上述发现,作者认为,可以通过共享浅层网络、独立多个深层网络来达到显式集成的效果,以减少计算量。
作者提出了名为GENet(Group Ensemble Network)的网络结构,用于在单个神经网络中进行显示集成学习,如下图所示:
从上图中可以看出,作者使用Group卷积的形式达到“独立多个深层网络”的效果,每个分类器的输出结果融合得到最终结果。
多个独立的深层网络提供了用于集成学习的多个分类器,并且引入了多个分类器的差异性;共享浅层网络,相比于传统的使用多个模型集成的方法,有效地减少了计算量。多个独立的深层网络使用同一个浅层网络,因此共享的浅层网络也可以理解为对共享参数的一种正则化手段。
假设深层的Group卷积中共有n个组,整个模型的损失函数为:
上式中的表示第m个Group卷积的损失函数。
推理时,这些独立的深层网络分别输出各自的结果,通过取平均值得到整个网络的最终结果。
2. Aggregation Strategies
作者介绍了3种训练时给样本分配权重的方法,它们分别为Group Averaging、Group Wagging、Group Boosting,可以使用这些方法提高深层网络的多样性,达到更好的集成效果。
第m组的损失函数 可以表示为:
上式中,i表示样本的索引,b 表示 batch size,表示第i个样本的损失,表示第i个样本对应的权重。下图说明了该公式的含义:
Group Averaging、Group Wagging、Group Boosting这3种方法决定了上式中的取值策略。
Group Averaging:所有的取值为1,不同组的多样性由它们的权重初始值、各自的损失和各自的反向传播决定。
Group Wagging:是服从均匀分布或者高斯分布的随机值,若服从高斯分布,表示为。随机选取能在样本层面增加不同组的多样性。
Group Boosting 当前组的取决于上一组对该样本的识别效果,当上一组对该样本识别错误时,在当前组中取较大值,即让当前组重点关注上一组识别错误的样本。为了在训练时就能得到某个组对样本的识别效果,作者使用了online boosting方法,定义在第m组中某个样本的权重为:
上式中的为第m-1组对该样本的识别概率,为参数。
这三种方法的示意图如下所示:
作者探讨了多个组之间保持独立性对于最终分类结果的影响:
假设一个简单的2分类任务,多个模型的输出服从于同样的正态分布,即
上式中表示对于某个样本第i个模型输出的正确类别的预测得分,n表示模型数量,和分别表示均值和标准差。
如下图所示:
左上图为单个模型的输出分布,右上图为2个独立模型的输出分布,左下图为4个独立模型的输出分布,从这3个图可以看出,融合多个模型的输出结果会更准确,且模型个数越多越好。
右下图为2个相关系数的模型的输出分布,对比右上图和右下图可以看出,模型越独立,融合后输出结果越准确。因此可知保持多个模型独立的重要性。
3. 实验
3.1 分类任务
作者在CIFAR和ImageNet数据集上测试GENet。使用Group Averaging策略分配样本权重,使用Pytorch实现。
对于CIFAR数据集,使用ResNet-29和ResNeXt-29作为Backbone,网络最后一层为用于分类的FC层,使用center-crop的方法在验证集上测试。在CIFAR-10和CIFAR-100数据集上的Top1错误率如下图所示:
上图中的Ensemble表示使用传统的多个模型集成方法,“2x”和“3x”分别表示使用2个和3个模型。从上图中可以看出GENet保持和baseline基本一致的计算量,性能与传统的模型集成方法接近。
对于ImageNet数据集,使用ResNet-50、ResNeXt-50、ResNeXt-101作为Backbone,在验证集上使用224x224尺寸的center-crop进行测试,错误率如下图所示:
从上图中可以看出,GENet和传统的模型集成方法性能接近,但是有更少的参数量和计算需求。
3.2 目标检测任务
在COCO2017数据集上使用Faster R-CNN测试,Backbone为ResNet50并且使用了FPN。保持Faster R-CNN中的RPN不变,只将R-CNN head部分改为group ensemble形式,测试结果如下图所示:
从上图中可以看出,使用group ensemble能够在保持参数量不变的前提下达到多个模型集成的性能,验证了GENet对baseline性能的提升,以及在不同领域的有效性。
4. 总结
提出了GENet,通过共用浅层网络、独立深层网络的形式在单个CNN中实现显式模型集成;
说明了保持多组网络独立性的意义,以及通过aggregation strategy提高多组网络独立性和多样性的方法;
通过实验证明了GENet在几乎不改变baseline计算需求的前提下提升性能。
仅用于学习交流!
END
备注:部署
模型压缩与应用部署交流群
模型压缩、网络压缩、神经网络加速、轻量级网络设计、知识蒸馏、应用部署、MNN、NCNN等技术,
扫码备注拉你入群。
我爱计算机视觉
微信号:aicvml
QQ群:805388940
微博知乎:@我爱计算机视觉
投稿:amos@52cv.net
网站:www.52cv.net
在看,让更多人看到
利用Group Conv在单个CNN中实现集成学习相关推荐
- CNN中卷积的学习笔记
1 致谢 感谢赵老师的讲述~ 2 前言 今天在学习CNN~ 记得很久以前,小伙伴曾经问过我一个问题,为什么CNN网络要使用卷积运算作为神经元的输入, 那时候我还没怎么开始学深度学习,觉得这是一个很玄妙 ...
- 机器学习中的集成学习模型实战完整讲解
2019-12-03 13:50:23 集成学习模型实践讲解 --沂水寒城 无论是在机器学习领域还是深度学习领域里面,通过模型的集成来提升整体模型的性能是一件非常有效的事情,当前我们所接触到的比较成熟 ...
- 基于机器学习中集成学习的stacking方式进行的金线莲质量鉴别研究(python进行数据处理并完成建模,对品种进行预测)
1.前言 金线莲为兰科开唇兰属植物,别名金丝兰.金丝线.金耳环.乌人参.金钱草等,是一种名贵中药材,国内主要产地为较低纬度地区如:福建.台湾.广东.广西.浙江.江西.海南.云南.四川.贵州以及西藏南部 ...
- CNN中各类卷积总结:残差、shuffle、空洞卷积、变形卷积核、可分离卷积等
CNN从2012年的AlexNet发展至今,科学家们发明出各种各样的CNN模型,一个比一个深,一个比一个准确,一个比一个轻量.我下面会对近几年一些具有变革性的工作进行简单盘点,从这些充满革新性的工作中 ...
- CNN中十大拍案叫绝的操作
目录 一.卷积只能在同一组进行吗?- Group convolution 二.卷积核一定越大越好?-3×3卷积核 四.怎样才能减少卷积层参数量?- Bottleneck 五.越深的网络就越难训练吗?- ...
- (CVPR-2022)将内核扩展到31x31:重新审视cnn中的大型内核设计
将内核扩展到31x31:重新审视cnn中的大型内核设计 paper题目:Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design ...
- Winograd,GEMM算法综述(CNN中高效卷积实现)(上)
高效卷积实现算法和应用综述(上) 在下一篇文章会介绍Winograd算法的应用,在ICLR,CVPR,FPGA,FCCM等机器学习和FPGA领域的定会上的 ...
- 谈谈CNN中的位置和尺度问题
来自 | 知乎 作者 | 黄飘 链接 | https://zhuanlan.zhihu.com/p/113443895 编辑 | 深度学习这件小事公众号 本文经作者授权转载,作学术交流,请勿二次转载 ...
- CNN中的即插即用小模块汇总
文章目录 前言 1 STN 2. ASPP 3. Non-local 4. SE 5. CBAM 6 DCN v1&v2 7 CoordConv 8 Ghost 9 BlurPool 10 R ...
最新文章
- PocketPC 全屏的实现
- keras回调监控函数
- 时间序列因果关系_分析具有因果关系的时间序列干预:货币波动
- 英特尔助力金山云带你畅游云端的游戏世界
- python与html5搭建聊天室_html5 websocket 新版协议聊天室 服务端(python版)
- 电商商城系统活动设计
- 解决CsrfFilter与Rest服务Post方式的矛盾
- RHEL6.3更换163 centos源或本地源(适用rhel7)
- 瑞典皇家理工学院工程类表
- 基于mysql的可视化日志管理——loganalyzer
- oracle报错imp报错00008,imp导入时遭遇IMP-00032,IMP-00008错误.
- FogROS2 使用 ROS 2 的云和雾机器人的自适应和可扩展平台
- Linux man指令查询文档设定成中文
- 7个月时间“从零到亿”,社交电商靠谱好物为何总能占据行业“C位”?
- Bug的级别,按照什么划分
- CSS3选择器及优先级
- leetcode 6 z字型变换
- python关联规则apriori算法_Python --深入浅出Apriori关联分析算法(二) Apriori关联规则实战...
- 变频器是如何节能的?
- 引入CSS样式表的三种方法
热门文章
- Opencv--学习Opencv比较好的网址
- Spring框架 简述
- matlab中ahp方法,AHP及matlab程序.doc
- python 聚类_聚类算法中的四种距离及其python实现
- python导入同一文件夹下的类_Python模块导入机制与规范
- anaconda创建新环境_【创建社会主义新农村】怀城街道:转变整治理念 农村人居环境换新颜...
- 工程安全cso千人千面计算机,千人一面变为千人千面 自适应教育助力因材施教...
- linux脚本是什么语言,Linux学习之Shell脚本语言的优势是什么?
- html小球跳跃技术原理,HTML5在文本上跳跃的小球
- python语音信号时频分析_librosa-madmom:音频和音乐分析