首先介绍一下链式法则

假如我们要求z对x1的偏导数,那么势必得先求z对t1的偏导数,这就是链式法则,一环扣一环

BackPropagation(BP)正是基于链式法则的,接下来用简单的前向传播网络为例来解释。里面有线的神经元代表的sigmoid函数,y_1代表的是经过模型预测出来的,y_1 = w1 * x1 + w2 * x2,而y^1代表的是实际值,最后是预测值与实际值之间的误差,l_1 = 1/2 * (y_1 - y^1)^2,l_2同理。总的错误是E = l_1 + l_2。

在神经网络中我们采用梯度下降(Gradient Descent)来进行参数更新,最终找到最优参数。可是离E最近的不是w1,首先我们需要求出E对l_1的偏导,接着求l_1对于最近神经元sigmoid中变量的导数,最后再求y_0对于w1的偏导,进行梯度更新。

这便是神经网络中的BP算法,与以往的正向传播不同,它应该是从反向的角度不断优化

这里只是用了一层隐含层,可以看的出来一个参数的梯度往往与几个量产生关系:

  • 最终y被预测的值。这往往取决于你的激活函数,如这里采用sigmoid

  • 中间对激活函数进行求导的值

  • 输入的向量,即为x

推广到N层隐含层,只是乘的东西变多了,但是每个式子所真正代表的含义是一样的。

换个角度说,在深度学习梯度下降的时候会出现比较常见的两类问题,梯度消失以及梯度爆炸很可能就是这些量之间出了问题,对模型造成了影响。

1、梯度消失(Gradient Vanishing)。意思是梯度越来越小,一个很小的数再乘上几个较小的数,那么整体的结果就会变得非常的小。那么导致的可能原因有哪些呢?我们由靠近E的方向向后分析。

  • 激活函数。y_1是最后经过激活函数的结果,如果激活函数不能很好地反映一开始输入时的情况,那么就很有可能出问题。sigmoid函数的性质是正数输出为大于0.5,负数输出为小于0.5,因为函数的值域为(0,1),所以也常常被用作二分类的激活函数,用以表示概率。但是,当x比较靠近原点的时候,x变化时,函数的输出也会发生明显的变化,可是,当x相当大的时候,sigmoid几乎已经是无动于衷了,x相当小的时候同理。这里不妨具体举个二分类的例子,比如说用0,1代表标签,叠了一层神经网络,sigmoid函数作为激活函数。E对l_1的偏导极大程度将取决于y_1,因为标签就是0,1嘛。就算输入端的x,w都比较大,那么经过sigmoid压缩之后就会变得很小,只有一层的时候其实还好,但是当层数变多之后呢,sigmoid函数假如说每一层都当做是激活函数,那么最后E对l_1的偏导将是十分地小,尽管x,w代表着一些信息,可经过sigmoid压缩之后信息发生了丢失,梯度无法完成真正意义上的传播,乘上一个很小的数,那么整个梯度会越来越小,梯度越小,说明几乎快收敛了。换句话说,几乎没多久就已经收敛了。

另一种思路是从公式(4)出发,无论y_0取何值,公式(4)的输出值总是介于(0,1/4](当然具体边界处是否能取到取决于具体变量的取值),证明:

因为不断乘上一个比较小的数字,所以层数一多,那么整个梯度的值就会变得十分小,而且这个是由sigmoid本身导致,应该是梯度消失的主因。

解决方法可以是换个激活函数,比如RELU就不错,或者RELU的变种。

2、梯度爆炸(Gradient Exploding)。意思是梯度越来越大,更新的时候完全起不到优化的作用。其实梯度爆炸发生的频率远小于梯度消失的频率。如果发生了,可以用梯度削减(Gradient Clipping)。

  • 梯度削减。首先设置一个clip_gradient作为梯度阈值,然后按照往常一样求出各个梯度,不一样的是,我们没有立马进行更新,而是求出这些梯度的L2范数,注意这里的L2范数与岭回归中的L2惩罚项不一样,前者求平方和之后开根号而后者不需要开根号。如果L2范数大于设置好的clip_gradient,则求clip_gradient除以L2范数,然后把除好的结果乘上原来的梯度完成更新。当梯度很大的时候,作为分母的结果就会很小,那么乘上原来的梯度,整个值就会变小,从而可以有效地控制梯度的范围。

另外,观察公式可知,其实到底对谁求偏导是看最近的一次是谁作为自变量,就会对谁求,不一定都是对权重参数求,也有对y求的时候。

接着我们用PyTorch来实操一下反向传播算法,PyTorch可以实现自动微分,requires_grad 表示这个参数是可学习的,这样我们进行BP的时候,直接用就好。不过默认情况下是False,需要我们手动设置。

import torchx = torch.ones(3,3,requires_grad = True)
t = x * x + 2
z = 2 * t + 1
y = z.mean()

接下来我们想求y对x的微分,需要注意这种用法只能适用于y是标量而非向量

y.backward()
print(x.grad)

所以当y是向量形式的时候我们可以自己来算,如

