收藏 | 常见的神经网络求导总结!
来源:机器学习与生成对抗网络本文约1700字,建议阅读5分钟本文为你总结常见的神经网络求导。
derivative of softmax
1.1 derivative of softmax
一般来说,分类模型的最后一层都是softmax层,假设我们有一个 分类问题,那对应的softmax层结构如下图所示(一般认为输出的结果 即为输入 属于第i类的概率):
假设给定训练集 ,分类模型的目标是最大化对数似然函数 ,即
通常来说,我们采取的优化方法都是gradient based的(e.g., SGD),也就是说,需要求解 。而我们只要求得 ,之后根据链式法则,就可以求得 ,因此我们的核心在于求解 ,即由上式可知,我们只需要知道各个样本 的 ,即可通过求和求得 ,进而通过链式法则求得 。因此下面省略样本下标j,仅讨论某个样本 。
实际上对于如何表示 属于第几个类,有两种比较直观的方法:
一种是直接法(i.e., 用 来表示x属于第3类),则 ,其中 为指示函数;
另一种是one-hot法(i.e., 用 来表示x属于第三类),则 ,其中 为向量 的第 个元素。
p.s., 也可以将one-hot法理解为直接法的实现形式,因为one-hot向量实际上就是 。
为了方便,本文采用one-hot法。于是,我们有:
1.2 softmax & sigmoid
再补充一下softmax与sigmoid的联系。当分类问题是二分类的时候,我们一般使用sigmoid function作为输出层,表示输入 属于第1类的概率,即
然后利用概率和为1来求解 属于第2类的概率,即
乍一看会觉得用sigmoid做二分类跟用softmax做二分类不一样:
在用softmax时,output的维数跟类的数量一致,而用sigmoid时,output的维数比类的数量少;
在用softmax时,各类的概率表达式跟sigmoid中的表达式不相同。
但实际上,用sigmoid做二分类跟用softmax做二分类是等价的。我们可以让sigmoid的output维数跟类的数量一致,并且在形式上逼近softmax。
通过上述变化,sigmoid跟softmax已经很相似了,只不过sigmoid的input的第二个元素恒等于0(i.e., intput为 ),而softmax的input为 ,下面就来说明这两者存在一个mapping的关系(i.e., 每一个 都可以找到一个对应的 来表示相同的softmax结果。不过值得注意的是,反过来并不成立,也就是说并不是每个 仅仅对应一个 )。
因此,用sigmoid做二分类跟用softmax做二分类是等价的。
02 backpropagation
一般来说,在train一个神经网络时(i.e., 更新网络的参数),我们都需要loss function对各参数的gradient,backpropagation就是求解gradient的一种方法。
假设我们有一个如上图所示的神经网络,我们想求损失函数 对 的gradient,那么根据链式法则,我们有
而我们可以很容易得到上述式子右边的第二项,因为 ,所以有
其中, 是上层的输出。而对于式子右边的的第一项,可以进一步拆分得到
我们很容易得到上式右边第二项,因为 ,而激活函数 (e.g., sigmoid function)是我们自己定义的,所以有
,
其中, 是本层的线性输出(未经激活函数)。
观察上图,我们根据链式法则可以得到
其中,根据 可知
和 的值是已知的,因此,我们离目标 仅差 和 了。接下来我们采用动态规划(或者说递归)的思路,假设下一层的 和 是已知的,那么我们只需要最后一层的graident,就可以求得各层的gradient了。而通过softmax的例子,我们知道最后一层的gradient确实可求,因此只要从最后一层开始,逐层向前,即可求得各层gradient。
因此我们求 的过程实际上对应下图所示的神经网络(原神经网络的反向神经网络):
综上,我们先通过神经网络的正向计算,得到 以及 ,进而求得 和 ;然后通过神经网络的反向计算,得到 和 ,进而求得 ;然后根据链式法则求得 。这整个过程就叫做backpropagation,其中正向计算的过程叫做forward pass,反向计算的过程叫做backward pass。
03 derivative of CNN
卷积层实际上是特殊的全连接层,只不过:
神经元中的某些 为 ;
神经元之间共享 。
具体来说,如下图所示,没有连线的表示对应的w为0:
如下图所示,相同颜色的代表相同的 :
因此,我们可以把loss function理解为 ,然后求导的时候,根据链式法则,将相同w的gradient加起来就好了,即
在求各个 时,可以把他们看成是相互独立的 ,那这样就跟普通的全连接层一样了,因此也就可以用backpropagation来求。
04 derivative of RNN
RNN按照时序展开之后如下图所示(红线表示了求gradient的路线):
跟处理卷积层的思路一样,首先将loss function理解为 ,然后把各个w看成相互独立,最后根据链式法则求得对应的gradient,即
由于这里是将RNN按照时序展开成为一个神经网络,所以这种求gradient的方法叫Backpropagation Through Time(BPTT)。
05 derivative of max pooling
一般来说,函数 是不可导的,但假如我们已经知道哪个自变量会是最大值,那么该函数就是可导的(e.g., 假如知道y是最大的,那对y的偏导为1,对其他自变量的偏导为0)。
而在train一个神经网络的时候,我们会先进行forward pass,之后再进行backward pass,因此我们在对max pooling求导的时候,已经知道哪个自变量是最大的,于是也就能够给出对应的gradient了。
references:
http://speech.ee.ntu.edu.tw/~tlkagk/courses_ML17_2.html
http://www.wildml.com/2015/10/recurrent-neural-networks-tutorial-part-3-backpropagation-through-time-and-vanishing-gradients/
编辑:于腾凯
校对:林亦霖
收藏 | 常见的神经网络求导总结!相关推荐
- 常见的神经网络求导总结!
↑↑↑关注后"星标"Datawhale每日干货 & 每月组队学习,不错过Datawhale干货 作者:Criss,来源:机器学习与生成对抗网络 derivative of ...
- 常见激活函数及其求导相关知识
文章目录 Sigmoid函数 Sigmoid函数介绍 Sigmoid函数求导 tanh 函数 tanh 函数介绍 tanh 函数求导 Relu函数 Relu函数介绍 Relu函数求导 Softmax函 ...
- 神经网络求导与不能求导的情况
关于神经网络的求导和不可求导 ,目前主要是两个地方遇到过,一个是karpathy在Policy Gradient的文章中有一节专门讲了 [1: Non-differentiable computati ...
- 手推卷积神经网络参数(卷积核)求导
手推卷积神经网络求导(卷积链式法则如何理解) 对于卷积如何求参数的导数问题(特别是对多个卷积层如何对初始层数的参数如何求导)困扰我许久了,也一直没有找到这方面的资料,所以自己研究了一下,在这里与大家分 ...
- [机器学习-数学] 矩阵求导(分母布局与分子布局),以及常用的矩阵求导公式
一, 矩阵求导 1,矩阵求导的本质 矩阵A对矩阵B求导: 矩阵A中的每一个元素分别对矩阵B中的每个元素进行求导. A1×1A_{1\times1}A1×1, B1×1B_{1\times1}B1×1 ...
- 机器学习中的线性代数之矩阵求导
前面针对机器学习中基础的线性代数知识,我们做了一个常用知识的梳理.接下来针对机器学习公式推导过程中经常用到的矩阵求导,我们做一个详细介绍. 矩阵求导(Matrix Derivative)也称作矩阵微分 ...
- 【转载】矩阵求导、几种重要的矩阵及常用的矩阵求导公式
一.矩阵求导 一般来讲,我们约定x=(x1,x2,-xN)Tx=(x1,x2,-xN)T,这是分母布局.常见的矩阵求导方式有:向量对向量求导,标量对向量求导,向量对标量求导. 1.向量对向量求导 Nu ...
- 两边同时取对数求复合函数_e2x求导(复合函数求导例题大全)
ln(ex +√(1 +e2x))'=1/[e^内x+√容(1+e^2x)]*[e^x+√(1+e^2x]'=[e^x+e^2x/√(1+e^2x)]/[e^x+√(1+e^2x)]==[e^x+e^ ...
- 二元函数对xy同时求导_【“数”你好看】求导
微积分的核心是极限(Limit),求导(Derivative)是微积分的重要内容,本质就是求极限.导数公式有很多, 靠死记还是比较麻烦的,但这又是微积分的基础,不然接下去导数的应用(求切线.求法线.增 ...
最新文章
- MySQL 5.5.19 GA 发布 修复多个Bug
- mix2s android p内测,历时一个月,MIX2S成小米首款Android P公测机型
- 第30天:项目时间管理相关错题整理
- python实现使用最近最久未使用算法的请求分页存储管理_答疑(存储管理)之一...
- 华为交换机STP的配置实例
- 带你玩转Visual Studio——带你高效开发
- 三菱gt3的序列号_WinXP sp3序列号大全
- eclipse断点调试(方立勋老师)
- 超市微信小程序怎么做_小程序怎么做的 超市微信小程序怎么做
- nanomsg安装和测试
- 设计配色的基本知识以及原理
- 无法下载文件或程序时的解决方法
- oracle报1653解决办法,oracle 建立查询账号ORA 1653和ORA 01502错误处理方法
- 2011年上半年五大臭名昭著的数据库泄密事件--转载
- win10任务栏卡死重启也没用
- 计算机关闭远程桌面,windows 远程桌面关闭 运行程序退出
- Spacy分词php,Spacy简单入门
- VarianceDeviation Tradeoff(方差、偏差权衡)
- 过滤器:管道过滤器技术特点及性能分析
- BGP高防是什么意思呢?
热门文章
- 【c语言】蓝桥杯算法训练 大小写转换
- 鸿蒙os系统的iphonexr,iPhoneXS/XR终极防水测试:iPhoneXR不幸阵亡
- 开源博客QBlog开发者视频教程:生命周期Page_Load介绍及简洁传递参数的重构方式(四)...
- 聊聊Cassandra的FailureDetector
- 软件测试2019:第四次作业—— 性能测试(含JMeter实验)
- 基于bs4+requests的豆瓣电影爬虫
- maven的tomcat插件如何进行debug调试
- mysql构架,索引,视图,查询语句
- python编码(六)
- Fedora15使用笔记