目录

1 parameters()

1.1 model.parameters():

1.2 model.named_parameters():

2 state_dict()


torch.nn.Module 模块中的可学习参数都被包含在该模型的parameters 中,可以通过model.parameters()的方法获取;

state_dict()是一 个字典,包含了模型各的参数(tensor类型),多用于保存模型;

1 parameters()

1.1 model.parameters():

源码:

    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:r"""Returns an iterator over module parameters.返回模块参数上的迭代器。This is typically passed to an optimizer.这通常被传递给优化器Args:recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.如果为True, 则生成该模块 及其所有子模块的参数。否则,只生成该模块的直接成员的形参。Yields:Parameter: module parameterExample::>>> for param in model.parameters():>>>     print(type(param), param.size())<class 'torch.Tensor'> (20L,)<class 'torch.Tensor'> (20L, 1L, 5L, 5L)"""for name, param in self.named_parameters(recurse=recurse):yield param

可以通过Module.parameters()获得网络参数, 迭代的返回模型所有可学习的参数 --  是个生成器

有些layer不包含可学习的参数,比如(relu, maxpool),因此model.parameters()不会输出这些层;

parameters()多见于优化器的初始化;

由于parameters()是生成器,因此需要利用循环或者next()来获取数据:

例子:

>>> import torch
>>> import torch.nn as nn>>> class Net(nn.Module):
...     def __init__(self):
...             super().__init__()
...             self.linear = nn.Linear(2,2)
...     def forward(self,x):
...             out = self.linear(x)
...             return out
...
>>> net = Net()
>>> for para in net.parameters():
...     print(para)
... Parameter containing:
tensor([[-0.1954, -0.2290],[ 0.5897, -0.3970]], requires_grad=True)
Parameter containing:
tensor([-0.1808,  0.2044], requires_grad=True)>>> for para in net.named_parameters():
...     print(para)
...
('linear.weight', Parameter containing:
tensor([[-0.1954, -0.2290],[ 0.5897, -0.3970]], requires_grad=True))
('linear.bias', Parameter containing:
tensor([-0.1808,  0.2044], requires_grad=True))

1.2 model.named_parameters():

是带有layer name的model.parameters(),其以tuple方式输出,其中包含两个元素,分别为layer name和 model.parameters;

layer name有后缀 .weight, .bias用于区分权重和偏置;

源码:

    def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:r"""Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.返回模块参数上的迭代器,生成参数名和参数本身。Args:prefix (str): prefix to prepend to all parameter names.recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.如果为True,则生成该模块及其所有子模块的参数。否则,只生成该模块的直接成员的形参。Yields:(string, Parameter): Tuple containing the name and parameterExample::>>> for name, param in self.named_parameters():>>>    if name in ['bias']:>>>        print(param.size())"""gen = self._named_members(lambda module: module._parameters.items(),prefix=prefix, recurse=recurse)for elem in gen:yield elem

代码例子,看1.1部分;

2 state_dict()

model.state_dict()能够获得模型所有的参数,包括可学习的参数和不可学习的参数,返回值是一个有序字典OrderedDict.

这部分相当于在model.parameters()基础上,又额外获取了不可学习的参数部分;

例子:

key值是对网络参数的说明,这里是线性层的weight和bias;

>>> class Net(nn.Module):
...     def __init__(self):
...             super().__init__()
...             self.linear = nn.Linear(10,8)
...             self.dropout = nn.Dropout(0.5)
...             self.linear1 = nn.Linear(8,2)
...     def forward(self,x):
...             out = self.dropout(self.linear(x))
...             out = self.linear1(out)
...             return out
...
>>> net = Net()
>>> net.state_dict()
OrderedDict([('linear.weight', tensor([[ 0.1415, -0.2228, -0.1262,  0.0992, -0.1600,  0.0141, -0.1841, -0.1907,0.0295, -0.1853],[-0.0399, -0.2487, -0.3085,  0.1602,  0.3135,  0.1379,  0.0696,  0.0362,-0.1619, -0.0887],[-0.1244, -0.1739,  0.1211, -0.2578, -0.0561,  0.0635, -0.1976, -0.2557,0.1761,  0.2553],[ 0.0912, -0.1469, -0.3012, -0.1583, -0.0028,  0.2697,  0.1947, -0.0596,-0.2144, -0.0785],[-0.1770,  0.0411,  0.1663,  0.1861,  0.2769,  0.0990,  0.1883, -0.1801,0.2727,  0.1219],[-0.1269,  0.0713,  0.2798,  0.1760,  0.0965,  0.1144,  0.2644,  0.0274,0.0034,  0.2702],[ 0.0628,  0.0682, -0.1842,  0.1461,  0.0678, -0.2264, -0.1249, -0.1715,0.1115,  0.2459],[ 0.1198, -0.2584,  0.0234,  0.2756,  0.1174, -0.1212,  0.3024, -0.2304,-0.2950,  0.0970]])), ('linear.bias', tensor([-0.3036, -0.1933,  0.2412,  0.3137, -0.3007,  0.2386, -0.1975,  0.3127])), ('linear1.weight', tensor([[-0.1725,  0.3027,  0.1985,  0.1394, -0.1245,  0.2913,  0.0136,  0.1633],[-0.1558, -0.0865, -0.3032,  0.1374,  0.2967, -0.2886,  0.0430, -0.1246]])), ('linear1.bias', tensor([-0.1232, -0.0690]))])
>>>

