import torch

1. Requires_grad

但是,模型毕竟不是人,它的智力水平还不足够去自主辨识那些量的梯度需要计算,既然如此,就需要手动对其进行标记。

在PyTorch中,通用的数据结构tensor包含一个attributerequires_grad,它被用于说明当前量是否需要在计算中保留对应的梯度信息,以上文所述的线性回归为例,容易知道参数www为需要训练的对象,为了得到最合适的参数值,我们需要设置一个相关的损失函数,根据梯度回传的思路进行训练。

官方文档中的说明如下

If there’s a single input to an operation that requires gradient, its output will also require gradient.

只要某一个输入需要相关梯度值,则输出也需要保存相关梯度信息,这样就保证了这个输入的梯度回传。

而反之,若所有的输入都不需要保存梯度,那么输出的requires_grad会自动设置为False。既然没有了相关的梯度值,自然进行反向传播时会将这部分子图从计算中剔除。

Conversely, only if all inputs don’t require gradient, the output also won’t require it. Backward computation is never performed in the subgraphs, where all Tensors didn’t require gradients.

对于那些要求梯度的tensor,PyTorch会存储他们相关梯度信息和产生他们的操作,这产生额外内存消耗,为了优化内存使用,默认产生的tensor是不需要梯度的。

而我们在使用神经网络时,这些全连接层卷积层等结构的参数都是默认需要梯度的。

a = torch.tensor([1., 2., 3.])

print('a:', a.requires_grad)

b = torch.tensor([1., 4., 2.], requires_grad = True)

print('b:', b.requires_grad)

print('sum of a and b:', (a+b).requires_grad)

a: False

b: True

sum of a and b: True

2. Computation Graph

从PyTorch的设计原理上来说,在每次进行前向计算得到pred时,会产生一个用于梯度回传的计算图,这张图储存了进行back propagation需要的中间结果,当调用了.backward()后,会从内存中将这张图进行释放

这张计算图保存了计算的相关历史和提取计算所需的所有信息,以output作为root节点,以input和所有的参数为leaf节点,

we only retain the grad of the leaf node with requires_grad =True

在完成了前向计算的同时,PyTorch也获得了一张由计算梯度所需要的函数所组成的图

而从数据集中获得的input其requires_grad为False,故我们只会保存参数的梯度,进一步据此进行参数优化

在PyTorch中,multi-task任务一个标准的train from scratch流程为

for idx, data in enumerate(train_loader):

xs, ys = data

optmizer.zero_grad()

# 计算d(l1)/d(x)

pred1 = model1(xs) #生成graph1

loss = loss_fn1(pred1, ys)

loss.backward() #释放graph1

# 计算d(l2)/d(x)

pred2 = model2(xs)#生成graph2

loss2 = loss_fn2(pred2, ys)

loss.backward() #释放graph2

# 使用d(l1)/d(x)+d(l2)/d(x)进行优化

optmizer.step()

Computation Graph本质上是一个operation的图,所有的节点都是一个operation,而进行相应计算的参数则以叶节点的形式进行输入

借助torchviz库以下面的模型作为示例

import torch.nn.functional as F

import torch.nn as nn

class Conv_Classifier(nn.Module):

def __init__(self):

super(Conv_Classifier, self).__init__()

self.conv1 = nn.Conv2d(1, 5, 5)

self.pool1 = nn.MaxPool2d(2)

self.conv2 = nn.Conv2d(5, 16, 5)

self.pool2 = nn.MaxPool2d(2)

self.fc1 = nn.Linear(256, 20)

self.fc2 = nn.Linear(20, 10)

def forward(self, x):

x = F.relu(self.pool1((self.conv1(x))))

x = F.relu(self.pool2((self.conv2(x))))

x = F.dropout2d(x, training=self.training)

x = x.view(-1, 256)

x = F.relu(self.fc1(x))

x = F.relu(self.fc2(x))

return x

Mnist_Classifier = Conv_Classifier()

from torchviz import make_dot

input_sample = torch.rand((1, 1, 28, 28))

make_dot(Mnist_Classifier(input_sample), params=dict(Mnist_Classifier.named_parameters()))

其对应的计算梯度所需的图(计算图)为

可以看到,所有的叶子节点对应的操作都被记录,以便之后的梯度回传。

