作者丨DarkZero@知乎

来源丨https://zhuanlan.zhihu.com/p/25202034

编辑丨极市平台

本文仅用于学术分享。若侵权,请联系后台作删文处理。

相信每一个刚刚入门神经网络(现在叫深度学习)的同学都一定在反向传播的梯度推导那里被折磨了半天。在各种机器学习的课上明明听得非常明白,神经网络无非就是正向算一遍Loss,反向算一下每个参数的梯度,然后大家按照梯度更新就好了。问题是梯度到底怎么求呢?课上往往举的是标量的例子,可是一到你做作业的时候就发现所有的东西都是vectorized的,一个一个都是矩阵。矩阵的微分操作大部分人都是不熟悉的,结果使得很多人在梯度的推导这里直接选择死亡。我曾经就是其中的一员,做CS231n的Assignment 1里面那几个简单的小导数都搞得让我怀疑人生。

我相信很多人都看了不少资料,比如CS231n的讲师Karpathy推荐的这一篇矩阵求导指南http://cs231n.stanford.edu/vecDerivs.pdf,但是经过了几天的折磨以后,我发现事实上根本就不需要去学习这些东西。在神经网络中正确计算梯度其实非常简单,只需要把握好下面的两条原则即可。这两条原则非常适合对矩阵微分不熟悉的同学,虽然看起来并不严谨,但是有效。

1. 用好维度分析,不要直接求导

神经网络中求梯度,第一原则是:如果你对矩阵微分不熟悉,那么永远不要直接计算一个矩阵对另一个矩阵的导数。我们很快就可以看到,在神经网络中,所有的矩阵对矩阵的导数都是可以通过间接的方法,利用求标量导数的那些知识轻松求出来的。而这种间接求导数的方法就是维度分析。我认为维度分析是神经网络中求取梯度最好用的技巧,没有之一。用好维度分析,你就不用一个一个地去分析矩阵当中每个元素究竟是对谁怎么求导的,各种求和完了以后是左乘还是右乘,到底该不该转置等等破事,简直好用的不能再好用了。这一技巧在Karpathy的Course Note上也提到了一点。

什么叫维度分析?举一个最简单的例子。设某一层的Forward Pass为,X是NxD的矩阵,W是DxC的矩阵,b是1xC的矩阵,那么score就是一个NxC的矩阵。现在上层已经告诉你L对score的导数是多少了,我们求L对W和b的导数。

我们已经知道一定是一个NxC的矩阵(因为Loss是一个标量,score的每一个元素变化,Loss也会随之变化),那么就有

现在问题来了,score是一个矩阵,W也是个矩阵,矩阵对矩阵求导,怎么求啊?如果你对矩阵微分不熟悉的话,到这里就直接懵逼了。于是很多同学都出门右转去学习矩阵微分到底怎么搞,看到那满篇的推导过程就感到一阵恶心,之后就提前走完了从入门到放弃,从深度学习到深度厌学的整个过程。

其实我们没有必要直接求score对W的导数,我们可以利用另外两个导数间接地把算出来。首先看看它是多大的。我们知道一定是DxC的(和W一样大),而是NxC的,哦那你瞬间就发现了一定是DxN的,因为(DxN)x(NxC)=>(DxC),并且你还发现你随手写的这个式子右边两项写反了,应该是

那好,我们已经知道了是DxN的,那就好办了。既然score=XW+b,如果都是标量的话,score对W求导,本身就是X;X是NxD的,我们要DxN的,那就转置一下呗,于是我们就得出了:

完事了。

你看,我们并没有直接去用诸如这种细枝末节的一个一个元素求导的方式推导,而是利用再加上熟悉的标量求导的知识,就把这个矩阵求导给算出来了。这就是神经网络中求取导数的正确姿势。

为什么这一招总是有效呢?这里的关键点在于Loss是一个标量,而标量对一个矩阵求导,其大小和这个矩阵的大小永远是一样的。那么,在神经网络里,你永远都可以执行这个“知二求一”的过程,其中的“二”就是两个Loss对参数的导数,另一个是你不会求的矩阵对矩阵的导数。首先把你没法直接求的矩阵导数的大小给计算出来,然后利用你熟悉的标量求导的方法大概看看导数长什么样子,最后凑出那个目标大小的矩阵来就好了。

