来源:机器学习与生成对抗网络
本文约2000字,建议阅读5分钟
本文为你介绍神经网络的5种常见求导方式。

01 derivative of softmax

1.1 derivative of softmax

一般来说,分类模型的最后一层都是softmax层,假设我们有一个  分类问题,那对应的softmax层结构如下图所示(一般认为输出的结果  即为输入  属于第i类的概率):

假设给定训练集  ,分类模型的目标是最大化对数似然函数  ,即

通常来说,我们采取的优化方法都是gradient based的(e.g., SGD),也就是说,需要求解  。而我们只要求得  ,之后根据链式法则,就可以求得  ,因此我们的核心在于求解  。

由上式可知,我们只需要知道各个样本  的  ,即可通过求和求得  ,进而通过链式法则求得  。因此下面省略样本下标j,仅讨论某个样本  。

实际上对于如何表示  属于第几个类,有两种比较直观的方法:

  • 一种是直接法(i.e., 用  来表示x属于第3类),则  ,其中  为指示函数;

  • 另一种是one-hot法(i.e., 用  来表示x属于第三类),则  ,其中  为向量  的第  个元素。

  • p.s., 也可以将one-hot法理解为直接法的实现形式,因为one-hot向量实际上就是  。

为了方便,本文采用one-hot法。于是,我们有:

1.2 softmax & sigmoid

再补充一下softmax与sigmoid的联系。当分类问题是二分类的时候,我们一般使用sigmoid function作为输出层,表示输入  属于第1类的概率,即:

然后利用概率和为1来求解  属于第2类的概率,即:

乍一看会觉得用sigmoid做二分类跟用softmax做二分类不一样:

  • 在用softmax时,output的维数跟类的数量一致,而用sigmoid时,output的维数比类的数量少;

  • 在用softmax时,各类的概率表达式跟sigmoid中的表达式不相同。

但实际上,用sigmoid做二分类跟用softmax做二分类是等价的。我们可以让sigmoid的output维数跟类的数量一致,并且在形式上逼近softmax。

通过上述变化,sigmoid跟softmax已经很相似了,只不过sigmoid的input的第二个元素恒等于0(i.e., intput为  ),而softmax的input为  ,下面就来说明这两者存在一个mapping的关系(i.e., 每一个  都可以找到一个对应的  来表示相同的softmax结果。不过值得注意的是,反过来并不成立,也就是说并不是每个  仅仅对应一个  )。

因此,用sigmoid做二分类跟用softmax做二分类是等价的。

02 backpropagation

一般来说,在train一个神经网络时(i.e., 更新网络的参数),我们都需要loss function对各参数的gradient,backpropagation就是求解gradient的一种方法。

假设我们有一个如上图所示的神经网络,我们想求损失函数  对  的gradient,那么根据链式法则,我们有:

而我们可以很容易得到上述式子右边的第二项,因为  ,所以有:

其中,  是上层的输出。

而对于式子右边的的第一项,可以进一步拆分得到:

我们很容易得到上式右边第二项,因为  ,而激活函数  (e.g., sigmoid function)是我们自己定义的,所以有:

其中,  是本层的线性输出(未经激活函数)。

观察上图,我们根据链式法则可以得到:

其中,根据 可知:

和  的值是已知的,因此,我们离目标  仅差  和  了。接下来我们采用动态规划(或者说递归)的思路,假设下一层的  和  是已知的,那么我们只需要最后一层的graident,就可以求得各层的gradient了。而通过softmax的例子,我们知道最后一层的gradient确实可求,因此只要从最后一层开始,逐层向前,即可求得各层gradient。

因此我们求  的过程实际上对应下图所示的神经网络(原神经网络的反向神经网络):

综上,我们先通过神经网络的正向计算,得到  以及  ,进而求得  和  ;然后通过神经网络的反向计算,得到  和  ,进而求得  ;然后根据链式法则求得  。这整个过程就叫做backpropagation,其中正向计算的过程叫做forward pass,反向计算的过程叫做backward pass。

03 derivative of CNN

卷积层实际上是特殊的全连接层,只不过:

  • 神经元中的某些 w 为 0 ;

  • 神经元之间共享 w 。

具体来说,如下图所示,没有连线的表示对应的w为0:

如下图所示,相同颜色的代表相同的  :

因此,我们可以把loss function理解为  ,然后求导的时候,根据链式法则,将相同w的gradient加起来就好了,即

在求各个  时,可以把他们看成是相互独立的  ,那这样就跟普通的全连接层一样了,因此也就可以用backpropagation来求。

04 derivative of RNN

RNN按照时序展开之后如下图所示(红线表示了求gradient的路线):

跟处理卷积层的思路一样,首先将loss function理解为  ,然后把各个w看成相互独立,最后根据链式法则求得对应的gradient,即

由于这里是将RNN按照时序展开成为一个神经网络,所以这种求gradient的方法叫Backpropagation Through Time(BPTT)。

05 derivative of max pooling

一般来说,函数  是不可导的,但假如我们已经知道哪个自变量会是最大值,那么该函数就是可导的(e.g., 假如知道y是最大的,那对y的偏导为1,对其他自变量的偏导为0)。

