梯度中心化,一行代码加速训练并提升泛化能力
来源 | 晓飞的算法工程笔记
优化器(Optimizer)对于深度神经网络在大型数据集上的训练是十分重要的,如SGD和SGDM,优化器的目标有两个:加速训练过程和提高模型的泛化能力。目前,很多工作研究如何提高如SGD等优化器的性能,如克服训练中的梯度消失和梯度爆炸问题,有效的trick有权值初始化、激活函数、梯度裁剪以及自适应学习率等。而一些工作则从统计的角度对权值和特征值进行标准化来让训练更稳定,比如特征图标准化方法BN以及权值标准化方法WN。
与在权值和特征值进行标准化方法不同,论文提出作用于权值梯度的高性能网络优化算法梯度中心化(GC, gradient centralization),能够加速网络训练,提高泛化能力以及兼容模型fine-tune。如图a所示,GC的思想很简单,零均值化梯度向量,能够轻松地嵌入各种优化器中。论文主要贡献如下:
提出新的通用网络优化方法,梯度中心化(GC),不仅能平滑和加速训练过程,还能提高模型的泛化能力。
分析了GC的理论属性,表明GC能够约束损失函数,标准化权值空间和特征值空间,提升模型的泛化能力。另外,约束的损失函数有更好的Lipschitzness(抗扰动能力,函数斜率恒定小于一个Lipschitze常数),让训练更稳定、更高效。
梯度中心化
Motivation
BN和WS使用Z-score标准化分别操作于特征值和权重,实际是间接地对权值的梯度进行约束,从而提高优化时损失函数的Lipschitz属性。受此启发,论文直接对梯度操作,首先尝试了Z-score标准化,但实验发现并没有提升训练的稳定性。之后,尝试计算梯度向量的均值,对梯度向量进行零均值化,实验发现能够有效地提高损失函数的Lipschitz属性,使网络训练更稳定、更具泛化能力,得到梯度中心化(GC)算法。
Notations
定义一些基础符号,使用
统一表示全连接层的权值矩阵和卷积层的权值张量,为权值矩阵的第列,为目标函数,和为对和的梯度,与的大小一样。定义为输入特征图,则为输出特征图,为位单位向量(unit vector),为单位矩阵(identity matrix)。
Formulation of GC
对于卷积层或全连接层的权值向量
,通过反向传播得到其梯度,然后如图b所示计算其均值,GC操作定义如下:
也可以将公式1转换为矩阵形式:
由单位矩阵以及单位向量形成矩阵构成,分别负责保留原值以及求均值。
Embedding of GC to SGDM/Adam
GC能够简单地嵌入当前的主流网络优化算法中,如SGDM和Adam,直接使用零均值化的梯度
进行权值的更新。
算法1和算法2分别展示了将GC嵌入到SGDM和Adam中,基本上不需要对原优化器算法进行修改,仅需加入一行梯度零均值化计算即可,大约仅需0.6sec。
梯度中心化的性质
下面从理论的角度分析GC为何能提高模型的泛化能力以及加速训练。
Improving Generalization Performance
GC有一个很重要的优点是提高模型的泛化能力,主要得益于权值空间正则化和特征值空间正则化。
Weight space regularization
首先介绍
的物理意义,经过推算可以得到:
即
可以看作映射矩阵,将映射到空间向量中法向量为的超平面,为映射梯度。
以SGD优化为例,权值梯度的映射能够将权值空间约束在一个超平面或黎曼流形(Riemannian manifold)中,如图2所示,梯度首先映射到
的超平面中,然后跟随映射梯度的方向进行更新。从可以得到,目标函数实际变为:
这是一个权值空间
的约束优化问题,正则化的解空间,降低了过拟合的可能性(过拟合通常是学习了复杂的权值来适应训练数据),能够提升网络的泛化能力,特别是当训练样本较少的情况下。
WS对权值进行
的约束,当初始权值不满足约束时,会直接修改权值来满足约束条件。假设进行fine-tune训练,WS则会完全丢弃预训练模型的优势,而GC可以适应任何初始权值。
Output feature space regularization
以SGD优化方法为例,权值更新
,可以推导得到。对于任何输入特征向量,有以下定理:
相关证明可以看原文附录,定理4.1表明输入特征的常量变化会造成输出的变化,而输出的变化量仅与标量
和相关,与当前权值无关。为初始化权值向量缩放后的均值,假设接近0,则输入特征值的常量变化将几乎不会改变输出特征值,意味着输出特征空间对训练样本的变化更鲁棒。
对ResNet50的不同初始权值进行可视化,可以看到权值都非常小(小于
),这说明如果使用GC来训练,输出特征不会对输入特征的变化过于敏感。这个属性正则化输出特征空间,并且提升网络训练的泛化能力。
Accelerating Training Process
Optimization landscape smoothing
前面提到BN和WS都间接地对权值梯度进行约束,使损失函数满足Lipschitz属性,
和(的Hessian矩阵)都有上界。GC直接对梯度进行约束,也有类似于BN和WS的属性,对比原损失函数满足以下定理:
相关证明可以看原文附录,定理4.2表明GC比原函数有更好的Lipschitzness,更好的Lipschitzness意味着梯度更加稳定,优化过程也更加平滑,能够类似于BN和WS那样加速训练过程。
Gradient explosion suppression
GC的另一个优点是防止梯度爆炸,使得训练更加稳定,作用原理类似于梯度裁剪。过大的梯度会导致损失严重震荡,难以收敛,而梯度裁剪能够抑制大梯度,使得训练更稳定、更快。
对梯度的
norm和最大值进行了可视化,可以看到使用GC后的值均比原函数要小,这也与定理4.2一致,GC能够让训练过程更平滑、更快。
实验
与BN和WS结合的性能对比。
Mini-ImageNet上的对比实验。
CIFAR100上的对比实验。
ImageNet上的对比实验。
细粒度数据集上的性能对比。
检测与分割任务上的性能对比。
结论
梯度中心化GC对权值梯度进行零均值化,能够使得网络的训练更加稳定,并且能提高网络的泛化能力,算法思路简单,论文的理论分析十分充分,能够很好地解释GC的作用原理。
论文地址:https://arxiv.org/abs/2004.01461
论文代码:
https://github.com/Yonghongwei/Gradient-Centralization
推荐阅读
30名工程师,历时1300天打造,又一“国产”AI框架开源了
PyTorch 1.6、TensorFlow 2.3、Pandas 1.1同日发布!都有哪些新特性?
Python 还能实现图片去雾?FFA 去雾算法、暗通道去雾算法用起来! | 附代码
程序员必备基础:Git 命令全方位学习
微软直播马上开始,近百岗位等你来,快戳进直播间
梯度中心化,一行代码加速训练并提升泛化能力相关推荐
- c语言一行代码太长,C语言修改一行代码,运行效率居然提升数倍,这个技巧你知道吗...
对编译.链接.OS内核.系统调优等技术感兴趣的童鞋,不妨右上角关注一下吧,近期会持续更新相关方面的专题文章!引言 近日,网上看到一篇文章,分析数组访问的性能问题.文章经过一系列"有理有据&q ...
- 【Python】Modin,只需一行代码加速你的Pandas
本文翻译自:Shrivarsheni的博客 Modin是一个Python第三方库,可以通过并行来处理大数据集.它的语法和pandas非常相似,因其出色的性能,能弥补Pandas在处理大数据上的缺陷. ...
- 9 行代码提高少样本学习泛化能力,代码已开源
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 转自 | 新智元 来源 | 知乎 作者 | 杨朔 本文介绍一篇最新发 ...
- 利用一个样本估计类别数据分布,9行代码提高少样本学习泛化能力
© 作者|杨朔 学校|悉尼科技大学博士生 研究方向|少样本学习,噪音标签学习 介绍一篇我们刚刚发表在 ICLR 2021 Oral 上的一篇少样本学习工作,简单有效. 论文标题: Free Lunch ...
- ICLR2021 Oral |9行代码提高少样本学习泛化能力
文 | 杨朔@知乎 本文已获作者授权,禁止二次转载 介绍一篇我们刚刚发表在ICLR2021 Oral上的一篇少样本学习工作,简单有效. 题目: Free Lunch for Few-shot Lear ...
- 9行代码提高少样本学习泛化能力! ICLR2021 Oral,代码已开源
本文转载自知乎,已获作者授权转载. 链接:https://zhuanlan.zhihu.com/p/344531704 介绍一篇我们刚刚发表在ICLR2021 Oral上的一篇少样本学习工作,简单有效 ...
- 一行代码加速 sklearn 运算上千倍
作者 | 费弗里 来源 | Python大数据分析 ❞ 1.简介 scikit-learn作为经典的机器学习框架,从诞生至今已发展了十余年,但其运算速度一直广受用户的诟病.熟悉scikit-learn ...
- 一行代码加速你的Pandas数据探索分析
本文3分钟,大幅提升分析数据效率 我们知道,pandas库为EDA提供了许多非常有用的功能.但是,在能够应用大多数功能之前,通常必须先从更通用的功能开始,例如df.describe()函数. 比如以分 ...
- 本周AI热点回顾:一行代码提升训练速度、PyTorch核心技术涉嫌抄袭、bAbI又被屠榜
01 180所高校获批新增人工智能专业 3 月 3 日,教育部官方网站更新了「关于公布 2019 年度普通高等学校本科专业备案和审批结果的通知」.各高校新增备案专业 1672 个.审批专业 181 个 ...
最新文章
- js进阶 12-1 jquery的鼠标事件有哪些
- [WebApp]定宽网页设计下,固定宽度布局开发WebApp并实现多终端下WebApp布局自适应...
- 图解SQLite教程
- NLP之WordCloud:基于jieba+matplotlib库对一段文本生成词云图~~情人节最好的礼物(给你一张过去的词云图,看看那时我们的爱情)
- 第5章 Python 数字图像处理(DIP) - 图像复原与重建4 - 指数噪声
- 配置管理-CMMI的五个等级
- python利用Excel读取和存储测试数据完成接口自动化
- KMP算法代学习之(二)代码深入学习
- 字符编码笔记:ASCII,Unicode和UTF-8(转)
- pageoffice 骑缝章_Java 集成PageOffice自带印章配置连接MySQL
- picasa csdn_如何将发送到Facebook的功能添加到Picasa
- Android 动态更换app图标
- wildcard 的理解
- 1383: 手机短号 (多实例)
- 实时视频直播平台的技术要点详解
- win10企业版激活(自测有效)
- 【爱码物联】“颜值经济”当道,你选的化妆品能溯源么?
- 脑电波实时数据收集——RDA—数据包定义
- Intellij IDEA 去除警告波浪线(Weak Warning)
- 笔记:尚学堂Java300集 第一章
热门文章
- 「深度」线下大数据正成为构建精准“用户画像”的最大助力
- 【书籍下载链接】_2_第二轮_计算机专业书籍
- 关于Android方法数量限制的问题
- PHP获取时间排除周六、周日的两个方法
- 爱上MVC3系列~开发一个站点地图(俗称面包屑)
- AS1.0(2.0)中的XML示例
- 理解HTTP消息头【很完整,例子也很丰富】
- [THUWC2017]随机二分图
- django教程目录
- GM Tech 2 works with Hummer Yes or No