前言

想要理解神经网络的工作原理,反向传播(BP)是必须搞懂的东西。BP其实并不难理解,说白了就是用链式法则(chain rule)算算算。本文试图以某个神经网络为例,尽可能直观,详细,明了地说明反向传播的整个过程。

正向传播

在反向传播之前,必然是要有正向传播的。正向传播时的所有参数都是预先随机取的,没人能说这样的参数好不好,得要试过才知道,试过之后,根据得到的结果与目标值的差距,再通过反向传播取修正各个参数。下图就是一个神经网络,我们以整个为例子来说明整个过程

图1:神经网络图

我懒,此图取自参考文献[1],图中的各个符号说明如下(顺序从下往上):
xix_ixi​:输入样本中的第iii个特征的值
vihv_{ih}vih​:xix_ixi​与隐层第hhh个神经元连接的权重
αh\alpha_hαh​:第h个隐层神经元的输入,αh=∑i=1dvihxi\alpha_h=\sum_{i=1}^d v_{ih}x_iαh​=∑i=1d​vih​xi​
bhb_hbh​:第h个隐层神经元的输出,某个神经元的输入和输出有关系f(αh)=bhf(\alpha_h)=b_hf(αh​)=bh​,其中f(x)f(x)f(x)为激活函数,比如Sigmoid函数f(x)=11+e−xf(x)=\dfrac{1}{1+e^{-x}}f(x)=1+e−x1​
whjw_hjwh​j:隐层第hhh个神经元和输出层第jjj个神经元连接的权重
βj\beta_jβj​:输出层第jjj个神经元的输入,βj=∑h=1qwhjbh\beta_j=\sum_{h=1}^q w_{hj}b_hβj​=∑h=1q​whj​bh​
yjy_jyj​:第jjj个输出层神经元的输出,f(βj)=yjf(\beta_j)=y_jf(βj​)=yj​,f(x)f(x)f(x)为激活函数
为了方便书写,我们假设截距项bias已经在参数www和vvv之中了,也就是说在输入数据的时候,我们增添了一个x0=1x_0=1x0​=1,由于我懒,图中没有画出来,但心里要清楚这一点。
相信看了图之后,神经网络的正向传播就相当简单明了了,不过,这里我还是啰嗦一句,举个例子,比如输出yjy_jyj​的计算方法为

yj=f(βj)=f(∑h=1qwhjbh)=f(∑h=1qwhjf(αh))=f(∑h=1qwhjf(∑i=1dvihxi))y_j=f(\beta_j)=f(\sum_{h=1}^q w_{hj}b_h)=f(\sum_{h=1}^q w_{hj}f(\alpha_h))=f(\sum_{h=1}^q w_{hj}f(\sum_{i=1}^d v_{ih}x_i))yj​=f(βj​)=f(h=1∑q​whj​bh​)=f(h=1∑q​whj​f(αh​))=f(h=1∑q​whj​f(i=1∑d​vih​xi​))

反向传播

好了,通过正向传播,我们就已经得到了lll个yyy的值了,将它们与目标值ttt,也就是我们期望它们成为的值作比较,并放入损失函数中,记作LLL。
损失LLL可以自行选择,比如常见的均方误差L=12∑j=1l(yj−tj)2L=\dfrac{1}{2}\sum_{j=1}^l (y_j - t_j)^2L=21​∑j=1l​(yj​−tj​)2
利用这个误差,我们将进行反向传播,以此来更新参数www和vvv。更新时,我们采用的是梯度下降法,也就是

