pytorch - state_dict() , parameters() 详解
目录
1 parameters()
1.1 model.parameters():
1.2 model.named_parameters():
2 state_dict()
torch.nn.Module 模块中的可学习参数都被包含在该模型的parameters 中,可以通过model.parameters()的方法获取;
state_dict()是一 个字典,包含了模型各的参数(tensor类型),多用于保存模型;
1 parameters()
1.1 model.parameters():
源码:
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:r"""Returns an iterator over module parameters.返回模块参数上的迭代器。This is typically passed to an optimizer.这通常被传递给优化器Args:recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.如果为True, 则生成该模块 及其所有子模块的参数。否则,只生成该模块的直接成员的形参。Yields:Parameter: module parameterExample::>>> for param in model.parameters():>>> print(type(param), param.size())<class 'torch.Tensor'> (20L,)<class 'torch.Tensor'> (20L, 1L, 5L, 5L)"""for name, param in self.named_parameters(recurse=recurse):yield param
可以通过Module.parameters()获得网络参数, 迭代的返回模型所有可学习的参数 -- 是个生成器;
有些layer不包含可学习的参数,比如(relu, maxpool),因此model.parameters()不会输出这些层;
parameters()多见于优化器的初始化;
由于parameters()是生成器,因此需要利用循环或者next()来获取数据:
例子:
>>> import torch
>>> import torch.nn as nn>>> class Net(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(2,2)
... def forward(self,x):
... out = self.linear(x)
... return out
...
>>> net = Net()
>>> for para in net.parameters():
... print(para)
... Parameter containing:
tensor([[-0.1954, -0.2290],[ 0.5897, -0.3970]], requires_grad=True)
Parameter containing:
tensor([-0.1808, 0.2044], requires_grad=True)>>> for para in net.named_parameters():
... print(para)
...
('linear.weight', Parameter containing:
tensor([[-0.1954, -0.2290],[ 0.5897, -0.3970]], requires_grad=True))
('linear.bias', Parameter containing:
tensor([-0.1808, 0.2044], requires_grad=True))
1.2 model.named_parameters():
是带有layer name的model.parameters(),其以tuple方式输出,其中包含两个元素,分别为layer name和 model.parameters;
layer name有后缀 .weight, .bias用于区分权重和偏置;
源码:
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:r"""Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.返回模块参数上的迭代器,生成参数名和参数本身。Args:prefix (str): prefix to prepend to all parameter names.recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.如果为True,则生成该模块及其所有子模块的参数。否则,只生成该模块的直接成员的形参。Yields:(string, Parameter): Tuple containing the name and parameterExample::>>> for name, param in self.named_parameters():>>> if name in ['bias']:>>> print(param.size())"""gen = self._named_members(lambda module: module._parameters.items(),prefix=prefix, recurse=recurse)for elem in gen:yield elem
代码例子,看1.1部分;
2 state_dict()
model.state_dict()能够获得模型所有的参数,包括可学习的参数和不可学习的参数,返回值是一个有序字典OrderedDict.
这部分相当于在model.parameters()基础上,又额外获取了不可学习的参数部分;
例子:
key值是对网络参数的说明,这里是线性层的weight和bias;
>>> class Net(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(10,8)
... self.dropout = nn.Dropout(0.5)
... self.linear1 = nn.Linear(8,2)
... def forward(self,x):
... out = self.dropout(self.linear(x))
... out = self.linear1(out)
... return out
...
>>> net = Net()
>>> net.state_dict()
OrderedDict([('linear.weight', tensor([[ 0.1415, -0.2228, -0.1262, 0.0992, -0.1600, 0.0141, -0.1841, -0.1907,0.0295, -0.1853],[-0.0399, -0.2487, -0.3085, 0.1602, 0.3135, 0.1379, 0.0696, 0.0362,-0.1619, -0.0887],[-0.1244, -0.1739, 0.1211, -0.2578, -0.0561, 0.0635, -0.1976, -0.2557,0.1761, 0.2553],[ 0.0912, -0.1469, -0.3012, -0.1583, -0.0028, 0.2697, 0.1947, -0.0596,-0.2144, -0.0785],[-0.1770, 0.0411, 0.1663, 0.1861, 0.2769, 0.0990, 0.1883, -0.1801,0.2727, 0.1219],[-0.1269, 0.0713, 0.2798, 0.1760, 0.0965, 0.1144, 0.2644, 0.0274,0.0034, 0.2702],[ 0.0628, 0.0682, -0.1842, 0.1461, 0.0678, -0.2264, -0.1249, -0.1715,0.1115, 0.2459],[ 0.1198, -0.2584, 0.0234, 0.2756, 0.1174, -0.1212, 0.3024, -0.2304,-0.2950, 0.0970]])), ('linear.bias', tensor([-0.3036, -0.1933, 0.2412, 0.3137, -0.3007, 0.2386, -0.1975, 0.3127])), ('linear1.weight', tensor([[-0.1725, 0.3027, 0.1985, 0.1394, -0.1245, 0.2913, 0.0136, 0.1633],[-0.1558, -0.0865, -0.3032, 0.1374, 0.2967, -0.2886, 0.0430, -0.1246]])), ('linear1.bias', tensor([-0.1232, -0.0690]))])
>>>
参考:PyTorch中model.state_dict(),model.modules(),model.children(),model.named_children()等含义_yaoyz105的博客-CSDN博客_model.state_dict()
model.parameters()与model.state_dict() - 知乎
pytorch - state_dict() , parameters() 详解相关推荐
- pytorch MSELoss参数详解
pytorch MSELoss参数详解 import torch import numpy as np loss_fn = torch.nn.MSELoss(reduce=False, size_av ...
- pytorch实战:详解查准率(Precision)、查全率(Recall)与F1
pytorch实战:详解查准率(Precision).查全率(Recall)与F1 1.概述 本文首先介绍了机器学习分类问题的性能指标查准率(Precision).查全率(Recall)与F1度量,阐 ...
- PyTorch Python API详解大全(持续更新ing...)
诸神缄默不语-个人CSDN博文目录 具体内容以官方文档为准. 最早更新时间:2021.4.23 最近更新时间:2023.1.9 文章目录 0. 常用入参及函数统一解释 1. torch 1.1 Ten ...
- 【小白学PyTorch】10.pytorch常见运算详解
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 <<小白学PyTorch>> 参考目录: ...
- 3 矩阵运算_小白学PyTorch——pytorch常见运算详解
公众号关注 "DL-CVer" 设为 "星标",DLCV消息即可送达! 参考目录: 1 矩阵与标量 2 哈达玛积 3 矩阵乘法 4 幂与开方 5 对数运算 6 ...
- python训练手势分类器_使用Pytorch训练分类器详解(附python演练)
[前言]:你已经了解了如何定义神经网络,计算loss值和网络里权重的更新.现在你也许会想数据怎么样? 目录: 一.数据 二.训练一个图像分类器 使用torchvision加载并且归一化CIFAR10的 ...
- PyTorch 的 Autograd详解
↑ 点击蓝字 关注视学算法 作者丨xiaopl@知乎 来源丨https://zhuanlan.zhihu.com/p/69294347 编辑丨极市平台 PyTorch 作为一个深度学习平台,在深度学习 ...
- Pytorch LSTM初识(详解LSTM+torch.nn.LSTM()实现)1
pytorch LSTM1初识 目录 pytorch LSTM1初识 一.LSTM简介1
- win10开始不显示python_win10从零安装配置pytorch全过程图文详解
1.安装anaconda (anaconda内置python在内的许多package,所以不用另外下载python) 可以点击下面的清华开源软件镜像站,在官网下载anaconda不如在这下的快 htt ...
最新文章
- leetcode算法题--最后一块石头的重量 II★
- LiteDB源码解析系列(3)索引原理详解
- 做了44年保洁员,一生只会5个字,她却成为香港大学院士
- 各种抠图动态图片_不用手。自动、智能抠图,图片去背景
- 人口增长(信息学奥赛一本通-T1070)
- 高级定时器的各种框图和HAL库重要结构
- Struts入门经验(二)
- excel表中怎么插入visio_如何插入或 Visio 中粘贴的 Excel 工作表-阿里云开发者社区...
- 一起来学SpringBoot | 第二篇:SpringBoot配置详解
- eclipse报错: Unhandled event loop exception No more handles
- iOS UIDatePicker
- 08服务器端口映射,windows_Server_2008_R2_NAT服务器_端口映射.pdf
- 随机过程及其在金融领域中的应用 第二章 习题 及 答案
- 劲爆!群晖docker视频
- 流媒体弱网优化之路(FEC+mediasoup)——mediasoup的Nack优化以及FEC引入
- vue element-ui实现金额数字添加千分位并保留两位小数
- 用blockly制作诗词学习游戏
- getshell之Nexus远程命令执行(CVE-2020-10199)
- C#连接Access数据库(详解)
- Vue 2 即将成为过去
热门文章
- Web课程设计——“念念手账”网页APP制作
- java计算机毕业设计高校共享单车管理系统(附源码、数据库)
- ADAS“中国战事”升级
- 双边滤波(Bilateral filter)原理介绍及matlab程序实现
- 极智AI | Pytorch 中常用乘法的 TensorRT 实现
- 中软国际用一场自我进化,推动云市场跨入下一幕
- 腾讯 AI Lab 2018年度回顾
- 没错,这就是AIR-CT2504-K9的内心!
- iframe 在firefox火狐浏览器 动态获取内容不展示问题
- C++无依赖库的websocket实现