ICML-2015


在 CIFAR-10 上的小实验可以参考如下博客:
【Keras-Inception v2】CIFAR-10


文章目录

  • 1 Background and Motivation
  • 2 Advantages
  • 3 Innovations
  • 4 Method
    • 4.1 Batch Normalization
    • 4.2 Inception-v2
  • 5 Datasets
  • 6 Experiments
  • 7 Conclusion / Future work
    • 7.1 Conclusion
    • 7.2 Future work
  • 8 Other normalizations
  • 9 附录 sigmoid
  • 参考

1 Background and Motivation

机器学习中存在一个training data 和 testing data 分布不一致的问题,叫做 Covariate Shift

我们讲的规范一点(【机器学习】covariate shift现象的解释):
假设q1(x)是测试集中一个样本点的概率密度,q0(x)是训练集中一个样本点的概率密度。最终我们估计一个条件概率密度p(y|x,θ),它由x和一组参数θ={θ1,θ2…θm}所决定。对于一组参数来说,对应loss(θ)函数评估性能的好坏
综上,当我们找出在q0(x)分布上最优的一组θ’时,能否保证q1(x)上测试时也最好呢?
传统机器学习假设训练集和测试集是独立同分布的,即q0(x)=q1(x),所以可以推出最优θ’依然可以保证q1(x)最优。但现实当中这个假设往往不成立,伴随新数据产生,老数据会过时,当q0(x)不再等于q1(x)时,就被称作covariate shift

作者注意到
The distribution of each layer’s inputs changes during training(这也就是 Internal Covariate Shift

缺点如下:

  • slows down the training by requiring lower learning rates and careful parameter initialization
  • hard to train models with saturating nonlinearities

作者的解决方法:normalization for each training mini-batch

代码见附录

初始化不好,Wx+b 偏大或者偏小,Sigmoid’(Z)很小,梯度消失,网络越深越明显(反向传播的时候,激活函数的导数和W会随着网络的向前深入一直连乘下去)

解决办法:

  • Relu
  • careful initialization
  • small learning rate

作者的解决办法
ensure that the distribution of nonlinearity inputs remains more stable as the network trains

注意一个前提
It is based on the premise that convariate shift also applies to sub-network and layers.

2 Advantages

  • reaching 4.9% top-5 validation error (and 4.8% test error), exceeding the accuracy of human raters.

3 Innovations

注意到了 Internal Covariate Shift 的问题,并提出BN,大大加速训练速度

4 Method

4.1 Batch Normalization

x = Wu+b


可导性证明

yyy 和 lll 的关系就是损失函数,用链式求导法则感觉怪怪的,比如∂l∂μB\frac{\partial l}{\partial \mu_{B}}∂μB​∂l​ 中,x^i=f1(μB,σB2)\hat{x}_{i} =f_1(\mu_{B}, \sigma_{B}^{2})x^i​=f1​(μB​,σB2​) , σB2=f2(μB)\sigma_{B}^{2} = f_2(\mu_{B})σB2​=f2​(μB​) ,按链式求导法则应该是

∂l∂μB=∂l∂x^i⋅(∂x^i∂σB2⋅∂σB2∂μB+∂x^i∂μB)+∂l∂σB2⋅∂σB2∂μB\frac{\partial l}{\partial \mu_{B}} = \frac{\partial l}{\partial \hat{x}_{i} } \cdot (\frac{\partial \hat{x}_{i} }{\partial \sigma_{B}^{2} } \cdot \frac{\partial \sigma_{B}^{2} }{\partial \mu_{B} }+ \frac{\partial \hat{x}_{i} }{\partial \mu_{B} }) + \frac{\partial l}{\partial \sigma_{B}^{2} } \cdot \frac{\partial \sigma_{B}^{2} }{\partial \mu_{B} }∂μB​∂l​=∂x^i​∂l​⋅(∂σB2​∂x^i​​⋅∂μB​∂σB2​​+∂μB​∂x^i​​)+∂σB2​∂l​⋅∂μB​∂σB2​​

而论文的做法是

∂l∂μB=∂l∂x^i⋅∂x^i∂μB+∂l∂σB2⋅∂σB2∂μB\frac{\partial l}{\partial \mu_{B}} = \frac{\partial l}{\partial \hat{x}_{i} } \cdot \frac{\partial \hat{x}_{i} }{\partial \mu_{B} } + \frac{\partial l}{\partial \sigma_{B}^{2} } \cdot \frac{\partial \sigma_{B}^{2} }{\partial \mu_{B} }∂μB​∂l​=∂x^i​∂l​⋅∂μB​∂x^i​​+∂σB2​∂l​⋅∂μB​∂σB2​​

∂l∂xi\frac{\partial l}{\partial x_i}∂xi​∂l​ 也是这样处理的,不晓得是我记错了还是神经网路的求导不一样,有待我进一步否定自己


注意测试的时候 μ\muμ 和 σ\sigmaσ 要固定下来,固定的方法如下(参考 【深度学习】深入理解Batch Normalization批标准化):

因为每次做Mini-Batch训练时,都会有那个Mini-Batch里m个训练实例获得的均值和方差,现在要全局统计量,只要把每个Mini-Batch的均值和方差统计量记住,然后对这些均值和方差求其对应的数学期望即可得出全局统计量。也即步骤10所示!

Q:疑问?为什么要来个 mm−1\frac{m}{m-1}m−1m​?mmm 是 batch size
A:在概率统计里,当计算方差时,如果期望已知,则除以 mmm,如果期望需要用样本的均值估计出来,则除以 m−1m-1m−1

训练的时候方差已知,除以 mmm,现在方差是估计值,除以 m−1m-1m−1

11 由如下表达式推导得来
y=γ⋅x−E[x]Var[x]+ε+βy = \gamma \cdot \frac{x-E[x]}{\sqrt{Var[x]+\varepsilon }} + \betay=γ⋅Var[x]+ε​x−E[x]​+β


4.2 Inception-v2

有些人总结 v2和v3都出自一篇论文(【Inception-v3】《Rethinking the Inception Architecture for Computer Vision》),结构一样,只是穿上了钢铁侠的盔甲,v3 = v2 + RMSProp + LSR + BN auxiliary

也有些人认为 v2 是出自本片博客讲述的论文(《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》),此论文对 Inception-v1 做了一点点小改进。

  • 5×5 to 两个 3×3
  • parameters 增加 25%,computational cost 增加 30%
  • 28×28 的 Inception 数量 2 to 3
  • average pooling + max pooling 交替,v1中最后用 average pooling,中间都是 max pooling
  • 有 pass through 结构

个人还是倾向于 BN+小改进的 Inception-v1 = Inception-v2,结构如下

【Inception-v1】《Going Deeper with Convolutions》 的结构见表1

5 Datasets

LSVRC 2012

6 Experiments

1)小试牛刀——MINIST

