pytorch笔记——autograd和Variable
1 autograd
1.1 requires_grad
tensor中会有一个属性requires_grad 来记录之前的操作(为之后计算梯度用)。
1.2 调整tensor的requires_grad
1.3 with torch.no_grad
在这个环境里面里面生成的式子将无requires_grad
1.4 detach
内容不变,但是requires_grad将变为False
2 Variable
一般pytorch里面的运算,都是Variable级别的运算
Variable 计算时, 它一步步默默地搭建着一个庞大的系统, 叫做计算图(computational graph)。
这个图是将所有的计算步骤 (节点) 都连接起来。最后进行误差反向传递的时候, 一次性将所有 variable 里面的修改幅度 (梯度) 都计算出来, 而 普通的tensor 就没有这个能力。
2.1 获取variable里面的数据
直接print(variable)只会输出 Variable 形式的数据, 在很多时候是用不了的(比如想要用 plt 画图), 所以我们要转换一下, 将它变成 tensor 形式,或者ndarray形式等。
3 autograd 流程
3.1 无梯度的叶子节点
import torch
a = torch.tensor(2.0)
b = torch.tensor(3.0)
c = a*b
c.backward()
a.grad,b.grad
'''
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
'''
前项的计算图如下:
每个方框代表一个tensor,其中列出一些属性(还有其他很多属性):
data | tensor的data |
grad | 当计算gradient的时候将会存入此函数对应情况下,这个tensor的gradient |
grad_fn | 指向用于backward的函数的节点 |
is_leaf | 判断是否是叶节点 |
requires_grad |
如果是设为 如果为 在上图中,此时由于requires_grad都为False,因此没有backwards的graph. |
3.2 有梯度的叶子节点
a = torch.tensor(2.0,requires_grad=True)
b = torch.tensor(3.0)
c = a*b
c.backward()
a.grad,b.grad
#(tensor(3.), None)
前馈流程图如下:
3.2.1 backward 后馈流程图
1) 当我们调用tensor的乘法函数时,同时调用了隐性变量 ctx (context)变量的save_for_backward 函数。这样就把此函数做backward时所需要的从forward函数中获取的相关的一些值存到了ctx中。
ctx起到了缓存相关参数的作用,变成连接forward与backward之间的缓存站。
ctx中的值将会在c 做backwards时传递给对应的Mulbackward 操作.
2) 由于c是通过 c=a*b运算得来的, c的grad_fn中存了做backwards时候对应的函数.且把这个对应的backward 叫做 “MulBackward”
3) 当进行c的backwards的时候,其实也就相当于执行了 c = a*b这个函数分别对 a 与b 做的偏导。
那么理应对应两组backwards的函数,这两组backwards的函数打包存在 MulBackward的 next_functions 中。
next_function为一个 tuple list, AccumulateGrad 将会把相应得到的结果送到 a.grad中和b.grad中
4) 于是在进行 c.backward() 后, c进行关于a以及关于b进行求导。
由于b的requires_grad为False,因此b项不参与backwards运算(所以,next_function中list的第二个tuple即为None)。
c关于a的梯度为3,因此3将传递给AccumulaGrad进一步传给a.grad
因此,经过反向传播之后,a.grad 的结果将为3
3.3 稍微复杂一点的
a = torch.tensor(2.0,requires_grad = True)
b = torch.tensor(3.0,requires_grad = True)
c = a*b
d = torch.tensor(4.0,requires_grad = True)
e = c*d
e.backward()
a.grad,b.grad,d.grad
#(tensor(12.), tensor(8.), tensor(6.))
- e的grad_fn 指向节点 MulBackward, c的grad_fn指向另一个节点 MulBackward
- c 为中间值is_leaf 为False,因此并不包含 grad值,在backward计算中,并不需要再重新获取c.grad的值, backward的运算直接走相应的backward node 即可
- MulBackward 从 ctx.saved_tensor中调用有用信息, e= c+d中 e关于c的梯度通过MulBackward 获取得4. 根据链式规则, 4再和上一阶段的 c关于 a和c关于b的两个梯度值3和2相乘,最终得到了相应的值12 和8
- 因此经过backward之后,a.grad 中存入12, b.grad中存入 8
参考资料:【one way的pytorch学习笔记】(四)autograd的流程机制原理_One Way的博客-CSDN博客
pytorch笔记——autograd和Variable相关推荐
- PyTorch 笔记(13)— autograd(0.4 之前和之后版本差异)、Tensor(张量)、Gradient(梯度)
1. 背景简述 torch.autograd 是 PyTorch 中方便用户使用,专门开发的一套自动求导引擎,它能够根据输入和前向传播过程自动构建计算图,并执行反向传播. 计算图是现代深度学习框架 P ...
- (d2l-ai/d2l-zh)《动手学深度学习》pytorch 笔记(3)前言(介绍各种机器学习问题)以及数据操作预备知识Ⅲ(概率)
开源项目地址:d2l-ai/d2l-zh 教材官网:https://zh.d2l.ai/ 书介绍:https://zh-v2.d2l.ai/ 笔记基于2021年7月26日发布的版本,书及代码下载地址在 ...
- Introduction to PyTorch 笔记
文章目录 Introduction to PyTorch 笔记 Part 1 - Tensors in PyTorch (Solution).ipynb Part 2 - Neural Network ...
- PyTorch 的 Autograd详解
↑ 点击蓝字 关注视学算法 作者丨xiaopl@知乎 来源丨https://zhuanlan.zhihu.com/p/69294347 编辑丨极市平台 PyTorch 作为一个深度学习平台,在深度学习 ...
- 04_Pytorch生态、PyTorch能做什么、PyTorch之Autograd、autograd案例、GPU加速案例
1.4.初见PyTorch 1.4.1.PyTorch生态 1.4.2.PyTorch能做什么? GPU加速 自动求导 常用网络层 nn.Linear nn.Conv2d nn.LSTMnn.R ...
- 8月2日Pytorch笔记——梯度、全连接层、GPU加速、Visdom
文章目录 前言 一.常见函数的梯度 二.激活函数及其梯度 1.Sigmoid 2.Tanh 3.ReLU 三.Loss 函数及其梯度 1.Mean Squared Error(MSE) 2.Softm ...
- PyTorch 的 Autograd
PyTorch 的 Autograd 原创 AlanBupt 发布于2019-06-15 22:16:21 阅读数 1175 收藏 更新于2019-06-15 22:16:21 分类专栏: Pytho ...
- PyTorch 笔记Ⅱ——PyTorch 自动求导机制
文章目录 Autograd: 自动求导机制 张量(Tensor) 梯度 使用PyTorch计算梯度数值 Autograd 简单的自动求导 复杂的自动求导 Autograd 过程解析 扩展Autogra ...
- PYTORCH 笔记 DILATE 代码解读
dilate 完整代码路径:vincent-leguen/DILATE: Code for our NeurIPS 2019 paper "Shape and Time Distortion ...
最新文章
- Linux常用指令---ps(查看进程)
- .netcore部署到IIS上出现HTTP Error 502.5 - Process Failure问题解决
- 高通软件发布版本简称
- hdu-1251(基本字典树)
- js修改本地json文件_Flutter加载本地JSON文件教程建议收藏
- julia const报错_我爱Julia之入门-004
- php 根号2计算过程,根号2以及π的计算--关于无理数的畅想
- C++工作笔记-作用域的巧妙使用,释放堆区创建的资源
- php 添加 redis 扩展模块
- Maven 无法下载Oracle 驱动解决
- 【linux】ubuntu11.10下各种问题以及解决方案
- c语言里编译错误c131,C语言题库2.doc
- 113. 路径总和 II
- global在python中啥意思_Python中global用法详解
- 文本转语音-微软Azure-一步一步教你从注册到使用
- 国自然php代码,2020国自然单细胞项目申请——你的学科代码申请对了吗? | 单细胞专题之国基金...
- Git与bitbucket简单使用
- IE7、IE6和火狐兼容性问题
- linux mint xed中文乱码
- 天天生鲜项目——项目立项
热门文章
- 五步搞定Android开发环境部署——非常详细的Android开发环境搭建教程(转)
- CSP认证201712-4	行车路线[C++题解]:单源最短路变型、拆点、好题!
- 高精度除以低精度板子
- tcp/ip 协议栈Linux内核源码分析13 udp套接字发送流程二
- 燕赵志愿云如何认证_如何成为中国志愿服务网注册志愿者?操作秘籍!
- 如何查看keepalived版本号_Linux下Keepalived 安装与配置
- python rsa_python rsa加解密
- 学计算机用华硕电脑,请问华硕笔记本电脑什么型号比较好用,就商务办公?
- java如何获得键值_如何在java中取map中的键值 的两种方法
- c++全局类对象_史上最全 Python 面向对象编程