ICML2022论文解读『Sparse Double Descent: Where Network Pruning Aggravates Overfitting』
论文解读『Sparse Double Descent: Where Network Pruning Aggravates Overfitting』
- 1. 研究动机
- 2. 稀疏神经网络中的双下降现象
- 3. 如何解释稀疏神经网络的泛化性能与双下降现象?
- 4. 与彩票假说的区别与联系
- 5. 后记
- 论文与代码连接
- 参考文献:
「Sparse Double Descent: Where Network Pruning Aggravates Overfitting」是ICML2022关于网络剪枝、彩票假说与模型泛化的一个新工作。
这篇论文主要是受模型过参数化(over-parameterization)和彩票假说(lottery tickets)两方面研究的启发,探索分析了剪枝后的稀疏神经网络的泛化性能。
一句话结论:稀疏神经网络的泛化能力受稀疏度的影响,随着稀疏度不断增加,模型的测试准确率会先下降,后上升,最后再次下降。
1. 研究动机
根据传统机器学习的观点,模型难以同时最小化预测时的偏差与方差,因此往往需要权衡两者,才能找到最合适的模型。这便是广为流传的偏差-方差均衡(bias-variance tradeoff)曲线:随着模型容量增加,模型在训练集上的误差不断下降,然而在测试集上的误差却会先下降后上升。
虽然传统观点认为模型参数过多会导致过拟合,但是神奇的是,在深度学习实践中,大模型往往有着更好的表现。
今年来有学者发现,深度学习模型的测试误差和模型容量的关系,并非是U型曲线,而是具备的双下降(Double Descent)的特点,即随着模型参数变多,测试误差是先下降,再上升,然后第二次下降1 2。
也就是说,过参数的神经网络非但不会发生严重的过拟合,反而有可能具有更好的泛化性能!
这究竟是为什么呢?
彩票假说(lottery tickets)3为解释这一现象提供了一个新的思路。
彩票假说认为,一个随机初始化的密集网络(未剪枝过的初始网络),包含着性能良好的稀疏子网络,这个子网络从原初始化(winning ticket)训练时,可以达到媲美原始密集网络的准确率,甚至还有可能更快收敛(而如果让这个子网络从一个新的初始化值开始训练,效果则往往大不如原始网络)。
当一个网络参数越多,它包含这样一个性能良好的子网络的概率就越大,也就是中彩票的可能性越高。
从这个角度出发,一个过参数的神经网络中,真正对优化和泛化起作用可能只有相当少的一部分参数,而其余的参数只是作为冗余备份存在,即使被剪掉也不会对模型训练产生决定性影响。
彩票假说似乎说明,我们可以安全地剪掉模型当中的冗余参数,而不必担心是否会造成不利影响。还有一些其他文献,本着简单最优的奥卡姆剃刀原则,相信剪枝后的稀疏网络会具有更好的泛化能4。目前的剪枝文献也都强调自己的算法可以在剪去大量参数的情况下,仍保持与原模型相媲美的准确率。
但是联想到双下降现象,我们不禁反思一个基本问题:
剪枝剪去的参数真的是完全冗余的吗?难道过参数更优的双下降在稀疏神经网络上并不成立吗?
为了探寻这个问题的答案,我们参考deep double descent[2]的设置,在稀疏神经网络上进行了大量实验。
2. 稀疏神经网络中的双下降现象
通过实验,我们惊讶地发现,网络中所谓"冗余"的参数其实并不完全冗余。
当参数量逐渐减少,稀疏度逐渐上升时,即使模型训练准确率尚未受到影响,其测试准确率可能已经开始明显下降。这时,模型越来越严重地过拟合噪声。
不同数据集上的稀疏双下降现象。左:CIAFR-10,中:CIFAR-100,右:Tiny ImageNet
如果进一步的增加模型稀疏度,可以发现当经过某个拐点后,模型的训练准确率开始快速下降,测试准确率开始上升,此时模型对噪声的鲁棒性逐步提高。
至于当测试准确率达到最高点后,若继续减少模型的参数,则会影响模型的学习能力。此时,模型的训练与测试准确率同时下降,开始变得难以学习。
此外,我们还发现采用不同的标准来剪枝,得到的模型即使参数量相同,其模型容量/复杂度也不同。例如针对同一类拐点,采用基于权重的剪枝的模型稀疏度更高,而随机剪枝则对应着较低的稀疏度。说明随机剪枝对模型表达能力的破坏更大,想取得相同的效果只能剪更少的参数。
不同剪枝标准下的稀疏双下降现象。左:基于权重的剪枝,中:基于梯度的剪枝,右:随机剪枝
虽然我们的大部分实验都采用了彩票假说的retrain方式,但也尝试了其他几种不同的方法。有趣的是,即使是剪枝后微调(Finetuning)也可以观察到明显的双下降。可见稀疏双下降现象并不局限于从初始化训练一个稀疏网络,哪怕沿用剪枝前训练好的参数值也会有相似的结果。
不同retrain方法下的稀疏双下降现象。左:Finetuning,中:Learning rate rewinding,右:Scratch retraining
我们还调整了标签噪声的比例,来观察双下降现象的变化。类似于deep double descent,提高标签噪声的比例,会使得模型训练准确率下降的起始点,向更高模型容量方向移动(即更低的稀疏度)。而另一方面,标签噪声比例越高,为了取得对噪声的鲁棒性,越多的参数需要被剪去以避免过拟合。
不同标签噪声比例下的稀疏双下降现象。左:20%,中:40%,右:80%
3. 如何解释稀疏神经网络的泛化性能与双下降现象?
在这里我们主要检验了两种可能的解释。
其一是极小点平坦度假说(Minima Flatness Hypothesis)。
一些文章指出,剪枝可以为模型增加扰动,这种扰动使得模型更易收敛到平坦的极小点5。由于极小点越平坦,一般会具有更好的泛化能力,因此该文章认为剪枝通过影响极小点的平坦度影响着模型的泛化 5。
那么,极小点平坦度的变化可以解释稀疏双下降吗?
我们对loss进行如图的可视化,间接比较了不同稀疏度下,模型极小点平坦度的大小。
一维loss可视化
遗憾的是,随着稀疏度提高,loss曲线变得越来越陡峭(不平坦)。极小点平坦度与测试准确率之间并没有呈现出相关关系。
另一是学习距离假说(Learning Distance Hypothesis)
已有理论工作证明,深度学习模型的复杂度与参数到初始化的l2距离(学习距离)息息相关6。学习距离越小,说明模型停留在离初始化越近的位置,好比早停时获得的模型参数,此时还没有足够的复杂度记忆噪声;反之,则说明模型在参数空间上的改变就越大,此时复杂度更高,容易过拟合。
那么,学习距离的变化可以反应双下降的趋势吗?
模型学习距离与测试准确率的变化曲线
如图可见,当准确率下降时,学习距离整体呈上升趋势,且最高点恰好对应准确率的最低点;而当准确率上升时,学习距离也相应下降。学习距离的变化与稀疏双下降的变化趋势基本吻合(尽管当测试准确率第二次下降时,由于可训练的参数过少,学习距离难以再次上升了)。
4. 与彩票假说的区别与联系
我们还进行了彩票假说中winning ticket与重新随机初始化的对比实验。有趣的是,在双下降情景下,彩票假说的初始化方式并不总是优于对网络重新初始化的效果。
彩票假说初始化(Lottery)与重新随机初始化(Reinit)的对比
由图可以看出,Reinit的结果相比于Lottery整体左移,也就是说Reinit方式在保留模型的表达能力方面是逊于Lottery的。这也从另一方面验证了彩票假说的思想: 即使模型的结构完全相同,从不同的初始化训练时,模型的性能也可能相差甚远。
5. 后记
在做这项研究的过程中,我们观察到了一些神奇的、反直觉的实验现象,并尝试进行了分析解释。然而,现有的理论工作还无法完全地解释这些现象存在的原因。
比如说在训练准确率接近100%时,测试准确率会随着剪枝逐渐下降。为何此时模型没有遗忘数据中的复杂特征,反而对噪声更加严重的过拟合?
我们还观察到模型的学习距离会随着稀疏度增加先上升后下降,为何剪枝会导致模型学习距离发生这样的变化?
以及深度学习模型的双下降现象往往需要对输入增加标签噪声才可以观察到2,决定双下降是否发生的背后机制是什么?
还有很多问题目前尚无答案。我们现在也在进行一个新的理论工作,以期能对其中的一个或几个问题进行解释。希望可以早日拨开迷雾,探明这一现象背后的本质原因。
论文与代码连接
论文链接:https://arxiv.org/abs/2206.08684
代码链接:https://github.com/hezheug/sparse-double-descent
参考文献:
Belkin, M., Hsu, D., Ma, S., & Mandal, S. (2018). Reconciling modern machine learning and the bias-variance trade-off.stat,1050, 28. ↩︎
Nakkiran, P., Kaplun, G., Bansal, Y., Yang, T., Barak, B., & Sutskever, I. Deep double descent: Where bigger models and more data hurt. ICLR 2020. ↩︎ ↩︎ ↩︎
Frankle J., & Carbin, M. The lottery ticket hypothesis: Finding sparse, trainable neural networks. ICLR 2019. ↩︎
Hoefler, T., Alistarh, D., Ben-Nun, T., Dryden, N., & Peste, A. Sparsity in deep learning: Pruning and growth for efficient inference and training in neural networks. arXiv preprint arXiv:2102.00554, 2021. ↩︎
Bartoldson, B., Morcos, A. S., Barbu, A., and Erlebacher, G. The generalization-stability tradeoff in neural network pruning. NIPS, 2020. ↩︎ ↩︎
Nagarajan, V. and Kolter, J. Z. Generalization in deep networks: The role of distance from initialization. arXiv preprint arXiv:1901.01672, 2019. ↩︎
ICML2022论文解读『Sparse Double Descent: Where Network Pruning Aggravates Overfitting』相关推荐
- 论文解读:ChangeFormer | A TRANSFORMER-BASED SIAMESE NETWORK FOR CHANGE DETECTION
论文地址:https://arxiv.org/pdf/2201.01293.pdf 项目代码:https://github.com/wgcban/ChangeFormer 发表时间:2022 本文提出 ...
- 论文笔记-精读-8.22-Manifold Regularized Dynamic Network Pruning
目录 总结 要解决的问题&解决的情况 问题 方法的优缺点 优点 缺点 实验结果如何 有哪些可以提升 正文 概要 先验知识 流型动态剪枝-Maniprune 复杂性 相似性 总结 关于本篇文所解 ...
- 点云配准的端到端深度神经网络:ICCV2019论文解读
点云配准的端到端深度神经网络:ICCV2019论文解读 DeepVCP: An End-to-End Deep Neural Network for Point Cloud Registration ...
- Paper:论文解读《Adaptive Gradient Methods With Dynamic Bound Of Learning Rate》中国本科生提出AdaBound的神经网络优化算法
Paper:论文解读-<Adaptive Gradient Methods With Dynamic Bound Of Learning Rate>中国本科生(学霸)提出AdaBound的 ...
- 论文解读:《一种利用二核苷酸One-hot编码器识别水稻基因组中N6甲基腺嘌呤位点的卷积神经网络》
论文解读:<A Convolutional Neural Network Using Dinucleotide One-hot Encoder for identifying DNA N6-Me ...
- 论文解读: Double DIP
论文解读:Double-DIP" : Unsupervised Image Decomposition via Coupled Deep-Image-Priors Unsupervised ...
- CVPR2020论文解读:3D Object Detection三维目标检测
CVPR2020论文解读:3D Object Detection三维目标检测 PV-RCNN:Point-Voxel Feature Se tAbstraction for 3D Object Det ...
- CVPR2020论文解读:三维语义分割3D Semantic Segmentation
CVPR2020论文解读:三维语义分割3D Semantic Segmentation xMUDA: Cross-Modal Unsupervised Domain Adaptation for 3D ...
- Paper:《A Unified Approach to Interpreting Model Predictions—解释模型预测的统一方法》论文解读与翻译
Paper:<A Unified Approach to Interpreting Model Predictions-解释模型预测的统一方法>论文解读与翻译 导读:2017年11月25 ...
- YOLOv4论文解读
论文原文: https://arxiv.org/pdf/2004.10934.pdf 代码实现: https://github.com/AlexeyAB/darknet 一.介绍 原文名称:<Y ...
最新文章
- 从算法到产品:NLP技术的应用演变
- 史上最简单的SpringCloud教程 | 第八篇: 消息总线(Spring Cloud Bus)(Finchley版本)
- 2020人工神经网络第一次作业-解答第一部分
- 连接sql sever2008数据库出现了无法连接到数据库引擎问题解决
- OpenCV将现有算法移植到G-API的实例(附完整代码)
- 自定义PopView
- Oracle数据库之间数据同步 -- DBLink
- ios 将随意对象存进数据库
- SVN客户端日志无法显示的解决
- OkHttp 官方Wiki之【使用案例】
- K3/Cloud 用插件打开一张已存在的单据
- asp.net长文章插入分页符^进行分页
- 坐飞机还是尽量早点出发(差点误机)
- socket 源码分析
- java做游戏前端_小游戏——金庸奇侠传(JAVA,对面向对象的进一步了解)
- itpt_TCPL 第一章:C简要教程
- 截流式合流制设计流量计算_截流式合流制管渠的水力计算要点
- 5款宝藏浏览器插件推荐,每一个都真香,一定要看到最后
- https网络编程——DNS域名解析获取IP地址
- 计算机英语第五版翻译,计算机专业英语教程第5版翻译