反向传播算法 Back-Propagation 数学推导以及源码详解 深度学习 Pytorch笔记 B站刘二大人(3/10)

数学推导

BP算法 BP神经网络可以说机器学习的最基础网络。对于普通的简单的神经网络层,我们还能通过推导计算得到梯度表达式,但是当网络结构如下图所示

此时梯度grad就变成了非常庞大的计算量,对于复杂的多层级网络,权重w个数多,无法直接对权重w进行解析式求导。面对这种情况我们引入数据结构中的图的概念,通过形成计算图在图上传播梯度利用链式法则对各个节点的梯度进行求解。

需要注意的是,基础线性单元(一层)的构成,应该以矩阵思维看待,w权重矩阵+b偏置量,将参数和输入输出都视为向量或者矩阵。Matrix cookbook 是主要矩阵运算的参考资料,可以去查阅。

由于各个层都是线性关系,而线性映射之间可以进行线性拼接和化简,会导致多个线性层直接连接与单一线性层的功能相同无法表示足够的网络复杂程度。在此引入激活函数概念,在各个层之间连接处加入激活函数Sigmoid

在每一线性层后加上激活函数Nonliner Function,激活函数的本质是非线性映射
eg: sigmoid: x -> 1/1+e^(-x1)
之后通过链式法则,累计求导,实际上就是高数中的复合函数求导和求偏导的相关知识,如下图

梯度计算过程,首先前馈计算出loss函数,之后根据loss函数,之后反向求loss与输出Z导数,由于Z由输入x和权重w的复合组成,因此可以求出loss与x和w的导数,根据loss的意义,loss最小则模型达到最优,而x输入为固定,则根据w关于loss导数动态调整w进行更新即可得到最新的loss
细节:在多层的运算过程中一般会将求导的导数存储在层单元中,pytorch是将导数存储在输入单元x中,而非运算模块f=x*w中
下图是Forward与backward具体流程推导,其中注意wx与wx+b的区别


Pytorch基本数据类型tensor,用于储存所有数值(标量,向量,矩阵,高维矩阵),主要成员:data保存权重本身值+grad损失函数对权重的导数

源码解读与实现


编程细节,在进行数据类型定义的时候将w的tensor数据类型中求导的标识符定义为真**(tensor数据类型默认不进行梯度求导以节省运算)**

此时,运算符重载,进行tensor与tensor之间的乘法运算,将x自动进行数据转化变换为tensor类型,同时该计算模块由于内部成员w是需要进行梯度计算的tensor类型x,则该计算模块x*w也自动将梯度计算的表示符转化为true
需要注意的是,在tensor的计算中是按照图的形式进行生成,每进行一次调用和运行,就动态生成一次计算图

  1. .backward()函数将整条计算链上的梯度全部进行计算并存储**,在进行一次backward后,将之前生成的计算图进行清除释放**
  2. .data运算,通过.data运算是直接运算其中的存储数据进行标量计算,而不构建计算图,如果使用w直接进行计算将在运算过程中构建计算图,产生大量冗余计算。
    3.在计算中不可以定义sum使用sum += l将loss值累加,同样因为l为张量,在与标量sum进行计算的过程中将生成计算图,产生冗余运算。如使用需要使用语句 sum += l.item()

整体代码

import torchx_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0 ,6.0]w = torch.Tensor([1.0]) # 预测模型的参数w应当作为tensor量进行定义
w.requires_grad = True # 将tensor中的求导标识符定义为True,默认为Falsedef forward(x):return x * w    # 注意:此时由于w是tensor张量,在进行乘法时自动讲x转换为张量,进行张量乘法def loss(x,y):y_pred = forward(x) #调用forward函数计算预测值return (y_pred - y)**2 #返回损失print("predict (before training)", 4, forward(4).item()) # 输出未训练的结果for epoch in range(100):for x,y in zip(x_data, y_data):l = loss(x,y)   #计算损失函数l.backward()    #计算梯度,注意此时进行的是计算图运算print("\t grad:", x, y, w.grad.item())  # w属于张量,运算将生成计算图,因此用item函数调用标量数据w.data = w.data - 0.01 * w.grad.item()w.grad.data.zero_()     #梯度清零print("progress:", epoch, l.item())print("predict (after training)", 4, forward(4).item()) #输出训练预测值