呢?我们来看看,是NxC的,是1xC的,看起来像1,那聪明的你肯定想到其实就是1xN个1了,因为(1xN)x(NxC)=>(1xC)。其实这也就等价于直接对d_score的第一维求个和,把N降低成1而已。

多说一句,这个求和是怎么来的?原因实际上在于所谓的“广播”机制。你会发现,XW是一个NxC的矩阵,但是b只是一个1xC的矩阵,按理说,这俩矩阵形状不一样,是不能相加的。但是我们都知道,实际上我们想做的事情是让XW的每一行都加上b。也就是说,我们把b的第一维复制了N份,强行变成了一个NxC的矩阵,然后加在了XW上(当然这件事实际上是numpy帮你做的)。那么,当你要回来求梯度的时候,既然每一个b都参与了N行的运算,那就要把每一份的梯度全都加起来求个和的。因为求导法则告诉我们,如果一个变量参与了多个运算,那就要把它们的导数加起来。这里借用一下@午后阳光的图,相信大家可以看得更明白。

总之,不要试图在神经网络里面直接求矩阵对矩阵的导数,而要用维度分析间接求,这样可以为你省下很多不必要的麻烦。

2. 用好链式法则,不要一步到位

我曾经觉得链式法则简直就是把简单的问题搞复杂,复合函数求导这种东西高考的时候我们就都会了,还用得着一步一步地往下拆?比如,我一眼就能看出来,还用得着先把当成一个中间函数么?

不幸的是,在神经网络里面,你会发现事情没那么容易。上面的这些推导只在标量下成立,如果w,x和b都是矩阵的话,我们很容易就感到无从下笔。还举上面这个例子,设,我们要求,那么我们直接就可以写出

L对H的导数,是反向传播当中上一层会告诉你的,但问题是H对W的导数怎么求呢?