参考:PyTorch中model.state_dict(),model.modules(),model.children(),model.named_children()等含义_yaoyz105的博客-CSDN博客_model.state_dict()

model.parameters()与model.state_dict() - 知乎

pytorch - state_dict() , parameters() 详解相关推荐

  1. pytorch MSELoss参数详解

    pytorch MSELoss参数详解 import torch import numpy as np loss_fn = torch.nn.MSELoss(reduce=False, size_av ...

  2. pytorch实战:详解查准率(Precision)、查全率(Recall)与F1

    pytorch实战:详解查准率(Precision).查全率(Recall)与F1 1.概述 本文首先介绍了机器学习分类问题的性能指标查准率(Precision).查全率(Recall)与F1度量,阐 ...

  3. PyTorch Python API详解大全(持续更新ing...)

    诸神缄默不语-个人CSDN博文目录 具体内容以官方文档为准. 最早更新时间:2021.4.23 最近更新时间:2023.1.9 文章目录 0. 常用入参及函数统一解释 1. torch 1.1 Ten ...

  4. 【小白学PyTorch】10.pytorch常见运算详解

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 <<小白学PyTorch>> 参考目录: ...

  5. 3 矩阵运算_小白学PyTorch——pytorch常见运算详解

    公众号关注 "DL-CVer" 设为 "星标",DLCV消息即可送达! 参考目录: 1 矩阵与标量 2 哈达玛积 3 矩阵乘法 4 幂与开方 5 对数运算 6 ...

  6. python训练手势分类器_使用Pytorch训练分类器详解(附python演练)

    [前言]:你已经了解了如何定义神经网络,计算loss值和网络里权重的更新.现在你也许会想数据怎么样? 目录: 一.数据 二.训练一个图像分类器 使用torchvision加载并且归一化CIFAR10的 ...

  7. PyTorch 的 Autograd详解

    ↑ 点击蓝字 关注视学算法 作者丨xiaopl@知乎 来源丨https://zhuanlan.zhihu.com/p/69294347 编辑丨极市平台 PyTorch 作为一个深度学习平台,在深度学习 ...

  8. Pytorch LSTM初识(详解LSTM+torch.nn.LSTM()实现)1

    pytorch  LSTM1初识 目录 pytorch  LSTM1初识 ​​​​​​​​​​​​​​​​​​​​​ 一.LSTM简介1

  9. win10开始不显示python_win10从零安装配置pytorch全过程图文详解

    1.安装anaconda (anaconda内置python在内的许多package,所以不用另外下载python) 可以点击下面的清华开源软件镜像站,在官网下载anaconda不如在这下的快 htt ...

最新文章

  1. leetcode算法题--最后一块石头的重量 II★
  2. LiteDB源码解析系列(3)索引原理详解
  3. 做了44年保洁员,一生只会5个字,她却成为香港大学院士
  4. 各种抠图动态图片_不用手。自动、智能抠图,图片去背景
  5. 人口增长(信息学奥赛一本通-T1070)
  6. 高级定时器的各种框图和HAL库重要结构
  7. Struts入门经验(二)
  8. excel表中怎么插入visio_如何插入或 Visio 中粘贴的 Excel 工作表-阿里云开发者社区...
  9. 一起来学SpringBoot | 第二篇:SpringBoot配置详解
  10. eclipse报错: Unhandled event loop exception No more handles
  11. iOS UIDatePicker
  12. 08服务器端口映射,windows_Server_2008_R2_NAT服务器_端口映射.pdf
  13. 随机过程及其在金融领域中的应用 第二章 习题 及 答案
  14. 劲爆!群晖docker视频
  15. 流媒体弱网优化之路(FEC+mediasoup)——mediasoup的Nack优化以及FEC引入
  16. vue element-ui实现金额数字添加千分位并保留两位小数
  17. 用blockly制作诗词学习游戏
  18. getshell之Nexus远程命令执行(CVE-2020-10199)
  19. C#连接Access数据库(详解)
  20. Vue 2 即将成为过去

热门文章

  1. Web课程设计——“念念手账”网页APP制作
  2. java计算机毕业设计高校共享单车管理系统(附源码、数据库)
  3. ADAS“中国战事”升级
  4. 双边滤波(Bilateral filter)原理介绍及matlab程序实现
  5. 极智AI | Pytorch 中常用乘法的 TensorRT 实现
  6. 中软国际用一场自我进化,推动云市场跨入下一幕
  7. 腾讯 AI Lab 2018年度回顾
  8. 没错,这就是AIR-CT2504-K9的内心!
  9. iframe 在firefox火狐浏览器 动态获取内容不展示问题
  10. C++无依赖库的websocket实现