文章目录

  • 反向传播法求梯度
  • 梯度下降求最小值

反向传播法求梯度

利用计算图求梯度是一种比较方便又快速的方法,如何利用计算图求梯度?先回忆一下计算图:

以 z=x2+y2z=x^2+y^2z=x2+y2 为例:

  • 计算图以箭头和节点构成,正向传播时,得到的结果是 z=x2+y2z=x^2+y^2z=x2+y2

  • 反向传播时,得到的结果是:∂L∂z×2x\frac{\partial L}{\partial z}\times 2x∂z∂L​×2x 和 ∂L∂z×2y\frac{\partial L}{\partial z}\times 2y∂z∂L​×2y

  • 仔细一看,令 ∂L∂z=1\frac{\partial L}{\partial z}=1∂z∂L​=1 不就得到梯度为 (2x,2y)(2x,2y)(2x,2y) 了吗!(注意这里的x, y是正向传播的x和y)
    为什么令∂L∂z=1\frac{\partial L}{\partial z}=1∂z∂L​=1 就可以得到梯度了呢?实际上在计算图中,L才是输出函数,而z只是中间变量,但在这里,z也是输出函数,所以 L=z, 因此有∂L∂z=1\frac{\partial L}{\partial z}=1∂z∂L​=1

  • 认识到这点,现在就可以写代码实现求梯度了,这里用一个类来实现:

class SqrtWithAdd:def __init__(self):self.x = Noneself.y = Nonedef forward(self, x, y):self.x = xself.y = yout = x**2 + y**2return outdef backward(self, dout=1): dx = dout * 2*self.xdy = dout * 2*self.yreturn dx,dy           #(dx, dy)就是梯度

现在利用这个类来求一下在点 (2, 3) 处的梯度:

sqrt_with_add = SqrtWithAdd()    #实例化类
sqrt_with_add.forward(2,3)       #先进行正向传播
grad = sqrt_with_add.backward()  #求梯度print(grad)#输出:
(4, 6)

经手算验证,结果正确。

梯度下降求最小值

既然用反向传播的方法求出了梯度值,那么现在就想用这个方法结合梯度下降法来求一下函数最小值。

先把公式写一下:
x=x−η∂f∂xy=y−η∂f∂yx = x - \eta \frac{\partial f}{\partial x}\\y = y - \eta \frac{\partial f}{\partial y} x=x−η∂x∂f​y=y−η∂y∂f​
之前讲过了,η\etaη 是学习率。

代码如下:

def gradient_descent(init_x,init_y, lr=0.01, step_num=100):x = init_xy = init_ysqrt_with_add = SqrtWithAdd()    #创建实例for i in range(step_num):sqrt_with_add.forward(x,y)   #正向传播dx,dy = sqrt_with_add.backward()     #反向传播求梯度x -= lr * dxy -= lr * dyreturn x,y#调用函数求最小值位置
x,y = gradient_descent(init_x= 3,init_y= 4) #设置开始起点(3,4)print('[{}, {}]'.format(x,y))

输出:

[0.39785866768425965, 0.5304782235790126]

结果可以通过调整学习率 lr 和 学习次数 step_num 来接近最佳位置。

这里举的例子是二维的,可以通过调整上面的类的输入值的维数来使其变为可以求 n维的。


最后总结一下反向传播法求梯度与数值微分法求梯度的区别

反向传播求梯度实际上用的是解析法来求导,跟自己手算求梯度是一样的,记得求导公式就可以;

而数值微分法求的梯度实际上用的是导数定义法来求