如果你学会了刚才的维度分析法,那么你可能会觉得是一个DxN的矩阵。然后就会发现没有任何招可以用了。事实上,卡壳的原因在于,根本不是一个矩阵,而是一个4维的tensor。对这个鬼玩意的运算初学者是搞不定的。准确的讲,它也可以表示成一个矩阵,但是它的大小并不是DxN,而且它和  的运算也不是简单的矩阵乘法,会有向量化等等的过程。有兴趣的同学可以参考这篇文章,里面有一个例子讲解了如何直接求这个导数:矩阵求导术(下)(https://zhuanlan.zhihu.com/p/24863977)。

这是一个刚学完反向传播的初学者很容易踩到的陷阱:试图不设中间变量,直接就把目标参数的梯度给求出来。如果这么去做的话,很容易在中间碰到这种非矩阵的结构,因为理论上矩阵对矩阵求导求出来是一个4维tensor,不是我们熟悉的二维矩阵。除非你完全掌握了上面那篇reference当中的数学技巧,不然你就只能干瞪眼了。

但是,如果你不直接求取对W的导数,而把当做一个中间变量的话,事情就简单的多了。因为如果每一步求导都只是一个简单二元运算的话,那么即使是矩阵对矩阵求导,求出来也仍然是一个矩阵,这样我们就可以用维度分析法往下做了。

,则有

利用维度分析:dS是NxC的,dH是NxC的,考虑到,那么容易想到也是NxC的,也就是,这是一个element-wise的相乘;所以

再求,用上一部分的方法,很容易求得,所以就求完了。

有了这些结果,我们不妨回头看看一开始的那个式子:,如果你错误地认为是一个DxN的矩阵的话,再往下运算:

我们已经知道,这两个矩阵一个是NxC的,一个是DxN的,无论怎么相乘,也得不出DxN的矩阵。矛盾就是出在H对W的导数其实并不是一个矩阵。但是如果使用链式法则运算的话,我们就可以避开这个复杂的tensor,只使用矩阵运算和标量求导就搞定神经网络中的梯度推导。

借助这两个技巧,已经足以计算任何复杂的层的梯度。下面我们来实战一个:求Softmax层的梯度。

Softmax层往往是输出层,其Forward Pass公式为:

,,

假设输入X是NxD的,总共有C类,那么W显然应该是DxC的,b是1xC的。其中就是第i个样本预测的其正确class的概率。关于softmax的知识在这里就不多说了。我们来求Loss关于W, X和b的导数。为了简便起见,下面所有的d_xxx指的都是Loss对xxx的导数。

我们首先把Loss重新写一下,把P代入进去:

不要一步到位,我们把前面一部分和后面一部分分开看。设, rowsum就是每一行的score指数和,因此是Nx1的,那么就有

先看d_score,其大小与score一样,是NxC的。你会发现如果扔掉前面的1/N不看d_score其实就是一堆0,然后在每一行那个正确的class那里为-1;写成python代码就是

d_score = np.zeros_like(score)
d_score[range(N),y] -= 1

然后看d_rowsum,其实就是,非常简单。

现在我们关注,需要注意的是我们不要直接求是什么,两个都是矩阵,不好求;相反,我们求是多少。我们会发现上面我们求了一个d_score,这里又求了一个d_score,这说明score这个矩阵参与了两个运算,这是符合这里Loss的定义的。求导法则告诉我们,当一个变量参与了两部分运算的时候,把这两部分的导数加起来就可以了。

这一部分的d_score就很好求了:

,左边是NxC的,右边已知的是Nx1的,那么剩下的有可能是1xC的,也有可能是NxC的。这个时候就要分析一下了。我们会发现右边应该是NxC的,因为每一个score都只影响一个rowsum的元素,因此我们不应该求和。NxC的矩阵就是自己,所以我们就很容易得出:

# 实际上,d_rowsum往往是一个长度为N的一位数组,因此我们先用np.newaxis把它的shape由N升维到Nx1,
# 这样就可以使用广播机制(Nx1 * NxC)
# 然后用乘号做element wise相乘。
d_score += d_rowsum[:, np.newaxis] * (np.exp(score))
d_score /= N #再把那个1/N给补上

这样我们就完成了对score的求导,之后score对W, X和b的求导,相信你也就会了。

当然,如果你注意一下的话,你会发现其实第二部分的那个式子就是P矩阵。不过如果你没有注意到这一点也无所谓,用这套方法也可以求出d_score是多少。

利用同样的方法,现在看看那个卡住无数人的Batch Normalization层的梯度推导,是不是也感到不那么困难了?

希望本文可以为刚刚入门神经网络的同学提供一些帮助,如有错漏欢迎指出。

觉得有用麻烦给个在看啦~  

道理我都懂,但是神经网络反向传播时的梯度到底怎么求?相关推荐

  1. 混淆矩阵怎么看_道理我都懂,但是神经网络反向传播时的梯度到底怎么求?

    ↑ 点击蓝字 关注极市平台作者丨DarkZero@知乎来源丨https://zhuanlan.zhihu.com/p/25202034编辑丨极市平台本文仅用于学术分享.若侵权,请联系后台作删文处理. ...

  2. 吴恩达|机器学习作业4.0神经网络反向传播(BP算法)

    4.0.神经网络学习 1)题目: 在本练习中,您将实现神经网络的反向传播算法,并将其应用于手写数字识别任务.在之前的练习中,已经实现了神经网络的前馈传播,并使用Andrew Ng他们提供的权值来预测手 ...

  3. 神经网络反向传播梯度计算数学原理

    [神经网络]反向传播梯度计算数学原理 1 文章概述 本文通过一段来自于Pytorch官方的warm-up的例子:使用numpy来实现一个简单的神经网络.使用基本的数学原理,对其计算过程进行理论推导,以 ...

  4. 学习笔记84—[深度学习]神经网络反向传播(BackPropagation)

    神经网络反向传播实例推导过程: 说到神经网络,大家看到这个图应该不陌生: 这是典型的三层神经网络的基本构成,Layer L1是输入层,Layer L2是隐含层,Layer L3是隐含层,我们现在手里有 ...

  5. 【机器学习笔记】神经网络反向传播算法 推导

    神经网络反向传播算法 推导 (一) 概念及基本思想 (二)信息的前向传播 (三)误差反向传播 (1)输出层的权重参数更新 (2)隐藏层的权重参数更新 (3)输出层与隐藏层的偏置参数更新 (4)反向传播 ...

  6. 7.神经网络反向传播

    文章目录 1.神经网络反向传播的作用 2.神经网络反向传播的流程 3.梯度检测 4.随机初始化 总结 1.神经网络反向传播的作用   在上一节中学到的输入通过隐藏层汇总计算最后到输出层的流程是神经网络 ...

  7. 为什么道理我都懂,却仍过不好一生 | 认知突破

    每个人都有证实偏差 简单来说就是,当你的头脑中预设立场或当你倾向于得到某个结果时,你就更容易在搜寻证据的途中不知不觉偏离公平.而我们之所以没有发现,是因为我们更喜欢自我创造的那个自己. 认清真实的自己 ...

  8. 神经网络反向传播为什么快

    神经网络反向传播为什么快 为什么从后往前算梯度而不是从前往后算梯度? 首先要有链式法则的认知.然后举个最简单的例子,每个神经元都会有激活函数,当系数(比如w)更新,根据链式求导法则,最终的loss对w ...

  9. Batch Normalization函数详解及反向传播中的梯度求导

    摘要 本文给出 Batch Normalization 函数的定义, 并求解其在反向传播中的梯度 相关 配套代码, 请参考文章 : Python和PyTorch对比实现批标准化Batch Normal ...

