1 模型的两种参数

在 Pytorch 中一种模型保存和加载的方式如下:(具体见pytorch模型的保存与加载_刘文巾的博客-CSDN博客)

#save
torch.save(net.state_dict(),PATH)#load
model=MyModel(*args,**kwargs)
model.load_state_dict(torch.load(PATH))
model.eval

模型保存的是 net.state_dict() 的返回对象。

net.state_dict() 的返回对象是一个 OrderDict ,它以键值对的形式包含模型中需要保存下来的参数

上例模型中的参数就是线性层的 weight 和 bias.

模型中需要保存下来的参数包括两种:

  • 一种是反向传播需要被optimizer更新的,称之为 parameter
  • 一种是反向传播不需要被optimizer更新,称之为 buffer

第一种参数我们可以通过 model.parameters() 返回;

第二种参数我们可以通过 model.buffers() 返回。

因为我们的模型保存的是 state_dict 返回的 OrderDict,所以这两种参数不仅要满足是否需要被更新的要求,还需要被保存到OrderDict。

2 Parameter

Parameter参数有两种创建方式:

  1. 我们可以直接将模型的成员变量self.xxx通过nn.Parameter() 创建,会自动注册到parameters中,可以通过model.parameters() 返回,并且这样创建的参数会自动保存到OrderDict中去;
  2. 通过nn.Parameter() 创建普通Parameter对象,不作为模型的成员变量,然后将Parameter对象通过register_parameter()进行注册,可以通过model.parameters() 返回,注册后的参数也会自动保存到OrderDict中去;

像我们前面的nn.Conv1d,nn.Linear,nn.RNN等模型,里面的权重参数等会被自动认为是Parameter 参数

3 buffer

buffer参数我们需要创建tensor, 然后将tensor通过register_buffer()进行注册,可以通过model.buffers() 返回,注册完后参数也会自动保存到OrderDict中去。

4 为什么要注册​​​​​​​

为什么不直接将不需要进行参数修改的变量作为模型类的成员变量就好了,还要进行注册?

  1. 不进行注册,参数不能保存到 OrderDict,也就无法进行保存
  2. 模型进行参数在CPU和GPU移动时, 执行 model.to(device) ,注册后的参数可以自动进行设备移动

5 实例说明

import torch
class net(torch.nn.Module):def __init__(self):super(net,self).__init__()#创建bufferself.register_buffer('my_buffer',torch.Tensor([1,2,3]))self.a=torch.Tensor([1])self.param1=torch.nn.Parameter(torch.Tensor([1,3,5,7,9]))#方法1 创建的parameterparam2=torch.nn.Parameter(torch.Tensor([2,4,6,8,0]))self.register_parameter('param2',param2)self.l=torch.nn.Linear(1,10)def forward(self,x):passn=net()for i in n.state_dict():print(i,n.state_dict()[i])
print('*'*10)
for i in n.parameters():print(i)
print('*'*10)
for i in n.buffers():print(i)
print('*'*10)'''
param1 tensor([1., 3., 5., 7., 9.])
param2 tensor([2., 4., 6., 8., 0.])
my_buffer tensor([1., 2., 3.])
l.weight tensor([[-0.1490],[-0.2445],[-0.5296],[-0.3687],[-0.9683],[ 0.3491],[-0.8726],[-0.7213],[ 0.3201],[-0.9994]])
l.bias tensor([ 0.6718,  0.3055,  0.7755,  0.3780, -0.8169,  0.3663, -0.6937, -0.3136,0.6907,  0.8732])
**********
Parameter containing:
tensor([1., 3., 5., 7., 9.], requires_grad=True)
Parameter containing:
tensor([2., 4., 6., 8., 0.], requires_grad=True)
Parameter containing:
tensor([[-0.1490],[-0.2445],[-0.5296],[-0.3687],[-0.9683],[ 0.3491],[-0.8726],[-0.7213],[ 0.3201],[-0.9994]], requires_grad=True)
Parameter containing:
tensor([ 0.6718,  0.3055,  0.7755,  0.3780, -0.8169,  0.3663, -0.6937, -0.3136,0.6907,  0.8732], requires_grad=True)
**********
tensor([1., 2., 3.])
**********
'''

