multi task训练torch_手把手教你使用PyTorch(2)-requires_gradamp;computation graph
import torch
1. Requires_grad
但是,模型毕竟不是人,它的智力水平还不足够去自主辨识那些量的梯度需要计算,既然如此,就需要手动对其进行标记。
在PyTorch中,通用的数据结构tensor包含一个attributerequires_grad,它被用于说明当前量是否需要在计算中保留对应的梯度信息,以上文所述的线性回归为例,容易知道参数www为需要训练的对象,为了得到最合适的参数值,我们需要设置一个相关的损失函数,根据梯度回传的思路进行训练。
官方文档中的说明如下
If there’s a single input to an operation that requires gradient, its output will also require gradient.
只要某一个输入需要相关梯度值,则输出也需要保存相关梯度信息,这样就保证了这个输入的梯度回传。
而反之,若所有的输入都不需要保存梯度,那么输出的requires_grad会自动设置为False。既然没有了相关的梯度值,自然进行反向传播时会将这部分子图从计算中剔除。
Conversely, only if all inputs don’t require gradient, the output also won’t require it. Backward computation is never performed in the subgraphs, where all Tensors didn’t require gradients.
对于那些要求梯度的tensor,PyTorch会存储他们相关梯度信息和产生他们的操作,这产生额外内存消耗,为了优化内存使用,默认产生的tensor是不需要梯度的。
而我们在使用神经网络时,这些全连接层卷积层等结构的参数都是默认需要梯度的。
a = torch.tensor([1., 2., 3.])
print('a:', a.requires_grad)
b = torch.tensor([1., 4., 2.], requires_grad = True)
print('b:', b.requires_grad)
print('sum of a and b:', (a+b).requires_grad)
a: False
b: True
sum of a and b: True
2. Computation Graph
从PyTorch的设计原理上来说,在每次进行前向计算得到pred时,会产生一个用于梯度回传的计算图,这张图储存了进行back propagation需要的中间结果,当调用了.backward()后,会从内存中将这张图进行释放
这张计算图保存了计算的相关历史和提取计算所需的所有信息,以output作为root节点,以input和所有的参数为leaf节点,
we only retain the grad of the leaf node with requires_grad =True
在完成了前向计算的同时,PyTorch也获得了一张由计算梯度所需要的函数所组成的图
而从数据集中获得的input其requires_grad为False,故我们只会保存参数的梯度,进一步据此进行参数优化
在PyTorch中,multi-task任务一个标准的train from scratch流程为
for idx, data in enumerate(train_loader):
xs, ys = data
optmizer.zero_grad()
# 计算d(l1)/d(x)
pred1 = model1(xs) #生成graph1
loss = loss_fn1(pred1, ys)
loss.backward() #释放graph1
# 计算d(l2)/d(x)
pred2 = model2(xs)#生成graph2
loss2 = loss_fn2(pred2, ys)
loss.backward() #释放graph2
# 使用d(l1)/d(x)+d(l2)/d(x)进行优化
optmizer.step()
Computation Graph本质上是一个operation的图,所有的节点都是一个operation,而进行相应计算的参数则以叶节点的形式进行输入
借助torchviz库以下面的模型作为示例
import torch.nn.functional as F
import torch.nn as nn
class Conv_Classifier(nn.Module):
def __init__(self):
super(Conv_Classifier, self).__init__()
self.conv1 = nn.Conv2d(1, 5, 5)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(5, 16, 5)
self.pool2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(256, 20)
self.fc2 = nn.Linear(20, 10)
def forward(self, x):
x = F.relu(self.pool1((self.conv1(x))))
x = F.relu(self.pool2((self.conv2(x))))
x = F.dropout2d(x, training=self.training)
x = x.view(-1, 256)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return x
Mnist_Classifier = Conv_Classifier()
from torchviz import make_dot
input_sample = torch.rand((1, 1, 28, 28))
make_dot(Mnist_Classifier(input_sample), params=dict(Mnist_Classifier.named_parameters()))
其对应的计算梯度所需的图(计算图)为
可以看到,所有的叶子节点对应的操作都被记录,以便之后的梯度回传。
multi task训练torch_手把手教你使用PyTorch(2)-requires_gradamp;computation graph相关推荐
- multi task训练torch_采用single task模型蒸馏到Multi-Task Networks
论文地址. 这篇论文主要研究利用各个single task model来分别作为teacher model,用knowledge distillation的方法指导一个multi task model ...
- 简单有趣的 NLP 教程:手把手教你用 PyTorch 辨别自然语言(附代码)
简单有趣的 NLP 教程:手把手教你用 PyTorch 辨别自然语言(附代码) 雷锋网(公众号:雷锋网)按:本文作者甄冉冉,原载于作者个人博客,雷锋网已获授权. 最近在学pyTorch的实际应用 ...
- 手把手教你用Pytorch代码实现Transformer模型(超详细的代码解读)
手把手教你用Pytorch代码实现Transformer模型(超详细代码解读)
- 实例 :手把手教你用PyTorch快速准确地建立神经网络(附4个学习用例)
作者:Shivam Bansal:翻译:陈之炎:校对:丁楠雅: 本文约5600字,建议阅读30+分钟. 本文中,我们将探讨PyTorch的全部内容.我们将不止学习理论,还包括编写4个不同的用例,看看P ...
- 手把手教你洞悉 PyTorch 模型训练过程,彻底掌握 PyTorch 项目实战!(文末重金招聘导师)...
(文末重金招募导师) 在CVPR 2020会议接收中,PyTorch 使用了405次,TensorFlow 使用了102次,PyTorch使用数是TensorFlow的近4倍. 自2019年开始,越来 ...
- multi task训练torch_Multi-task Learning的三个小知识
本文译自Deep Multi-Task Learning – 3 Lessons Learned by Zohar Komarovsky 在过去几年里,Multi-Task Learning (MTL ...
- multi task训练torch_Pytorch多机多卡分布式训练
被这东西刁难两天了,终于想办法解决掉了,来造福下人民群众. 关于Pytorch分布训练的话,大家一开始接触的往往是DataParallel,这个wrapper能够很方便的使用多张卡,而且将进程控制在一 ...
- matlab文档查阅使用训练(手把手教你阅读matlab文档)全网首发原创
本文章是为了,熟悉阅读matlab的帮助文档而设立,其实更多的应该是理论知识,我读本科的时候,刚接触matlab发现相当难使用,也不能静下心来看帮助文档,总想买本书,照着敲语法,到了研1时候,也试着买 ...
- 独家 | 手把手教你用PyTorch快速准确地建立神经网络(附4个学习用例)
作者:Shivam Bansal,2019年1月14日 翻译:陈之炎 校对:丁楠雅 本文约5600字,建议阅读30+分钟. 本文中,我们将探讨PyTorch的全部内容.我们将不止学习理论,还包括编写4 ...
最新文章
- CTF中智能合约部署交互基础
- Elasticsearch搜索引擎:ES的segment段合并原理
- @Resource注解研究和在SAP Hybris ECP中的应用
- 工业相机与民用相机的区别_工业相机和普通相机的区别详解
- oracle查看执行计划入门
- 多个redistemplate_Spring boot 使用多个RedisTemplate
- php mysql备份类_php MYSQL 数据备份类
- 计算机注销命令,Win7使用DOS命令实现定时自动关机,注销、重启的方法
- java作用域对象笔记_Java学习笔记(七)——对象
- 关于VBScript的运行环境
- Mysql和mono_c# – 让Linq与Mysql和Mono玩得很好,有可能吗?
- mysql主从不同步监控_MySQL主从同步监控
- [转]外贸出口流程图
- 仙童的ua741运算放大器内部电路
- xshell如何使用
- java oracle spatial_安装Oracle Spatial数据组件
- window10无法访问局域网共享文件夹
- 无线耳机除了苹果哪个牌子好?类似苹果耳机的蓝牙耳机推荐
- 全国高中数学联赛 2020 年二试第四题
- C#取得DataTable最大值、最小值
热门文章
- java 前置通知_spring aop中的前置通知
- jq匹配偶数行_jquery怎么实现奇偶行不同背景颜色?
- 深度ip转换器手机版app_房串串经纪人版app下载-房串串经纪人版app手机版 v1.0.0...
- c语言注释部分两侧的分界符号分别是,c语言中界定注释的符号分别是什么?
- 每日一题——王道考研2.2.4.1
- 钻井缸套排量_中国石化顺北特深层及川渝页岩气钻完井关键技术集成:碳酸盐岩酸压技术、优快钻井技术、页岩气强化体积改造技术、高温高压窄间隙固井技术...
- PHP使用SMTP邮件服务器
- [react-router] 请你说说react的路由的优缺点?
- [html] html5点击返回键怎样不让它返回上一页?
- [css] 举例说明:not()的使用场景有哪些