系列文章目录


文章目录

  • 系列文章目录
  • 前言
  • 一、Autograd是什么?
  • 二、Autograd的使用方法
    • 1.在tensor中指定
    • 2.重要属性
  • 三、Autograd的进阶知识
    • 1、动态计算图
    • 2、梯度累加
  • 总结

前言

今天我们一起来谈一下pytorch的一个重要特性----自动求导机制Autograd。


一、Autograd是什么?

Autograd自动求导机制,是pytorch中针对神经网络反向求导与梯度更新所开发的便利工具,内嵌在tensor操作中。

二、Autograd的使用方法

1.在tensor中指定

创建Tensor时声明requires_grad参数为true

    >>> import torch>>> a=torch.randn(5, 3, requires_grad=True)>>> b=torch.randn(5, 3)>>> a.requires_grad, b.requires_gradTrue, False# 可以看到默认的Tensor是不需要求导的,设置requires_grad为True后则需要求导# 也可以通过内置函数requires_grad_()将Tensor变为需要求导>>> b.requires_grad_()>>> b.requires_gradTrue

2.重要属性

tensor的两个重要属性:
.grad:记录梯度
.grad_fn:记录操作

require_grad参数表示是否需要对该Tensor进行求导,默认为False,一旦设定为True,则该Tensor的之后的所有节点都自动默认为True.

 >>> import torch>>> a=torch.randn(3, 3)>>> b=torch.randn(3, 3, requires_grad=True)>>> c = a + b>>> c.requires_gradTrue# 通过计算生成的Tensor,由于依赖的Tensor需要求导,因此c也需要求导>>> a.grad_fn, b.grad_fn, c.grad_fn(None, None, <AddBackward1 object at 0x7fa7a53e04a8>)# a与b是自己创建的,grad_fn为None,而c的grad_fn则是一个Add函数操作>>> d = c.detach()>>> d.requires_gradFalse# Tensor.detach()函数生成的数据默认requires_grad为False。

三、Autograd的进阶知识

1、动态计算图

每一步Tensor的计算操作,会生成计算图,并将操作的function记录在Tensor的grad_fn中。每一次前向都会生成新的计算图,这是动态的来源。同时反向一次之后计算图就会清空,不能在此基础上继续进行求导。

在前向计算完后,只需对根节点进行backward函数操作,即可从当前根节点自动进行反向传播与梯度计算,从而得到每一个叶子节点的梯度,梯度计算遵循链式求导法则。

2、梯度累加

grad一般计算完继续保留,不主动清空会累加。

for i, (inputs, labels) in enumerate(trainloader):outputs = net(inputs)                   # 正向传播loss = criterion(outputs, labels)       # 计算损失函数loss = loss / accumulation_steps        # 损失标准化loss.backward()                         # 反向传播,计算梯度if (i+1) % accumulation_steps == 0:optimizer.step()                    # 更新参数optimizer.zero_grad()               # 梯度清零if (i+1) % evaluation_steps == 0:evaluate_model()

1.正向传播,将数据传入网络,得到预测结果
2.根据预测结果与label,计算损失值
3.利用损失进行反向传播,计算参数梯度
4.重复1-3,不清空梯度,而是将梯度累加
5.梯度累加达到固定次数之后,更新参数,然后将梯度清零

总结来讲,梯度累加就是每计算一个batch的梯度,不进行清零,而是做梯度的累加,当累加到一定的次数之后,再更新网络参数,然后将梯度清零。

通过这种参数延迟更新的手段,可以实现与采用大batch size相近的效果。大多数情况下,采用梯度累加训练的模型效果,要比采用小batch size训练的模型效果要好很多。


总结

以上就是今天要讲的内容,本文仅仅简单介绍了自动求导机制的使用和意义,后续的使用中一般会集成在模块中,熟悉它的相关流程更重要。

