使用

import torch.nn as nn
import torch
class net(nn.Module):def __init__(self):super(net,self).__init__()self.register_buffer("a",torch.ones(2,3))#从此,self.a其实就是torch.ones(2,3)。def forward(self,x):return x+self.a#使用

理解

register_buffer的作用是将torch.ones(2,3)这个tensor注册到模型的 buffers() 属性中,并命名为a,这代表a对应的是一个持久态,不会有梯度传播给它,但是能被模型的state_dict记录下来。可以理解为模型的常数

注意,没有保存到模型的 buffers() 或 parameters() 属性中的参数是不会被记录到state_dict中的,在 buffers() 中的参数默认不会有梯度,parameters() 中的则相反。

我们可以将前者理解为常数,后者理解为变量。

  1. requires_grad=False
  2. 不会注册到模型参数中model.parameters()
  3. 会注册到模型model.state_dict()中。

一个很多人疑问的问题是:既然register_buffer的对象是模型中的常数,那为什么不直接使用下面的方法一,还不更直接吗?

class net(nn.Module):def __init__(self,x=None):super(net,self).__init__()self.a=torch.ones(2,3)#方法一self.register_buffer("a",torch.ones(2,3))#方法二

这么跟你说吧,如果常数是这种torch.ones(2,3)的话,两者确实在使用体验上没有任何差别(虽然后者会把torch.ones(2,3)这个常数注册到model.state_dict()中,前者不会)。

但是,我们可能会遇到这样的场景:那个常数不是这么简单的常数,而是外部传入的。

class net(nn.Module):def __init__(self,x=None):super(net,self).__init__()self.a=x#方法一self.register_buffer("a",x)#方法二x=**
x=***
x=**
#第一次运行的时候,你经过千辛万苦得到了模型中的常数x。
model=net(x)
#训练模型
#保存模型。
#完毕
#如果是方法一,你又要运行一遍获得x的过程。
x=**
x=***
x=**
model=net(x)
#载入模型model.load
#使用模型
#如果是方法二,不需要获得x,因为register_buffer会将常数x保存在state_dict中,载入就行了。
model=net(x)
#载入模型model.load
#使用模型

深入理解Pytorch之register_buffer相关推荐

  1. Lesson 7 (3) 深入理解PyTorch与PyTorch库架构

    我们已经理解了神经网络是如何诞生的,也了解了怎样的算法才是一个优秀的算法,现在我们需要借助深度学习框架(Deep learning framework)来帮助我们实现神经网络算法.在本门课程中,我们所 ...

  2. 收藏 | 万字长文带你理解Pytorch官方Faster RCNN代码

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者丨白裳@知乎 来源丨https://zhuanlan.z ...

  3. lstm 输入数据维度_理解Pytorch中LSTM的输入输出参数含义

    本文不会介绍LSTM的原理,具体可看如下两篇文章 Understanding LSTM Networks DeepLearning.ai学习笔记(五)序列模型 -- week1 循环序列模型 1.举个 ...

  4. 深入理解Pytorch负对数似然函数(torch.nn.NLLLoss)和交叉熵损失函数(torch.nn.CrossEntropyLoss)

    在看Pytorch的交叉熵损失函数torch.nn.CrossEntropyLoss官方文档介绍中,给出的表达式如下.不免有点疑惑为何交叉熵损失的表达式是这个样子的 loss ⁡ ( y , clas ...

  5. 深入理解pytorch分布式并行处理工具DDP——从工程实战中的bug说起

    近期博主在使用分布式并行处理工具DDP(DistributedDataParallel)训练单目深度估计模型Featdepth(源码地址:https://github.com/sconlyshoote ...

  6. 通过和resnet18和resnet50理解PyTorch的ResNet模块

    文章目录 模型介绍 resnet18模型流程 总结 resnet50 总结 resnet和resnext的框架基本相同的,这里先学习下resnet的构建,感觉高度模块化,很方便.本文算是对 PyTor ...

  7. 深度理解Pytorch中backward()

    转自https://blog.csdn.net/douhaoexia/article/details/78821428 接触pytorch很久了,也自认为对 backward 方法有一定了解,但看了这 ...

  8. pytorch 中register_buffer()

    今天在看DSSINet代码的ssim.py时,遇到了一个用法 class NORMMSSSIM(torch.nn.Module):def __init__(self, sigma=1.0, level ...

  9. pytorch的register_buffer

    Version pytorch 1.7 Belong torch.nn.Module Adds a buffer to the module. This is typically used to re ...

最新文章

  1. Git中.gitignore忽略文件(maven项目)
  2. centos安装JDK与Tomcat
  3. caffe中在某一层获得迭代次数的方法以及caffe编译时报错 error: ‘to_string‘ is not a member of ‘std‘解决方法
  4. ajax学生校验学号,ajax校验数据库数据是否存在
  5. html怎么改艺术字体颜色,html超链接字体颜色怎么改
  6. 全频音箱与分频音箱的区别
  7. 利用IP地址查询接口来查询IP归属地
  8. 有趣的算法(六):3分钟看懂插入排序(C语言实现)
  9. java基础总结06-常用api类-包装类
  10. debian编译openjdk8
  11. Linux的PDF工具,Linux 系统中的pdf阅读器以及工具
  12. 刘铁猛-深入浅出WPF-系列资源汇总
  13. 计算机组成原理不恢复余数法,计算机组成原理第七讲(除法-原码恢复余数法)(科大罗克露)...
  14. 缅甸投资环境及法律政策简介
  15. 360html文件打不开,为什么360安全卫士打不开
  16. 多年的人工智能安全机制争议 检察官、行刑者与道德家这么表示
  17. 用python写情书_《使用Python进行自然语言处理》学习笔记一 | 学步园
  18. 旅游行业数字化进程分析——2023年元旦与春节,旅游市场开启复苏模式,跨省游热度上升
  19. 论文当中图片保存png、pdf等等的要分辨率DPI
  20. 估算成本 制定预算 区别

热门文章

  1. 1.6 万字长文带你读懂 Java IO
  2. 谁说数学不好,就不能成为编程大佬
  3. JavaScript异步编程:异步的数据收集方法
  4. 人声提取工具Spleeter安装教程(linux)
  5. 基于MTCNN的人脸自动对齐技术原理及其Tensorflow实现测试
  6. 薛澜:人工智能发展要让创新驱动和敏捷治理并驾齐驱
  7. ACL 2021 | 腾讯AI Lab、港中文杰出论文:用单语记忆实现高性能NMT
  8. 20亿参数+30亿张图像,刷新ImageNet最高分!谷歌大脑华人研究员领衔发布最强Transformer...
  9. 快手开源斗地主AI,入选ICML,能否干得过「冠军」柯洁?
  10. 深度解析 | 大数据面前,统计学的价值在哪里?