最新文章

  1. Windows 08R2_AD图文详解
  2. 数组去重是面试中经常问到的问题
  3. 2018,微软可能要在方方面面融入进企业
  4. Perl输出复杂数据结构:Data::Dumper,Data::Dump,Data::Printer
  5. Flink SQL Client的datagen的用法(转载+自己验证)
  6. Selector 实现原理
  7. 重温设计模式之 Factory
  8. 大学学了一个学期的 C 语言,我们应该明白哪些知识点?别像没学一样!
  9. 程序员情人节送这些!
  10. LINUX矩阵键盘简单介绍,stm32矩阵键盘原理图及程序介绍
  11. python classmethod static_【python】classmethod staticmethod 区别
  12. Spring中事务管理的几种配法
  13. 解析FL Studio冻结小技巧
  14. MacOS开发必备工具brew,安装nginx反向代理,替代linux工具 apt-get和 yum
  15. 以太坊虚拟机 EVM(2)Solidity运行原理
  16. javascript 遍历数组的常用方法(迭代、for循环 、for… in、for…of、foreach、map、filter、every、some,findindex)
  17. 怎么使用水经注万能地图下载器制作百度个性化道路地图
  18. 中望cad文字显示问号怎么办_中望CAD图纸显示乱码怎么办?
  19. 【操作系统】—I/O设备的基本概念和分类
  20. dfuse 为你提供定制网络服务

热门文章

  1. 【数据结构】二叉树的应用。
  2. 爱耳日腾讯天籁行动再升级 助力100位青年听障人才打破“屏障”
  3. 再一次输给了AI,弯道急速超车、登上 Nature 封面
  4. AnimeGANv2 实现动漫风格迁移,简单操作
  5. “编程能力差的程序员,90%会输在这点上”谷歌AI专家:其实都是瞎努力
  6. 免费直播:1小时带你体验Python车牌识别实战
  7. 从Ops到NoOps,阿里文娱智能运维的关键:自动化应用容量管理
  8. Bert时代的创新:Bert应用模式比较及其它 | 技术头条
  9. 程序员,快通知你们老板上吴恩达的最新AI课
  10. AI一分钟 | 谷歌CEO承诺在中国组建更大团队;苹果与清华大学成立研究中心,并将帮助30万名贫困学生