而在train一个神经网络的时候,我们会先进行forward pass,之后再进行backward pass,因此我们在对max pooling求导的时候,已经知道哪个自变量是最大的,于是也就能够给出对应的gradient了。

references:

http://speech.ee.ntu.edu.tw/~tlkagk/courses_ML17_2.html

http://www.wildml.com/2015/10/recurrent-neural-networks-tutorial-part-3-backpropagation-through-time-and-vanishing-gradients/

版权申明:内容来源网络,版权归原创者所有。除非无法确认,都会标明作者及出处,如有侵权烦请告知,我们会立即删除并致歉。谢谢!

编辑:黄继彦

校对:林亦霖

神经网络的5种常见求导,附详细的公式过程相关推荐

  1. 收藏 | 神经网络的 5 种常见求导,附详细的公式过程

    来源:机器学习与生成对抗网络 本文约1800字,建议阅读5分钟 本文为你介绍5种常见求导的详细过程! 01 derivative of softmax 1.1 derivative of softma ...

  2. 光线跟踪的几种常见求交运算

    光线跟踪的几种常见求交运算 我们知道光线跟踪中最昂贵的就是和几何对象的求交运算了.这里就记录几个比较常见的光线和几何对象求交运算. 球 射线与球体相交可能是射线几何相交测试的最简单形式,这就是为什么这 ...

  3. 手推卷积神经网络参数(卷积核)求导

    手推卷积神经网络求导(卷积链式法则如何理解) 对于卷积如何求参数的导数问题(特别是对多个卷积层如何对初始层数的参数如何求导)困扰我许久了,也一直没有找到这方面的资料,所以自己研究了一下,在这里与大家分 ...

  4. 线性代数之 矩阵求导(2)标量函数求导基本法则与公式

    线性代数之 矩阵求导(2)基本法则与公式 前言 基本约定 标量对向量求导 基本法则 公式 标量对矩阵求导 基本法则 公式 后记 前言 上篇矩阵求导(1)解决了求导时的布局问题,也是矩阵求导最基础的求导 ...

  5. 复杂函数求导/对数指数幂公式

    指数.对数公式 https://wenku.baidu.com/view/69653d53f01dc281e53af0ba.html 求导公式 https://wenku.baidu.com/view ...

  6. Java 枚举(enum) 详解7种常见的用法<详细>

    JDK1.5引入了新的类型--枚举.在 Java 中它虽然算个"小"功能,却给我的开发带来了"大"方便. 大师兄我[大师兄]又加上自己的理解,来帮助各位理解一下 ...

  7. 用numpy、PyTorch自动求导、torch.nn库实现两层神经网络

    用numpy.PyTorch自动求导.torch.nn库实现两层神经网络 1 用numpy实现两层神经网络 2 用PyTorch自动求导实现两层神经网络 2.1 手动求导 2.2 gradient自动 ...

  8. 神经网络求导与不能求导的情况

    关于神经网络的求导和不可求导 ,目前主要是两个地方遇到过,一个是karpathy在Policy Gradient的文章中有一节专门讲了 [1: Non-differentiable computati ...

  9. 李群与李代数2:李代数求导和李群扰动模型

    李群与李代数2:李代数求导和李群扰动模型 1. 整体误差最小化引出求导问题 2. BCH公式与近似形式 2.1 BCH公式 2.2 BCH线性近似 2.3 BCH近似的意义 3. 微分模型--李代数求 ...

最新文章

  1. worktools-源码下拉问题
  2. 7.16 T1 礼物
  3. 一步一步实现自己的模拟控件(6)——控件树及控件区域
  4. 【渝粤题库】国家开放大学2021春3897商务英语1题目
  5. 云原生人物志|华为云CTO张宇昕:云原生已经进入深水区
  6. 作者:陈昕(1982-),女,博士,中国科学院计算机网络信息中心研究员
  7. Linux Ubuntu 18.04安装JDK、Hadoop、Hbase以及图形界面
  8. python如何输入多行数据合并_关于Python中的合并字典,这些问题必须搞清楚!
  9. 邮件系统IP被CBL列黑,怎么样里面申诉呢?
  10. mysql 日期与索引问题
  11. D5M数据手册英文版
  12. python arma_Python实现ARMA模型
  13. python通过文件头识别音频格式
  14. Ubuntu: Firefox 的profile missing解决
  15. mysql查询最近三个月数据方法
  16. vs无法产生pdb文件,也就无法断点调试
  17. dump文件,windbg
  18. lterator,Listlterator
  19. [opencv入门]1.2.6像素处理RGB三颜色数组图
  20. vue-pdf使用+分页预览(踩坑 + 使用本地字体库)

热门文章

  1. 使用指针输入输出一维数组
  2. Python读取内容UnicodeDecodeError错误
  3. SAP产品的Field Extensibility
  4. 七本书籍带你打下机器学习和数据科学的数学基础
  5. Python学习小结---粗略列表解析
  6. 难道他们说的都是真的?
  7. Vue.js 2.0 学习重点记录
  8. 《数据科学家修炼之道》一2.2 新规则
  9. OC之@property和@synthesize
  10. ansible 非root 用户 批量修改用户密码