pytorch笔记 pytorch模型中的parameter与buffer相关推荐

  1. PyTorch 笔记Ⅱ——PyTorch 自动求导机制

    文章目录 Autograd: 自动求导机制 张量(Tensor) 梯度 使用PyTorch计算梯度数值 Autograd 简单的自动求导 复杂的自动求导 Autograd 过程解析 扩展Autogra ...

  2. 『Pytorch笔记』Pycharm中使用CUDA_VISIBLE_DEVICES=0!

    Pycharm中使用CUDA_VISIBLE_DEVICES=0! 如果使用多GPU运行程序(或者指定GPU的个数),可以直接使用CUDA_VISIBLE_DEVICES=0,1,2,3python ...

  3. pytorch 笔记:手动实现AR (auto regressive)

    1 导入库& 数据说明 import numpy as np import torch import matplotlib.pyplot as plt from tensorboardX im ...

  4. 【多输入模型 Multiple-Dimension 数学原理分析以及源码详解 深度学习 Pytorch笔记 B站刘二大人 (6/10)】

    多输入模型 Multiple-Dimension 数学原理分析以及源码源码详解 深度学习 Pytorch笔记 B站刘二大人(6/10) 数学推导 在之前实现的模型普遍都是单输入单输出模型,显然,在现实 ...

  5. (pytorch-深度学习系列)pytorch构造深度学习模型-学习笔记

    pytorch构造深度学习模型 1. 通过继承module类的方式来构造模型 Module类是nn模块里提供的一个模型构造类,是所有神经网络模块的基类. 可以继承基类并重构 __init()__函数和 ...

  6. pytorch 模型中的bn层一键转化为同步bn(syncbn)

    pytorch 将模型中的所有BatchNorm2d layer转换为SyncBatchNorm layer: (单机多卡设置下) import torch.distributed as distdi ...

  7. 【pytorch笔记】(五)自定义损失函数、学习率衰减、模型微调

    本文目录: 1. 自定义损失函数 2. 动态调整学习率 3. 模型微调-torchvision 3.1 使用已有模型 3.2 训练特定层 1. 自定义损失函数 虽然pytorch提供了许多常用的损失函 ...

  8. 『PyTorch』学习笔记 2 —— 模型 Finetune

    目录 前言 1. 为什么要 Model Finetune? 2. 模型微调的步骤 3. 模型微调训练方法 4. 示例(finetune_resnet18) 4.1 不使用trick:所有的参数使用同一 ...

  9. Pytorch两种模型保存方式

    以字典方式保存,更容易解析和可视化 Pytorch两种模型保存方式 大黑_7e1b关注 2019.02.12 17:49:35字数 13阅读 5,907 只保存模型参数 # 保存 torch.save ...

最新文章

  1. 实验三:XML模型(一)
  2. java的基本数据类型有
  3. python实现基于八方向判断的断裂连接
  4. 【Spring MVC学习】spring mvc入门示例
  5. PHP-什么是PHP?为什么用PHP?有谁在用PHP?
  6. 机器人实现屠宰自动化
  7. java开发人员_Java 8:开发人员怎么看?
  8. textview点击展开全部或收起,内容过长显示省略号,设置行间距,字间距,跑马灯显示
  9. 从Java到Go面向对象--继承思想.md
  10. Oracle中 drop user 和 drop user cascade 的区别
  11. rabbitmq在exchange下的两种使用模式
  12. HTML实现简易音乐网站
  13. 字体大宝库:40套为网页设计师准备的时尚字体(下篇)
  14. FPGA实现的线性反馈移位寄存器LFSR
  15. 如何选择合适的代理IP?以下3点需要注意
  16. AD14.3绘制PCB教程
  17. excel 自定义参数(text函数)
  18. java+MySQL基于ssm的公文流转关管理系统
  19. 海康威视rtsp转rtmp(java稳定版)
  20. 蓝桥杯 特殊的回文数 C语言

热门文章

  1. 基于QT Plugin框架结构
  2. leetcode 8. String to Integer (atoi)
  3. [Java开发之路]Java字符串
  4. 异常:System.BadImageFormatException,未能加载正确的程序集XXX或其某一依赖项
  5. Http Status 304响应状态的资源更新机制
  6. windows puppet manifests 文件维护
  7. 移动互联网服务客户端开发技巧 ( Webview及正则)
  8. 如何解决Beyond Compare内容相同仍然标示红色
  9. kattis ones简单题取模运算+枚举
  10. 基于labview的温湿度数据采集_【零偏原创】基于FPGA的多路SPI接口并行数据采集系统...