x = torch.ones(3,3,requires_grad = True)
t = x * x + 2
y = t - 9

如果计算y.backward()会报错,因为是向量,所以我们需要手动算v,这里就是y对t嘛,全为1,注意v的维度就行。

v = torch.ones(3,3)
y.backward(v)
print(x.grad)

往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑黄海广老师《机器学习课程》视频课

本站qq群851320808,加入微信群请扫码:

【机器学习】详解 BackPropagation 反向传播算法!相关推荐

  1. 【机器学习笔记】神经网络反向传播算法 推导

    神经网络反向传播算法 推导 (一) 概念及基本思想 (二)信息的前向传播 (三)误差反向传播 (1)输出层的权重参数更新 (2)隐藏层的权重参数更新 (3)输出层与隐藏层的偏置参数更新 (4)反向传播 ...

  2. Backpropagation 反向传播算法

    当我们搭建好一个神经网络后,我们的目标都是:将网络的权值和偏置训练为一个好的值,这个值可以让我们的输入得到理想的输出. 我们经常使用梯度下降算法(Gradient Descent)来最小化损失函数,得 ...

  3. Batch Normalization函数详解及反向传播中的梯度求导

    摘要 本文给出 Batch Normalization 函数的定义, 并求解其在反向传播中的梯度 相关 配套代码, 请参考文章 : Python和PyTorch对比实现批标准化Batch Normal ...

  4. L2正则化Regularization详解及反向传播的梯度求导

    摘要 本文解释L2正则化Regularization, 求解其在反向传播中的梯度, 并使用TensorFlow和PyTorch验证. 相关 系列文章索引 : https://blog.csdn.net ...

  5. 机器学习--多标签softmax + cross-entropy交叉熵损失函数详解及反向传播中的梯度求导

    https://blog.csdn.net/oBrightLamp/article/details/84069835 正文 在大多数教程中, softmax 和 cross-entropy 总是一起出 ...

  6. 神经网络——反向传播算法

    一.多元分类 之前讨论的神经网络都是以二元分类为目的进行介绍的. 当我们有不止两种分类时(也就是y=1,2,3-.y=1,2,3-.y=1,2,3-.),比如以下这种情况,该怎么办?如果我们要训练一个 ...

  7. 【 反向传播算法 Back-Propagation 数学推导以及源码详解 深度学习 Pytorch笔记 B站刘二大人(3/10)】

    反向传播算法 Back-Propagation 数学推导以及源码详解 深度学习 Pytorch笔记 B站刘二大人(3/10) 数学推导 BP算法 BP神经网络可以说机器学习的最基础网络.对于普通的简单 ...

  8. 反向算法_10分钟带你了解神经网络基础:反向传播算法详解

    作者:Great Learning Team deephub.ai 翻译组 1.神经网络 2.什么是反向传播? 3.反向传播是如何工作的? 4.损失函数 5.为什么我们需要反向传播? 6.前馈网络 7 ...

  9. 机器学习(深度学习)中的反向传播算法与梯度下降

    这是自己在CSDN的第一篇博客,目的是为了给自己学习过的知识做一个总结,方便后续温习,避免每次都重复搜索相关文章. 一.反向传播算法 定义:反向传播(Backpropagation,缩写为BP)是&q ...

最新文章

  1. 面试被问烂的 Spring IOC(求求你别再问了)
  2. pytorch微调bert_北大、人大联合开源工具箱UER,3 行代码完美复现BERT、GPT
  3. USACO Section 4.2 题解
  4. 窗宽窗位改变图像_CT、MRI图像的影像诊断4大原则、5个步骤、3大阅片方法
  5. 一台服务器多个oracle启动
  6. QQ IDKey生成--一键加群
  7. JavaScript开发必备!这四款静态代码分析工具你了解吗
  8. 谈中国分布式数据库商业之路:OSM与DB-Inside
  9. 网站推广第一周总结和反思
  10. VideoShow -视频编辑 v8.8.4rc (更新版)
  11. 入职前的背景调查到底在查什么?
  12. 移动硬盘显示成cd驱动器解决办法
  13. provide和inject 用法
  14. Cesium-定位至entity的位置
  15. ScriptManager.RegisterStartupScript()方法
  16. matlab做二元garch m,多元garch模型的matlab程序如何运行?能否举例说明下啊,希望高手指点...
  17. 公共WIFI上网短信认证解决方案
  18. ubuntu16下安装opencv3.4.10
  19. 机器学习实战(一):k-近邻算法
  20. 第一类公民(First-class Citizen)

热门文章

  1. 可变数组NSMutableArray
  2. ORACLE经常使用的命令
  3. 在GLSurfaceView上添加Layout控件(android)
  4. Gromacs文件-Chapter1
  5. Linux CentOS服务启动
  6. 【java】java开发中的23种设计模式详解
  7. 常规操作中浏览器缓存检测与服务器请求机制总结
  8. 新建虚拟机Ubuntu16.4安装搜狗输入法的问题
  9. scanf_s 发送访问冲突_程序员如何解决并发冲突的难题?
  10. 数学学习--最小二乘法案例剖析