分布看不懂可以想想成是高斯分布,面积的15%,50%,80%

2)加速版 BN

加速技术

  • Increase learning rate
  • Remove Dropout
  • Reduce the L2 weight regularization
  • Accelerate the learning rate decay
  • Remove Local Response Normalization
  • Shuffle training examples more thoroughly
  • Reduce the photometric distortions

加速版本

  • Inception-v1
  • BN-baseline = Inception-v1 + BN,initial learning rate is 0.0015
  • BN-x5:learning rate*5 = 0.0075
  • BN-x30:learning rate*30 = 0.045
  • BN-x5-Sigmoid:Sigmoid instead of Relu

3)华山论剑(掌门solo)

不得不佩服的是 BN使得 Sigmoid 都能收敛,而且精度相当高(Without Batch Normalization, Inception with sigmoid never achieves better than 1/1000 accuracy.)霸气如斯!!!

需要注意的是 BN-x30 反而比 BN-x5 慢,但是最终精度高

BN 大大的提升了训练的速度

4)决战光明顶(门派混战——ensemble)

7 Conclusion / Future work

7.1 Conclusion

BN 是加在 activation function 之前的,而且是以 mini-batch 来计算,不像之前的有些工作加在activation function 之后。

BN 能使得训练更稳定,网络能用更大的 learning rate(会快许多),不那么 care initial 的质量,有些 regularization 的味道,in some case eliminating the need for Dropout.

Q1:链式求导法则的疑惑
Q2:Algorithm 2 的通透理解
Q3:归一化后 γ\gammaγ (scale)and β\betaβ (shift)的作用以及意义
Q4:Batch Normalization also makes training more resilient to the parameters scale. 公式证明的理解

7.2 Future work

  • BN+RNN
  • domain adaptation

We believe that further theoretical analysis of the algorithm would allow still more improvements and applications.

8 Other normalizations


更形象化的理解,四种 normalization 如下所示

9 附录 sigmoid

import matplotlib.pyplot as plt
import numpy as np
import math
x = np.arange(-10,10,1)
y = 1/(1+ math.e**-x)
z = y*(1-y)
plt.plot(x,y)
plt.plot(x,z)
plt.xlabel("z = Wx+b")
plt.ylabel("Sigmoid")
ax = plt.gca()
#去掉边框
ax.spines['top'].set_color('none')
ax.spines['right'].set_color('none')
#移位置 设为原点相交
ax.xaxis.set_ticks_position('bottom')
ax.spines['bottom'].set_position(('data',0))
ax.yaxis.set_ticks_position('left')
ax.spines['left'].set_position(('data',0))plt.legend(["Sigmoid (z)", "Sigmoid' (z)"])# 图例
plt.grid()#网格
plt.savefig('1.jpg')
plt.show()