{w:=w+Δwv:=v+Δv\begin{cases}w := w + \Delta w \\ v := v + \Delta v\end{cases}{w:=w+Δwv:=v+Δv​

其中,Δw=−η∂L∂w\Delta w = -\eta \dfrac{\partial L}{\partial w}Δw=−η∂w∂L​,Δv=−η∂L∂v\Delta v = -\eta \dfrac{\partial L}{\partial v}Δv=−η∂v∂L​,η\etaη为学习率。
下面要做的工作就是计算出每个参数的梯度,这也就是链式法则发挥作用的地方了。
比如,我们要计算whjw_{hj}whj​。从网络结构中不难看出whjw_{hj}whj​影响了βj\beta_jβj​从而影响了yjy_jyj​,最终影响了LLL所以我们有

Δwhj=−η∂βj∂whj∂yj∂βj∂L∂yj\Delta w_{hj}=-\eta \dfrac{\partial \beta_j}{\partial w_{hj}} \dfrac{\partial y_j}{\partial \beta_j} \dfrac{\partial L}{\partial y_j}Δwhj​=−η∂whj​∂βj​​∂βj​∂yj​​∂yj​∂L​

只要确定了损失函数LLL和激活函数f(x)f(x)f(x),上面所有的都是可以算的,而且∂βh∂whj=bh\dfrac{\partial \beta_h}{\partial w_{hj}} = b_h∂whj​∂βh​​=bh​这点是显而易见的。并且,∂yj∂βj=∂f(βj)∂βj\dfrac{\partial y_j}{\partial \beta_j} = \dfrac{\partial f(\beta_j)}{\partial \beta_j}∂βj​∂yj​​=∂βj​∂f(βj​)​就是激活函数的导数。
同理,vihv_{ih}vih​影响了αh\alpha_hαh​,从而影响了bhb_hbh​,从而影响了β1\beta_{1}β1​,β2\beta_{2}β2​,…,βl\beta_{l}βl​,从而影响了y1y_1y1​,y2y_2y2​,…,yly_lyl​,最终影响了LLL。

Δvih=−η∂αh∂vih∂bh∂αh∑j=1l(∂βj∂bh∂yj∂βj∂L∂yj)\Delta v_{ih} = -\eta \dfrac{\partial \alpha_h}{\partial v_{ih}} \dfrac{\partial b_h}{\partial \alpha_h}\sum_{j=1}^l (\dfrac{\partial \beta_j}{\partial b_h} \dfrac{\partial y_j}{\partial \beta_j} \dfrac{\partial L}{\partial y_j})Δvih​=−η∂vih​∂αh​​∂αh​∂bh​​j=1∑l​(∂bh​∂βj​​∂βj​∂yj​​∂yj​∂L​)

其中,∂αh∂vih=xi\dfrac{\partial \alpha_h}{\partial v_{ih}}=x_i∂vih​∂αh​​=xi​,∂βj∂bh=whj\dfrac{\partial \beta_j}{\partial b_h} = w_{hj}∂bh​∂βj​​=whj​,∂yj∂βj=∂f(βj)∂βj\dfrac{\partial y_j}{\partial \beta_j} = \dfrac{\partial f(\beta_j)}{\partial \beta_j}∂βj​∂yj​​=∂βj​∂f(βj​)​和∂bh∂αh=∂f(αh)∂αh\dfrac{\partial b_h}{\partial \alpha_h} = \dfrac{\partial f(\alpha_h)}{\partial \alpha_h}∂αh​∂bh​​=∂αh​∂f(αh​)​是激活函数的导数。
至此,我们已经可以算出Δw\Delta wΔw和Δv\Delta vΔv,从而更新参数了。

关于激活函数的几点说明

从推出的公式中不难看出,随着反向传播向输出层这个方向的推进,激活函数的影响也就越来越来了。通俗一点来说,在计算Δwhj\Delta w_{hj}Δwhj​,我们只乘了一个激活函数的导数,然而在计算Δvih\Delta v_{ih}Δvih​时,我们乘了多个激活函数的导数。

Δwhj=−η∂βj∂whjf′(βj)∂L∂yj\Delta w_{hj}=-\eta \dfrac{\partial \beta_j}{\partial w_{hj}} f'(\beta_j) \dfrac{\partial L}{\partial y_j}Δwhj​=−η∂whj​∂βj​​f′(βj​)∂yj​∂L​

Δvih=−η∂αh∂vihf′(αh)∑j=1l(∂βj∂bhf′(βj)∂L∂yj)\Delta v_{ih} = -\eta \dfrac{\partial \alpha_h}{\partial v_{ih}} f'(\alpha_h) \sum_{j=1}^l (\dfrac{\partial \beta_j}{\partial b_h} f'(\beta_j) \dfrac{\partial L}{\partial y_j})Δvih​=−η∂vih​∂αh​​f′(αh​)j=1∑l​(∂bh​∂βj​​f′(βj​)∂yj​∂L​)

不难推断出,如果隐层的层数更多的话,激活函数的影响还要更大。
一个比较传统的激活函数时Sigmoid函数,其图像如下所示。

图2:Sigmoid函数

不难发现,当xxx比较大的时候,或比较小的时候,f′(x)f'(x)f′(x)是趋近于0的,当神经网络的层数很深的时候,这么多个接近0的数相乘就会导致传到输出层这边的时候已经没剩下多少信息了,这时梯度对模型的更新就没有什么贡献了。那么大多数神经元将会饱和,导致网络就几乎不学习。这其实也是Sigmoid函数现在在神经网络中不再受到青睐的原因之一。
另一个原因是Sigmoid 函数不是关于原点中心对称的,这会导致梯度在反向传播过程中,要么全是正数,要么全是负数。导致梯度下降权重更新时出现 Z 字型的下降。
所以,就出现了ReLU这个激活函数 f(x)=max⁡(0,x)f\left( x\right) =\max \left( 0,x\right)f(x)=max(0,x),其图像如下图所示。

图3:ReLU函数

ReLU 对于 SGD 的收敛有巨大的加速作用,而且只需要一个阈值就可以得到激活值,而不用去算一大堆复杂的(指数)运算。
不过,由于它左半边的状态,ReLU在训练时比较脆弱并且可能“死掉”。
因此,人们又研究出了Leaky ReLU,PReLU等等的激活函数。这里不展开讨论。

参考文献

[1] 周志华. 机器学习 : = Machine learning[M]. 清华大学出版社, 2016.
[2] http://cs231n.github.io/neural-networks-1/
[2] http://www.jianshu.com/p/6df4ab7c235c

神经网络中BP(back propagation)到底在干些什么相关推荐

  1. bp神经网络中bp是什么意思,bp神经网络是什么网络

    神经网络plotperform三条不同颜色的曲线表示什么意思 比较随便的截图,纵坐标是误差平方的均值,绿色指的是验证集,红色指的是测试集,蓝色指的是训练集. 一般是用来观察训练集.验证集和测试集的最小 ...

  2. 神经网络中BP算法的推导

    神经网络BP算法的推导 有关BP算法推导的文章数不胜数,但是好多只讲到了单样本的BP算法的推导,有些虽然讲到了多样本的BP算法的推导,但是没讲清楚甚至是讲错了的都有一些. 1. 单样本BP算法推导 关 ...

  3. Attention!神经网络中的注意机制到底是什么?

    原作:Adam Kosiorek 安妮 编译自 GitHub 量子位 出品 | 公众号 QbitAI 神经网络的注意机制(Attention Mechanisms)已经引起了广泛关注.在这篇文章中,我 ...

  4. Vivado IP中的Shared Logic到底是干嘛的?

      在很多Vivado的高速接口的IP中,比如Ethernet.PCIe.SRIO的设置中,都会有个Shared Logic的页面: 可能很多同学并没有很关注这个页面,直接默认设置就完事了.   但其 ...

  5. 透过源码领悟GCC到底在干些什么

    GCC源码分析(一)--介绍与安装 GCC源码分析(一)--介绍与安装 目录(?)[-] 一GCC的作用和运行机制 二GCC的安装 上半年一直在做有关GCC和LD的项目,到现在还没做完.最近几天编程的 ...

  6. 透过源码领悟GCC到底在干些什么(收集整理)

    GCC源码分析(一)--介绍与安装 目录(?)[-] 一GCC的作用和运行机制 二GCC的安装 上半年一直在做有关GCC和LD的项目,到现在还没做完.最近几天编程的那台电脑坏了,所以趁此间隙写一点相关 ...

  7. Dropout层到底在干些什么(Pytorch实现)

    Dropout 一.基本概念 二.Dropout工作原理 1.Dropout工作流程 2.Dropout如何缓解过拟合 3.Dropout实际实现 三.Pytorch实现 1.实际实现方式(训练模式下 ...

  8. JavaScript中的 new 操作符到底做了些什么?

    new做了什么? 使用new关键字在调用函数时,函数的内部自动创建一个新对象 将函数的作用域赋给新的对象(this会指向新的对象); 执行函数的代码(添加属性和方法) 返回新对象(实例化对象) 如果返 ...

  9. python bp神经网络 异或_【神经网络】BP算法解决XOR异或问题MATLAB版

    第一种 %% %用神经网络解决异或问题 clear clc close ms=4;%设置4个样本 a=[0 0;0 1;1 0;1 1];%设置输入向量 y=[0,1,1,0];%设置输出向量 n=2 ...

最新文章

  1. 老司机 iOS 周报 #24 | 2018-06-25
  2. mysql connect 500_MySQL连接问题【mysql_connect和mysql_pconnect区别】
  3. Java高并发编程(十二):Executor框架
  4. Model和ViewModel
  5. OPPM 一页纸项目管理 One-Page Project Management
  6. 江阳职高计算机应用教改实验,计算机应用课程教改模式
  7. Python 02 编写代码
  8. 【BurpSuite学习篇】四:Scanner 漏洞扫描模块
  9. 用R做GLM的Summary相关指标解释——以Poission regression为例
  10. 深入理解拉格朗日乘子法(Lagrange Multiplier) 和KKT条件
  11. spark streaming 整合kafka 报错 KafkaConsumer is not safe for multi-threaded access
  12. 谁能最后享受到胜利成果?
  13. 2018年算法工程师秋招经验贴(微软、华为、网易游戏、阿里offer)
  14. 史上最通俗计算机网络分层详解,附架构师必备技术详解
  15. (HDU)1718 -- Rank (段位)
  16. 微信小程序获取当前位置及地图选点功能
  17. 【Bio】基础生物学 - 基本氨基酸 amino acids
  18. 软件项目计划管理:三级计划管理体系
  19. c++强引用与弱引用
  20. 数通安全工程师 7-18K/月

热门文章

  1. xmind可以画流程图吗_新娘妆可以自己画吗?临夏化妆学校告诉你答案!
  2. JAVA读取2g数据的速度_Java 读取大容量excel
  3. scp选择二进制_二进制传输与文本传输区别
  4. 【MySQL】Linux端-实现Mysql数据定时自动备份
  5. STL set和multiset
  6. 软件测试工程师核心竞争力(转)
  7. sql server常用函数积累
  8. TZOJ--1518: 星星点点 (二进制模拟)
  9. asp.net登录状态验证
  10. Nimbus三Storm源码分析--Nimbus启动过程