multi task训练torch_手把手教你使用PyTorch(2)-requires_gradamp;computation graph相关推荐

  1. multi task训练torch_采用single task模型蒸馏到Multi-Task Networks

    论文地址. 这篇论文主要研究利用各个single task model来分别作为teacher model,用knowledge distillation的方法指导一个multi task model ...

  2. 简单有趣的 NLP 教程:手把手教你用 PyTorch 辨别自然语言(附代码)

     简单有趣的 NLP 教程:手把手教你用 PyTorch 辨别自然语言(附代码) 雷锋网(公众号:雷锋网)按:本文作者甄冉冉,原载于作者个人博客,雷锋网已获授权. 最近在学pyTorch的实际应用 ...

  3. 手把手教你用Pytorch代码实现Transformer模型(超详细的代码解读)

    手把手教你用Pytorch代码实现Transformer模型(超详细代码解读)

  4. 实例 :手把手教你用PyTorch快速准确地建立神经网络(附4个学习用例)

    作者:Shivam Bansal:翻译:陈之炎:校对:丁楠雅: 本文约5600字,建议阅读30+分钟. 本文中,我们将探讨PyTorch的全部内容.我们将不止学习理论,还包括编写4个不同的用例,看看P ...

  5. 手把手教你洞悉 PyTorch 模型训练过程,彻底掌握 PyTorch 项目实战!(文末重金招聘导师)...

    (文末重金招募导师) 在CVPR 2020会议接收中,PyTorch 使用了405次,TensorFlow 使用了102次,PyTorch使用数是TensorFlow的近4倍. 自2019年开始,越来 ...

  6. multi task训练torch_Multi-task Learning的三个小知识

    本文译自Deep Multi-Task Learning – 3 Lessons Learned by Zohar Komarovsky 在过去几年里,Multi-Task Learning (MTL ...

  7. multi task训练torch_Pytorch多机多卡分布式训练

    被这东西刁难两天了,终于想办法解决掉了,来造福下人民群众. 关于Pytorch分布训练的话,大家一开始接触的往往是DataParallel,这个wrapper能够很方便的使用多张卡,而且将进程控制在一 ...

  8. matlab文档查阅使用训练(手把手教你阅读matlab文档)全网首发原创

    本文章是为了,熟悉阅读matlab的帮助文档而设立,其实更多的应该是理论知识,我读本科的时候,刚接触matlab发现相当难使用,也不能静下心来看帮助文档,总想买本书,照着敲语法,到了研1时候,也试着买 ...

  9. 独家 | 手把手教你用PyTorch快速准确地建立神经网络(附4个学习用例)

    作者:Shivam Bansal,2019年1月14日 翻译:陈之炎 校对:丁楠雅 本文约5600字,建议阅读30+分钟. 本文中,我们将探讨PyTorch的全部内容.我们将不止学习理论,还包括编写4 ...

最新文章

  1. CTF中智能合约部署交互基础
  2. Elasticsearch搜索引擎:ES的segment段合并原理
  3. @Resource注解研究和在SAP Hybris ECP中的应用
  4. 工业相机与民用相机的区别_工业相机和普通相机的区别详解
  5. oracle查看执行计划入门
  6. 多个redistemplate_Spring boot 使用多个RedisTemplate
  7. php mysql备份类_php MYSQL 数据备份类
  8. 计算机注销命令,Win7使用DOS命令实现定时自动关机,注销、重启的方法
  9. java作用域对象笔记_Java学习笔记(七)——对象
  10. 关于VBScript的运行环境
  11. Mysql和mono_c# – 让Linq与Mysql和Mono玩得很好,有可能吗?
  12. mysql主从不同步监控_MySQL主从同步监控
  13. [转]外贸出口流程图
  14. 仙童的ua741运算放大器内部电路
  15. xshell如何使用
  16. java oracle spatial_安装Oracle Spatial数据组件
  17. window10无法访问局域网共享文件夹
  18. 无线耳机除了苹果哪个牌子好?类似苹果耳机的蓝牙耳机推荐
  19. 全国高中数学联赛 2020 年二试第四题
  20. C#取得DataTable最大值、最小值

热门文章

  1. java 前置通知_spring aop中的前置通知
  2. jq匹配偶数行_jquery怎么实现奇偶行不同背景颜色?
  3. 深度ip转换器手机版app_房串串经纪人版app下载-房串串经纪人版app手机版 v1.0.0...
  4. c语言注释部分两侧的分界符号分别是,c语言中界定注释的符号分别是什么?
  5. 每日一题——王道考研2.2.4.1
  6. 钻井缸套排量_中国石化顺北特深层及川渝页岩气钻完井关键技术集成:碳酸盐岩酸压技术、优快钻井技术、页岩气强化体积改造技术、高温高压窄间隙固井技术...
  7. PHP使用SMTP邮件服务器
  8. [react-router] 请你说说react的路由的优缺点?
  9. [html] html5点击返回键怎样不让它返回上一页?
  10. [css] 举例说明:not()的使用场景有哪些