13-反向传播法求梯度相关推荐

  1. 一步一步教你反向传播,求梯度(A Step by Step Backpropagation Example)

    本文是我在学习反向传播时翻译的一篇文章.原文链接如下. A Step by Step Backpropagation Example 实例学习 在这个例子里,我们将制作一个小型神经网络.它有两个输入, ...

  2. 误差反向传播法(二)【神经网络以层的方式实现】

    我们来看激活函数层的实现,对于激活函数,大家初学神经网络的时候就经常听到,准确来说是在接触感知机的时候熟悉的,它是进入神经网络大门的钥匙,是现代神经网络快速发展的源头. ReLU层(Rectified ...

  3. 深度学习——误差反向传播法

    前言 通过数值微分的方法计算了神经网络中损失函数关于权重参数的梯度,虽然容易实现,但缺点是比较费时间,本章节将使用一种高效的计算权重参数梯度的方法--误差方向传播法 本文将通过①数学式.②计算图,这两 ...

  4. 梯度下降法与反向传播法

    梯度下降法与反向传播法 梯度下降法 参考资料:推荐系统玩家 之 随机梯度下降(Stochastic gradient descent)求解矩阵分解 - 知乎 (zhihu.com) 什么是梯度? 首先 ...

  5. 【深度学习的数学】2×3×1层带sigmoid激活函数的神经网络感知机对三角形平面的分类训练预测(绘制出模型结果三维图展示效果)(梯度下降法+最小二乘法+激活函数sigmoid+误差反向传播法)

    文章目录 训练数据 数据示意 训练数据生成及绘制三维图像代码 训练数据三维图像 搭建神经网络结构 网络结构 利用梯度下降法和误差反向传播法计算损失函数损失值 代码 [灾难降临]代码出现严重问题,已将其 ...

  6. 计算梯度的三种方法: 数值法,解析法,反向传播法

    # coding=gbk""" function : f(x,y,z) = (x+y)z """ # first method 解析法 de ...

  7. Batch Normalization函数详解及反向传播中的梯度求导

    摘要 本文给出 Batch Normalization 函数的定义, 并求解其在反向传播中的梯度 相关 配套代码, 请参考文章 : Python和PyTorch对比实现批标准化Batch Normal ...

  8. 深度学习入门-基于python的理论与实现(五)误差反向传播法

    目录 回顾 1 计算图 1.1局部计算 1.2 计算图的优点是什么 1.3 反向传播的导数是怎么求? 1.3.1加法节点的反向传播 1.3.2 乘法节点的反向传播 1.3.3 购买苹果的反向传播 1. ...

  9. 一文弄懂神经网络中的反向传播法——BackPropagation

    https://www.cnblogs.com/charlotte77/p/5629865.html 最近在看深度学习的东西,一开始看的吴恩达的UFLDL教程,有中文版就直接看了,后来发现有些地方总是 ...

最新文章

  1. 列表组件之ListView
  2. mongodb拆库分表脚本
  3. SAP Spartacus Cart UI 修改 quantity 字段后的 Patch 请求遇到 400 错误 - IllegalArgumentError
  4. do filtering will real delete note in DB
  5. [Leetcode][第20题][JAVA][有效的括号][栈][HashMap]
  6. 关于Oracle与MySQL的使用总结
  7. JAVA入门级教学之(参数传递)
  8. js验证家庭住址_手摇充电电筒、多功能组合剪刀……官方清单建议上海家庭储备13种应急物资...
  9. 页面的加载与渲染顺序
  10. PHP include语句和require语句
  11. html留言页面设计,html的留言板制作(js)
  12. Linux 常用命令 Updating
  13. 手机python代码查询四六级准考证_四六级查准考证号的网站是什么
  14. java pem 签名_如何在Java中验证PEM格式证书
  15. 如何分析数据建立数据表
  16. ORACLE 数据、表误删恢复(转)
  17. 转:Google论文之一----Bigtable学习翻译
  18. 疫情之下的远程办公解决方案
  19. 回收站恢复的文件找不到了怎么办?竟然还有这3种靠谱的方案
  20. Smart-Link配置

热门文章

  1. STM32与Flash AT45DB321D之间读写数据
  2. 容联云通讯_提供网络通话、视频通话、视频会议、云呼叫中心、IM等融合通讯能力开放平台。...
  3. linux 析构函数地址获取_c语言中有析构函数吗
  4. 华为笔记本在linux下越狱苹果设备(2022.2.27更新)
  5. supermap gis
  6. Android 从WebServer 获取PDF转图片
  7. zabbix简介及部署安装(邮件报警)
  8. 关于python赋值语句下列选项中描述正确的是_关于 Python 语句 P = –P,以下选项中描述正确的是________...
  9. PayPal开发之IPN的使用
  10. pci总线定时协议_PCI总线标准及协议