Pytorch教程入门系列4----Autograd自动求导机制相关推荐

  1. Pytorch Autograd (自动求导机制)

    Introduce Pytorch Autograd库 (自动求导机制) 是训练神经网络时,反向误差传播(BP)算法的核心. 本文通过logistic回归模型来介绍Pytorch的自动求导机制.首先, ...

  2. pytorch基础知识整理(一)自动求导机制

    torch.autograd torch.autograd是pytorch最重要的组件,主要包括Variable类和Function类,Variable用来封装Tensor,是计算图上的节点,Func ...

  3. PyTorch 笔记Ⅱ——PyTorch 自动求导机制

    文章目录 Autograd: 自动求导机制 张量(Tensor) 梯度 使用PyTorch计算梯度数值 Autograd 简单的自动求导 复杂的自动求导 Autograd 过程解析 扩展Autogra ...

  4. 【PyTorch基础教程2】自动求导机制(学不会来打我啊)

    文章目录 第一部分:深度学习和机器学习 一.机器学习任务 二.ML和DL区别 (1)数据加载 (2)模型实现 (3)训练过程 第二部分:Pytorch部分 一.学习资源 二.自动求导机制 2.1 to ...

  5. PyTorch的计算图和自动求导机制

    文章目录 PyTorch的计算图和自动求导机制 自动求导机制简介 自动求导机制实例 梯度函数的使用 计算图构建的启用和禁用 总结 PyTorch的计算图和自动求导机制 自动求导机制简介 PyTorch ...

  6. Pytorch学习(一)—— 自动求导机制

    现在对 CNN 有了一定的了解,同时在 GitHub 上找了几个 examples 来学习,对网络的搭建有了笼统地认识,但是发现有好多基础 pytorch 的知识需要补习,所以慢慢从官网 API 进行 ...

  7. 【Torch笔记】autograd自动求导系统

    [Torch笔记]autograd自动求导系统 Pytorch 提供的自动求导系统 autograd,我们不需要手动地去计算梯度,只需要搭建好前向传播的计算图,然后使用 autograd 计算梯度即可 ...

  8. 【PyTorch学习(三)】Aurograd自动求导机制总结

    ​Aurograd自动求导机制总结 PyTorch中,所有神经网络的核心是 autograd 包.autograd 包为tensor上的所有操作提供了自动求导机制.它是一个在运行时定义(define- ...

  9. 深度学习PyTorch笔记(9):自动求导

    深度学习PyTorch笔记(9):自动求导 4. 自动求导 4.1 理解 4.2 梯度 4.3 .requires_grad与.grad_fn 4.4 调用.backward()反向传播来完成所有梯度 ...

最新文章

  1. 多模态深度学习:用深度学习的方式融合各种信息
  2. python安装步骤电脑版-超详细的小白python3.X安装教程|Python安装
  3. iOS Hacker 重签名实现无需越狱注入动态库 dylib
  4. pytorch选出数据中的前k个最大(最小)值及其索引
  5. 自助分析_为什么自助服务分析真的不是一回事
  6. (JAVA)Math类
  7. Android官方开发文档Training系列课程中文版:与其它APP交互之将用户带到其它的APP
  8. 2021国潮新消费产业洞察报告
  9. 大学生转行IT,零基础非计算机专业可以学会吗?
  10. JavaScript中对于函数的形参实参个数匹配是如何做的?
  11. 通达信波段王指标公式主图_通达信波段操作主图指标公式
  12. swarm主网BZZ挖矿:钱包如何添加BZZ合约?如何查钱包余额?
  13. Web——HTML常见标签及用法
  14. 泰山OFFICE技术讲座:标点关系穷举研究-07
  15. 转载--[数据库] MySQL汉字字段按拼音排序
  16. 1055 集体照 Python实现
  17. ROS2 基础概念 服务
  18. OpenCV3之操作例子总汇
  19. Windows Phone开发中,减小(改变)Pivot控件PivotItem的Header(标题)字号
  20. 关于快手小铃铛广告投放的方式

热门文章

  1. 如何能到移动、联通公司工作?
  2. 借力云管理网络 兰州大学开启智慧校园新时代
  3. 实战:怎样规避淘宝店铺运营的几大误区
  4. 淘宝运营 店铺动销率的定义 、公示解说、产生原因以及对应的解决方案
  5. OpenCV图像处理-模糊
  6. #积分制管理感言#河北沧州盛世今典广告传媒赵胜
  7. python历史时间轴可视化_如何使用Python创建历史时间轴
  8. Ubuntu开启VNC屏幕共享
  9. JAVA攻城狮学习路线
  10. 从光明顶一役,我们能学到什么?