深入理解Pytorch之register_buffer
使用
import torch.nn as nn
import torch
class net(nn.Module):def __init__(self):super(net,self).__init__()self.register_buffer("a",torch.ones(2,3))#从此,self.a其实就是torch.ones(2,3)。def forward(self,x):return x+self.a#使用
理解
register_buffer
的作用是将torch.ones(2,3)这个tensor注册到模型的 buffers() 属性中,并命名为a,这代表a对应的是一个持久态,不会有梯度传播给它,但是能被模型的state_dict记录下来。可以理解为模型的常数。
注意,没有保存到模型的 buffers() 或 parameters() 属性中的参数是不会被记录到state_dict中的,在 buffers() 中的参数默认不会有梯度,parameters() 中的则相反。
我们可以将前者理解为常数,后者理解为变量。
- requires_grad=False
- 不会注册到模型参数中model.parameters()
- 会注册到模型model.state_dict()中。
一个很多人疑问的问题是:既然register_buffer的对象是模型中的常数,那为什么不直接使用下面的方法一,还不更直接吗?
class net(nn.Module):def __init__(self,x=None):super(net,self).__init__()self.a=torch.ones(2,3)#方法一self.register_buffer("a",torch.ones(2,3))#方法二
这么跟你说吧,如果常数是这种torch.ones(2,3)
的话,两者确实在使用体验上没有任何差别(虽然后者会把torch.ones(2,3)
这个常数注册到model.state_dict()中,前者不会)。
但是,我们可能会遇到这样的场景:那个常数不是这么简单的常数,而是外部传入的。
class net(nn.Module):def __init__(self,x=None):super(net,self).__init__()self.a=x#方法一self.register_buffer("a",x)#方法二x=**
x=***
x=**
#第一次运行的时候,你经过千辛万苦得到了模型中的常数x。
model=net(x)
#训练模型
#保存模型。
#完毕
#如果是方法一,你又要运行一遍获得x的过程。
x=**
x=***
x=**
model=net(x)
#载入模型model.load
#使用模型
#如果是方法二,不需要获得x,因为register_buffer会将常数x保存在state_dict中,载入就行了。
model=net(x)
#载入模型model.load
#使用模型
深入理解Pytorch之register_buffer相关推荐
- Lesson 7 (3) 深入理解PyTorch与PyTorch库架构
我们已经理解了神经网络是如何诞生的,也了解了怎样的算法才是一个优秀的算法,现在我们需要借助深度学习框架(Deep learning framework)来帮助我们实现神经网络算法.在本门课程中,我们所 ...
- 收藏 | 万字长文带你理解Pytorch官方Faster RCNN代码
点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者丨白裳@知乎 来源丨https://zhuanlan.z ...
- lstm 输入数据维度_理解Pytorch中LSTM的输入输出参数含义
本文不会介绍LSTM的原理,具体可看如下两篇文章 Understanding LSTM Networks DeepLearning.ai学习笔记(五)序列模型 -- week1 循环序列模型 1.举个 ...
- 深入理解Pytorch负对数似然函数(torch.nn.NLLLoss)和交叉熵损失函数(torch.nn.CrossEntropyLoss)
在看Pytorch的交叉熵损失函数torch.nn.CrossEntropyLoss官方文档介绍中,给出的表达式如下.不免有点疑惑为何交叉熵损失的表达式是这个样子的 loss ( y , clas ...
- 深入理解pytorch分布式并行处理工具DDP——从工程实战中的bug说起
近期博主在使用分布式并行处理工具DDP(DistributedDataParallel)训练单目深度估计模型Featdepth(源码地址:https://github.com/sconlyshoote ...
- 通过和resnet18和resnet50理解PyTorch的ResNet模块
文章目录 模型介绍 resnet18模型流程 总结 resnet50 总结 resnet和resnext的框架基本相同的,这里先学习下resnet的构建,感觉高度模块化,很方便.本文算是对 PyTor ...
- 深度理解Pytorch中backward()
转自https://blog.csdn.net/douhaoexia/article/details/78821428 接触pytorch很久了,也自认为对 backward 方法有一定了解,但看了这 ...
- pytorch 中register_buffer()
今天在看DSSINet代码的ssim.py时,遇到了一个用法 class NORMMSSSIM(torch.nn.Module):def __init__(self, sigma=1.0, level ...
- pytorch的register_buffer
Version pytorch 1.7 Belong torch.nn.Module Adds a buffer to the module. This is typically used to re ...
最新文章
- Git中.gitignore忽略文件(maven项目)
- centos安装JDK与Tomcat
- caffe中在某一层获得迭代次数的方法以及caffe编译时报错 error: ‘to_string‘ is not a member of ‘std‘解决方法
- ajax学生校验学号,ajax校验数据库数据是否存在
- html怎么改艺术字体颜色,html超链接字体颜色怎么改
- 全频音箱与分频音箱的区别
- 利用IP地址查询接口来查询IP归属地
- 有趣的算法(六):3分钟看懂插入排序(C语言实现)
- java基础总结06-常用api类-包装类
- debian编译openjdk8
- Linux的PDF工具,Linux 系统中的pdf阅读器以及工具
- 刘铁猛-深入浅出WPF-系列资源汇总
- 计算机组成原理不恢复余数法,计算机组成原理第七讲(除法-原码恢复余数法)(科大罗克露)...
- 缅甸投资环境及法律政策简介
- 360html文件打不开,为什么360安全卫士打不开
- 多年的人工智能安全机制争议 检察官、行刑者与道德家这么表示
- 用python写情书_《使用Python进行自然语言处理》学习笔记一 | 学步园
- 旅游行业数字化进程分析——2023年元旦与春节,旅游市场开启复苏模式,跨省游热度上升
- 论文当中图片保存png、pdf等等的要分辨率DPI
- 估算成本 制定预算 区别
热门文章
- 1.6 万字长文带你读懂 Java IO
- 谁说数学不好,就不能成为编程大佬
- JavaScript异步编程:异步的数据收集方法
- 人声提取工具Spleeter安装教程(linux)
- 基于MTCNN的人脸自动对齐技术原理及其Tensorflow实现测试
- 薛澜:人工智能发展要让创新驱动和敏捷治理并驾齐驱
- ACL 2021 | 腾讯AI Lab、港中文杰出论文:用单语记忆实现高性能NMT
- 20亿参数+30亿张图像,刷新ImageNet最高分!谷歌大脑华人研究员领衔发布最强Transformer...
- 快手开源斗地主AI,入选ICML,能否干得过「冠军」柯洁?
- 深度解析 | 大数据面前,统计学的价值在哪里?