参考

【1】当我们的经验无法适应新环境的时候该怎么办? Covariate Shift(★★★)
【2】深度学习中 Batch Normalization为什么效果好? - 我不坏的回答 - 知乎
https://www.zhihu.com/question/38102762/answer/391649040 (★★★★★,防止梯度消失爆炸)

【BN】《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》相关推荐

  1. 批归一化《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》

    批归一化<Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift ...

  2. 《Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift》阅读笔记与实现

    今年过年之前,MSRA和Google相继在ImagenNet图像识别数据集上报告他们的效果超越了人类水平,下面将分两期介绍两者的算法细节. 这次先讲Google的这篇<Batch Normali ...

  3. 读文献——《Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift》

    在自己阅读文章之前,通过网上大神的解读先了解了一下这篇文章的大意,英文不够好的惭愧... 大佬的文章在https://blog.csdn.net/happynear/article/details/4 ...

  4. 【论文泛读】 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

    [论文泛读] Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift ...

  5. Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift 论文笔记

    Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift 论文链接: h ...

  6. Batch normalization:accelerating deep network training by reducing internal covariate shift的笔记

    说实话,这篇paper看了很久,,到现在对里面的一些东西还不是很好的理解. 下面是我的理解,当同行看到的话,留言交流交流啊!!!!! 这篇文章的中心点:围绕着如何降低  internal covari ...

  7. 论文阅读:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

    文章目录 1.论文总述 2.Why does batch normalization work 3.BN加到卷积层之后的原因 4.加入BN之后,训练时数据分布的变化 5.与BN配套的一些操作 参考文献 ...

  8. 深度学习论文--Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

    本文翻译论文为深度学习经典模型之一:GoogLeNet-BN 论文链接:https://arxiv.org/abs/1502.03167v3 摘要:训练深度神经网络的难度在于:前一层网络参数的变化,导 ...

  9. Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

    机器学习领域有个很重要的假设:IID独立同分布假设,就是假设训练数据和测试数据是满足相同分布的,这是通过训练数据获得的模型能够在测试集获得好的效果的一个基本保障.BatchNorm就是在深度神经网络训 ...

最新文章

  1. P1541 乌龟棋 题解(洛谷,动态规划递推)
  2. kingcms php 排序 标签,修改PHPCMS V9列表排序,listorder、order排序功能的方法
  3. Hibernate事实:如何“断言” SQL语句计数
  4. dofilter在java中_在Filter的doFilter中进行重定向 出现异常
  5. sqlserver 清除日志
  6. python 数据分析 库_Python数据分析库
  7. mongodb 复制(副本集)
  8. GPS经纬度坐标与XY坐标相互转换的python程序
  9. signature=946b61359fb7b919b57e636da83bf538,X-ray tube.
  10. Tegra X2 系统上安装 openpose
  11. 通过c#打开pdf文件
  12. 考研数学随笔(2)——微分积分关系,中值定理
  13. 【SpringBoot】Error creating bean with name ‘methodValidationPostProcessor‘ defined in class path reso
  14. GPS卫星的导航电文和卫星信号
  15. 电工电子自动控制实验设备QY-DG328B
  16. 【来日复制粘贴】数据透视表分类不同账龄
  17. git fork的使用
  18. redis数据备份与恢复
  19. 《提问的艺术》读书笔记
  20. 审批保单信息java_policy-1 统一保单信息查询管理平台接口 - 下载 - 搜珍网

热门文章

  1. 监控设备乐橙连接linux,最近在做乐橙的监控设备,第一步通过http post json获取accessToken都失败了,请问如何解决?...
  2. 获取你的WIFI密码-fluxion(附操作视频)
  3. php中加入 空格的代码,在HTML中插入空格的几种方法
  4. 最近在看一本不错的书~推荐给大家
  5. 嵌入式软件工程师介绍
  6. 安卓开发入门小程序!一个本科渣渣是怎么逆袭从咸鱼到Offer收割机的?灵魂拷问
  7. php临时文件夹,PHP上传 找不到临时文件夹的解决方法
  8. 线上线下英文词典工具、在线翻译全搜罗
  9. 三款好用的软件代码检测工具
  10. 爱奇艺发布一支特别的视频:除了人全是“假”的