• object.grad.zero_()的意思是清0object的梯度值。
    下面做个实验。
x = torch.arange(4.0)
x.requires_grad_(True)
x.grad
# 注意此时为None,不为0
y = 2 * torch.dot(x, x)
y.backward()
x.grad
# tensor([ 0.,  4.,  8., 12.])
x.grad.zero_()
x.grad
# tensor([0., 0., 0., 0.])
  • 在默认情况下,PyTorch会累积梯度,我们需要清除之前的值,假如不清0会出现什么现象,看下面的实验。
x = torch.arange(5.0)
x.requires_grad_(True)
y = 2 * torch.dot(x, x)
y.backward()
x.grad
# Out[58]: tensor([ 0.,  4.,  8., 12., 16.])
z = 2 * torch.dot(x, x)
z.backward()
x.grad
# Out[61]: tensor([ 0.,  8., 16., 24., 32.]),结果不对
  • 那么上面这个错误结果是怎么来的呢?
    PyTorch会累积梯度,tensor([ 0., 8., 16., 24., 32.]) = tensor([ 0., 4., 8., 12., 16.]) + tensor([ 0., 4., 8., 12., 16.])得到的结果;

  • 所以下面这段代码的意思是迭代param时不需要构建计算图,并且迭代完成后就把param.grad清0,因为再一次调用sgd时就是下一个batch得到的param.grad,batch和batch是没有关系的。

def sgd(params, lr, batch_size):  #@save"""小批量随机梯度下降。"""with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()

参考资料

  1. https://zh-v2.d2l.ai/chapter_preliminaries/autograd.html;

pytorch之object.grad.zero_()相关推荐

  1. 【pytorch】|tensor grad

    计算图与动态图机制 计算图是用来描述运算的有向无环图.计算图有两个主要元素:结点(Node)和边(Edge).结点表示数据,如向量,矩阵,张量:边表示运算,如加减乘除卷积等. 下面用计算图表示:y = ...

  2. Pytorch:ToTensor(object)类

    PyTorch在做一般的深度学习图像处理任务时,先使用dataset类和dataloader类读入图片,在读入的时候需要做transform变换,其中transform一般都需要ToTensor()操 ...

  3. [转]一文解释PyTorch求导相关 (backward, autograd.grad)

    PyTorch是动态图,即计算图的搭建和运算是同时的,随时可以输出结果:而TensorFlow是静态图. 在pytorch的计算图里只有两种元素:数据(tensor)和 运算(operation) 运 ...

  4. Pytorch的grad、backward()、zero_grad()

    grad 梯度 什么样的tensor有grad? pytorch中只有torch.float和复杂类型才能有grad. x = torch.tensor([1, 2, 3, 4], requires_ ...

  5. PyTorch-Adam优化算法原理,公式,应用

    概念:Adam 是一种可以替代传统随机梯度下降过程的一阶优化算法,它能基于训练数据迭代地更新神经网络权重.Adam 最开始是由 OpenAI 的 Diederik Kingma 和多伦多大学的 Jim ...

  6. 笔记 | PyTorch安装及入门教程

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文内容概述如何安装PyTorch以及PyTorch的一些简单操作 ...

  7. pytorch adagrad_【学习笔记】Pytorch深度学习—优化器(二)

    点击文末 阅读原文,体验感更好哦! 前面学习过了Pytorch中优化器optimizer的基本属性和方法,优化器optimizer的主要功能是 "管理模型中的可学习参数,并利用参数的梯度gr ...

  8. PyTorch官方教程中文版:入门强化教程代码学习

    PyTorch之数据加载和处理 from __future__ import print_function, division import os import torch import pandas ...

  9. [PyTorch] 译+注:一个例子,让你明白PyTorch框架

    文章目录 Introduction Motivation Table of Contents A Simple Regression Problem (一个简单的线性回归) Data Generati ...

最新文章

  1. 【转】C# DateTime 日期计算
  2. boost::mp11::mp_set_push_front相关用法的测试程序
  3. Java中FileInputStream和FileOutputStream类实现文件夹及文件的复制粘贴
  4. poj2385 基础的动态规划算法 挑战程序设计竞赛
  5. 组织JSON数据、JSON转换
  6. PHP 判断客户端请求是 Android 还是 IOS
  7. ceph 代码分析 读_五大常见存储系统PK | Ceph、GlusterFS、MooseFS、HDFS、DRBD
  8. 转:error LNK2001 错误
  9. ajax 自动提示信息,自动提示使用AJAX
  10. vs不能调试_20200717调试记录(五十四)
  11. 几个支持 FreeSWITCH 的网络电话的安装与使用(linphone、MicroSIP、Sipdroid)
  12. 新团队团队融合研讨会_新的网络研讨会:如何避免持续交付的隐性成本
  13. 腾讯IM发送消息20001
  14. win10如何删除输入法_30秒解决Win10下输入法栏消失无法输入中文的难题
  15. 计算机考研408必考重难点整理(2022考纲大改后,陆续更新中。。)
  16. HTML实现两行两列单元表
  17. 计算机财务管理系统基础知识,计算机财务管理实习报告
  18. linux怎么查看内容并显示行号,linux中查看文件时显示行号
  19. 我的第一个博客----浅谈人生观价值观
  20. 终年32岁的传奇数学家,生前寂寂无闻,一个世纪后却让硅谷领袖们集体落泪致敬

热门文章

  1. 前端学习(750):作用域导读
  2. 12.多媒体和超链接标签及其应用实例
  3. 数字图像处理技术的应 用领域
  4. linux mysql数据库定时备份
  5. vue事件委托传递节点防止向下传递穿透
  6. 输入 3 个正数,判断能否构成一个三角形。
  7. iOS 去除警告 看我就够了
  8. CodeForces 139C Literature Lesson(模拟)
  9. 简单了解static
  10. HDU 1402 A * B Problem Plus FFT