【 反向传播算法 Back-Propagation 数学推导以及源码详解 深度学习 Pytorch笔记 B站刘二大人(3/10)】相关推荐

  1. 【 梯度下降算法 Gradient-Descend 数学推导与源码详解 深度学习 Pytorch笔记 B站刘二大人(2/10)】

    梯度下降算法 Gradient-Descend 数学推导与源码详解 深度学习 Pytorch笔记 B站刘二大人(2/10) 数学原理分析 在第一节中我们定义并构建了线性模型,即最简单的深度学习模型,但 ...

  2. 【多输入模型 Multiple-Dimension 数学原理分析以及源码详解 深度学习 Pytorch笔记 B站刘二大人 (6/10)】

    多输入模型 Multiple-Dimension 数学原理分析以及源码源码详解 深度学习 Pytorch笔记 B站刘二大人(6/10) 数学推导 在之前实现的模型普遍都是单输入单输出模型,显然,在现实 ...

  3. 【分类器 Softmax-Classifier softmax数学原理与源码详解 深度学习 Pytorch笔记 B站刘二大人(8/10)】

    分类器 Softmax-Classifier softmax数学原理与源码详解 深度学习 Pytorch笔记 B站刘二大人 (8/10) 在进行本章的数学推导前,有必要先粗浅的介绍一下,笔者在广泛查找 ...

  4. 【 卷积神经网络CNN 数学原理分析与源码详解 深度学习 Pytorch笔记 B站刘二大人(9/10)】

    卷积神经网络CNN 数学原理分析与源码详解 深度学习 Pytorch笔记 B站刘二大人(9/10) 本章主要进行卷积神经网络的相关数学原理和pytorch的对应模块进行推导分析 代码也是通过demo实 ...

  5. 【 线性模型 Linear-Model 数学原理分析以及源码实现 深度学习 Pytorch笔记 B站刘二大人(1/10)】

    线性模型 Linear-Model 数学原理分析以及源码实现 深度学习 Pytorch笔记 B站刘二大人(1/10) 数学原理分析 线性模型是我们在初级数学问题中所遇到的最普遍也是最多的一类问题 在线 ...

  6. CNN反向传播源码实现——CNN数学推导及源码实现系列(4)

    前言 本系列文章链接: CNN前置知识:模型的数学符号定义--卷积网络从零实现系列(1)_日拱一两卒的博客-CSDN博客https://blog.csdn.net/yangwohenmai1/arti ...

  7. 前向传播算法(Forward propagation)与反向传播算法(Back propagation)

    虽然学深度学习有一段时间了,但是对于一些算法的具体实现还是模糊不清,用了很久也不是很了解.因此特意先对深度学习中的相关基础概念做一下总结.先看看前向传播算法(Forward propagation)与 ...

  8. 花书+吴恩达深度学习(三)反向传播算法 Back Propagation

    目录 0. 前言 1. 从 Logistic Regression 中理解反向传播 2. 两层神经网络中单个样本的反向传播 3. 两层神经网络中多个样本的反向传播 如果这篇文章对你有一点小小的帮助,请 ...

  9. fdct算法 java_ImageSharp源码详解之JPEG压缩原理(3)DCT变换

    DCT变换可谓是JPEG编码原理里面数学难度最高的一环,我也是因为DCT变换的算法才对JPEG编码感兴趣(真是不自量力).这一章我就把我对DCT的研究心得体会分享出来,希望各位大神也不吝赐教. 1.离 ...

最新文章

  1. 2021-7-21 Bisenet V2 网络对Cityscapes公开数据集改变原有分类(4到5分类)
  2. 软件安装——internal error2503/2502
  3. pgsql 运行状态 采集脚本
  4. Java登陆页面经常出现的问题,问一下关于登陆页面的有关问题
  5. python 折线图x时间_在Python Bokeh折线图中设置日期/时间轴上的比例
  6. 调用函数,求a+aa+aaa+....+aa...aa(n个a)
  7. 【Linux内核】内存映射原理
  8. 怎样去理解@ComponentScan注解
  9. PyTorch 1.0 中文官方教程:使用字符级别特征的 RNN 网络进行姓氏分类
  10. PowerShell实现“机器人定时在企业微信群中发送消息”功能(下)
  11. 《淘宝网开店 进货 运营 管理 客服 实战200招》——1.3 常见网上开店平台
  12. 删除文件时提示正在被使用无法删除问题/删除dll文件
  13. 第三届“拳头奖”投票进行时 Devstore志在必得
  14. 【VUE+Elemet 】最全正则验证 + 表单验证 + 注意事项
  15. 数据结构(c语言版) 计算机科学丛书,数据结构与算法分析--C语言描述(原书第2版)(计算机科学丛书)...
  16. AAAI2022行人重识别论文汇总
  17. Adobe 软件共享
  18. QQ聊天快捷键【很好用的哦】
  19. 11台计算机的英语,计算机常见英语词汇
  20. 爆款!如何利用知乎引上万流量,我是这样做的!|实战

热门文章

  1. Random Walk(随机行走)
  2. chatGPT 生成随机漫步代码
  3. Android驻留广播,Android实现Service永久驻留
  4. 计算机科学经典著作(留作纪念)
  5. Excel文档的生成和压缩
  6. C++ 实现智能指针:shared_ptr 和 unique_ptr
  7. 获取屏幕、当前网页和浏览器窗口的大小
  8. 联想拯救者R720-15ikbn安装黑苹果Mac Catalina 10.15.3
  9. Spark中的spark.sql.shuffle.partitions 和spark.default.parallelism参数设置默认partition数目
  10. 什么软件查C语言答案,C语言小测验和参考答案