正文共3862个字,4张图,预计阅读时间10分钟。

本文介绍关于GoogLeNet的续作,习惯称为inception v2,如下:

[v2] Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift,top5 error 4.8%

这篇文章做出的贡献不是一般的大,它提出了Batch Normalization(BN),以至于网上关于它的介绍铺天盖地,但中文优秀原创没几个,都是转载来转载去,挑几个好的比如:这个(http://blog.csdn.net/hjimce/article/details/50866313)、这个(http://blog.csdn.net/u012816943/article/details/51691868)、这个(http://blog.csdn.net/happynear/article/details/44238541)。我之前也写过一个谈谈Tensorflow的Batch Normalization(https://www.jianshu.com/p/0312e04e4e83),讲了讲BN在Tensorflow中的实现。

前人关于BN介绍的已经太详细了,我就不再重复的了。本文就是想讲一讲BN的反向传播,BN需要调节的参数有两个,γ 和 β,反向传播的计算方式就是下面这张图:

Batch Normalization反向传播

又是令人作呕的公式。

几乎所有介绍BN的文章都把这部分略过了,估计是怕讲不清楚,或者作者根本就不明白也不想深究。BN的理念很好理解,它的优良效果也很好理解,可BN的训练到底是怎么回事?怎么反向传播?Szegedy在论文原文里也只是一句话带过了:

During training we need to backpropagate the gradient of loss ℓ through this transformation, as well as compute the gradients with respect to the parameters of the BN transform. We use chain rule...

上面那一坨公式对于深度学习的老鸟们应该不会构成理解障碍,但对于接触不久的人群,简直就是天书!鉴于此,参考xiaia的cs231n_2016_winter(https://github.com/xiaia/cs231n_2016_winter)作业,捋一捋BN的反向传播到底是怎么实现的,好有个直观理解。

下面的介绍基于cs231n_2016_winter/assignment2的全连接网络,隐藏层5个,每个100个神经元(hidden_dims = [100, 100, 100, 100, 100]),激活函数ReLU,每个隐藏层激活函数前都加了BN层,输出层是softmax-10,optimizer是adam。

Batch Normalization反向传播实现

根据上面那一坨公式,写出来的代码是这样子的:

def batchnorm_backward(dout, cache):  
"""  
Backward pass for batch normalization.    
For this implementation, you should write out a computation graph for

batch normalization on paper and propagate gradients backward through  intermediate nodes.

Inputs:  
- dout: Upstream derivatives, of shape (N, D)  
- cache: Variable of intermediates from batchnorm_forward.

Returns a tuple of:  
- dx: Gradient with respect to inputs x, of shape (N, D)  
- dgamma: Gradient with respect to scale parameter gamma, of shape (D,)  
- dbeta: Gradient with respect to shift parameter beta, of shape (D,)  
"""  
dx, dgamma, dbeta = None, None, None

x, gamma, beta, var, miu, x_hat, eps = cache  
m = len(x)  
dx_hat = dout * gamma  
dvar = np.sum(dx_hat * (x-miu), axis=0) * -0.5 * (var + eps) ** (-1.5)  
dmiu = np.sum(dx_hat * (-1) / np.sqrt(var+eps), axis=0) + dvar * np.mean(-2 * (x - miu), axis=0)  
dx = dx_hat / np.sqrt(var + eps) + dvar * 2 * (x - miu) / m + dmiu / m  
dgamma = np.sum(dout * x_hat, axis=0)  
dbeta = np.sum(dout, axis=0)

return dx, dgamma, dbeta

Tensorflow的源码里应该也会有相应的实现,以后我再找找看。

上面的batchnorm_backward函数就是BN反向传播的python实现版本,仅仅是把公式改写成了python语言而已,这篇博文对代码做了一些解释,可以参考,这里不再赘述。

问题就来了,dout是个什么东西?作为函数的输入,它怎么来的?我再翻一翻源码,找到了这个函数:

def softmax_loss(x, y):  """  Computes the loss and gradient for softmax classification.  Inputs:  - x: Input data, of shape (N, C) where x[i, j] is the score for the jth class    for the ith input.  - y: Vector of labels, of shape (N,) where y[i] is the label for x[i] and    0 <= y[i] < C  Returns a tuple of:  - loss: Scalar giving the loss  - dx: Gradient of the loss with respect to x  """  probs = np.exp(x - np.max(x, axis=1, keepdims=True))  probs /= np.sum(probs, axis=1, keepdims=True)  N = x.shape[0]  loss = -np.sum(np.log(probs[np.arange(N), y])) / N  dx = probs.copy()  dx[np.arange(N), y] -= 1  dx /= N  return loss, dx

softmax_loss用来计算最后softmax层的loss和gradient,函数返回两个值,一个是loss,一个是dx(gradient),这个dx就是dout的源头!也是反向传播的最最最开始的地方!它是这么得来的:

dx = probs.copy() dx[np.arange(N), y] -= 1

注:其中probs是softmax的输出结果。

上面的程序代码是如此的简洁!让人完全蒙圈!逼得我重温了一下反向传播算法,输出层的残差是这么算的:

sigmoid输出层残差计算

代码里的f'(z)去哪儿了???或者这种计算方式是softmax独有?深深的感觉到了自己基础知识的薄弱。我又查阅了Neural Networks and Deep Learning(http://neuralnetworksanddeeplearning.com/chap3.html#problems_68177),终于找到了,其中的公式 (84) 是 softmax 层的残差计算方法,如下:

softmax 残差计算

可是作者让读者自己推倒公式!又蒙圈了,有兴趣的可以自己推倒试一试。

简而言之,dx就是最后一层的gradient,这个dx要一层一层的反向传播回去,不同层的反向传播计算方式也不同,比如ReLU的反向传播计算是这样的:

def relu_backward(dout, cache):  
"""  
Computes the backward pass for a layer of rectified linear units (ReLUs).  
Input:  
- dout: Upstream derivatives, of any shape  
- cache: Input x, of same shape as dout  Returns:  
- dx: Gradient with respect to x  
"""

dx, x = None, cache  dx = dout  dx[x <= 0] = 0

return dx

当然还有 dropout_backward、affine_backward(全连层) 还有上面的 batchnorm_backward 计算函数,不再一一列举。反向传播其实就是把gradient作为输入,按照前向传播相反的方向再计算一遍而已。

总的来讲,加入BN层的反向传播没有发生根本的改变,只是多了一个反向计算过程(batchnorm_backward函数)而已,上述网络的最后几层的前向和反向传播示意图如下:

正反传播

图也画了,代码也给了,公式还是没明白,不深究了。

总之,加入BN层的网络,反向传播的时候也相应的多了BN-back,其中的dgamma、dbeta会根据反向传播的gradient(或者叫残差)计算出来,再利用 optimizer 更新 γ 和 β。

原文链接:https://www.jianshu.com/p/4270f5acc066

查阅更为简洁方便的分类文章以及最新的课程、产品信息,请移步至全新呈现的“LeadAI学院官网”:

www.leadai.org

请关注人工智能LeadAI公众号,查看更多专业文章

大家都在看


LSTM模型在问答系统中的应用

基于TensorFlow的神经网络解决用户流失概览问题

最全常见算法工程师面试题目整理(一)

最全常见算法工程师面试题目整理(二)

TensorFlow从1到2 | 第三章 深度学习革命的开端:卷积神经网络

装饰器 | Python高级编程

今天不如来复习下Python基础

GoogLeNet的心路历程(二)相关推荐

  1. GoogLeNet的心路历程(四)

    正文共1216个字 2张图,预计阅读时间5分钟. 今年年初的时候,Szegedy写了GoogLeNet的第三篇续作,如下: [v4] Inception-v4, Inception-ResNet an ...

  2. GoogLeNet的心路历程(一)

    正文共2964个字,4张图,预计阅读时间8分钟. 这一段时间撸了几篇论文,当我撸到GoogLeNet系列论文的时候,真是脑洞大开!GoogLeNet绝对可以称为已公开神经网络的复杂度之王!每当我看到它 ...

  3. GoogLeNet的心路历程(三)

    正文共2965个字,估计阅读时间10分钟. 本文主要介绍GoogLeNet续作二,inception v3.说实话,Szegedy这哥们真的很厉害,同一个网络他改一改就改出了4篇论文,这是其中第3篇, ...

  4. 接口测试 - 从0不到1的心路历程 (二)

    前段时间我发布了一篇有关自己做接口测试的实践经验,发出后受到了很多小伙伴的关注,也收到了很多佬儿哥的指点,很是开心,TesterHome真是一个温暖的地方.在众多建议中,频率最高的就是"py ...

  5. 前端证券项目_非科班二本前端大厂面试的心路历程和总结(腾讯、头条、阿里、京东)...

    现状和背景 个人背景 我是17年毕业的,大三升大四的暑假期间开始学习前端:在这之前一直在小公司打滚:而且至今已经换了四家公司了(算上接下来入职的公司),可谓跳槽非常频繁(其实是小公司容易倒闭).如果说 ...

  6. 【回眸】Study with me!计算机二/三 级(物联网)刷题的心路历程

    计算机二级(物联网)刷题的心路历程 1.NB-IoT的具体应用不包括__________. A.智能水表 B.共享单车 C.智慧门锁 D.高清视频监控 这题比较难的地方就在于NB-IoT是什么东西,笔 ...

  7. 【博客话题】我的linux心路历程

    2011年的某一天,程程同学在QQ上跟我说"linux 20周年了,有没有关于linux话题的好点子",回神一想,是啊,linux都20周年了,是应该搞一个有意义的话题了,我就建议 ...

  8. android checkbox监听另一个checkbox选中和不选中_一个真正0基础小白学习前端开发的心路历程...

    摘要:真正的0基础小白学习前端开发的心路历程. 距离第一阶段的结束敲响了末尾的声音,抱着初心从开始8号的学习到第一阶段的结束这期间要应付期末考试应付自己的各种事情学习时间总是挤出来的这次学习让我受益匪 ...

  9. 一个真正0基础小白学习前端开发的心路历程

    摘要:真正的0基础小白学习前端开发的心路历程. 距离第一阶段的结束敲响了末尾的声音,抱着初心从开始8号的学习到第一阶段的结束这期间要应付期末考试应付自己的各种事情学习时间总是挤出来的这次学习让我受益匪 ...

最新文章

  1. mariadb 10.1查看per connection内存消耗
  2. 【Linux】ps命令
  3. 06_Android中ArrayAdapter的使用
  4. Hybris Commerce里和Tomcat相关的一些配置信息
  5. 为什么会出现docker
  6. HTML5标签用法及描述
  7. 训练日志 2019.2.14
  8. 过拟合解决方法之L2正则化和Dropout
  9. 比iOS还流畅!国产手机最优秀90Hz手机发布,2999元起
  10. 【排序算法复习备忘】冒泡、选择、插入、归并、快排、堆排序
  11. 华为畅享max支持鸿蒙,华为手机怎么升级鸿蒙?华为鸿蒙系统支持手机型号大全...
  12. Linux Spark安装教程
  13. win的名词_英语语法系列:名词性从句
  14. linkedin 分享_如何在保持电子邮件私密性的同时导入LinkedIn联系人
  15. 【组织架构】中国铁路武汉局集团有限公司
  16. excel批量改名字(含识别区分)
  17. 使用Teleport Ultra批量克隆网站,使用Easy CHM合并生成chm文件
  18. 阿里云域名解析利用accesskey变动态域名DDNS,简易shell脚本型
  19. CSS scroll-behavior 属性 — 纯 CSS 平滑滚动
  20. 爱发猫自动建站程序,自动发布,自动推送,自动收录

热门文章

  1. k8s容器内的东西复制出来_容器 | Docker 如此之好,你为什么还要用k8s
  2. linux内核安装教程,Linux内核5.9的最重要功能及安装方法
  3. plsql修改表名称_Excel教程:常见的工作表技巧(内有冻结拆分窗格)Excel神技巧...
  4. 出现无效字符_网站出现死链的原因分析 - 最蜘蛛池租用
  5. php 10进制位数保持,php 任意进制的数转换成10进制功能实例
  6. python之做一个简易的翻译器(一)
  7. spring boot: 支持jsp,支持freemarker
  8. [转载]网络编辑必知常识:什么是PV、UV和PR值 zz
  9. Develop系列-API Guides-简介-应用基础
  10. WPF——Expander控件(转)