Pytorch state_dict介绍
Introduce
在pytorch中,torch.nn.Module模块中的state_dict变量存放训练过程中需要学习的权重和偏执系数,state_dict作为python的字典对象将每一层的参数映射成tensor张量,需要注意的是torch.nn.Module模块中的state_dict只包含卷积层和全连接层的参数,当网络中存在batchnorm时,例如vgg网络结构,torch.nn.Module模块中的state_dict也会存放batchnorm's running_mean,关于batchnorm详解可见https://blog.csdn.net/wzy_zju/article/details/81262453
torch.optim模块中的Optimizer优化器对象也存在一个state_dict对象,此处的state_dict字典对象包含state和param_groups的字典对象,而param_groups key对应的value也是一个由学习率,动量等参数组成的一个字典对象。
因为state_dict本质上Python字典对象,所以可以很好地进行保存、更新、修改和恢复操作(python字典结构的特性),从而为PyTorch模型和优化器增加了大量的模块化。
Sample
通过一个简单的案例来输出state_dict字典对象中存放的变量
#encoding:utf-8
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as mp
import matplotlib.pyplot as plt
import torch.nn.functional as F
#define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass,self).__init__()
self.conv1=nn.Conv2d(3,6,5)
self.pool=nn.MaxPool2d(2,2)
self.conv2=nn.Conv2d(6,16,5)
self.fc1=nn.Linear(16*5*5,120)
self.fc2=nn.Linear(120,84)
self.fc3=nn.Linear(84,10)
def forward(self,x):
x=self.pool(F.relu(self.conv1(x)))
x=self.pool(F.relu(self.conv2(x)))
x=x.view(-1,16*5*5)
x=F.relu(self.fc1(x))
x=F.relu(self.fc2(x))
x=self.fc3(x)
return x
def main():
# Initialize model
model = TheModelClass()
#Initialize optimizer
optimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
#print model's state_dict
print('Model.state_dict:')
for param_tensor in model.state_dict():
#打印 key value字典
print(param_tensor,'\t',model.state_dict()[param_tensor].size())
#print optimizer's state_dict
print('Optimizer,s state_dict:')
for var_name in optimizer.state_dict():
print(var_name,'\t',optimizer.state_dict()[var_name])
if __name__=='__main__':
main()
具体的输出结果如下:可以很清晰的观测到state_dict中存放的key和value的值
Model.state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
Optimizer,s state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [367949288, 367949432, 376459056, 381121808, 381121952, 381122024, 381121880, 381122168, 381122096, 381122312]}]
Pytorch state_dict介绍相关推荐
- pytorch学习笔记(九):PyTorch结构介绍
PyTorch结构介绍 对PyTorch架构的粗浅理解,不能保证完全正确,但是希望可以从更高层次上对PyTorch上有个整体把握.水平有限,如有错误,欢迎指错,谢谢! 几个重要的类型 和数值相关的 T ...
- pytorch - state_dict() , parameters() 详解
目录 1 parameters() 1.1 model.parameters(): 1.2 model.named_parameters(): 2 state_dict() torch.nn.Modu ...
- Pytorch:visdom介绍
一.介绍 在深度学习领域,模型训练是一个必须的过程,因此常常需要实时监听并可视化一些数据,如损失值loss,正确率acc等.在Tensorflow中,最常使用的工具非Tensorboard莫属:在Py ...
- [转载] Pytorch基础介绍
参考链接: PyTorch的基础 Pytorch安装.https://zhuanlan.zhihu.com/p/26871672. Pytorch中文文档.https://pytorch-cn.rea ...
- PyTorch框架:(1)基本处理操作
目录 1.PyTorch框架介绍 2.安装Pytorch 2.1.CPU版本的安装命令: 2.2.GPU版本的安装命令: 2.2.1.安装CUDA 3.基本使用方法 4.Pytorch中的自动求导机制 ...
- PyTorch 深度剖析:如何保存和加载PyTorch模型?
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨科技猛兽 编辑丨极市平台 导读 本文详解了PyTorch 模型 ...
- Pytorch学习 - 保存模型和重新加载
Pytorch学习 - 保存和加载模型 1. 3个函数 2. 模型不同后缀名的区别 3. 保存和重载模型 参考资料: Pytorch官方文档链接 某博客 1. 3个函数 torch.save() : ...
- CNN入门+猫狗大战(Dogs vs. Cats)+PyTorch入门
一些修改(修改后的代码) 修改原网络的输出方式.原网络采用的交叉熵torch.nn.CrossEntropyLoss()进行Loss计算,而这个函数内部是已经进行了softmax处理的(参考),所以网 ...
- pytorch入门使用及前置知识(2)NLP
1 深度学习的介绍 深度学习(deep learning) 是机器学习的分支,是一种以人工神经网络为架构,对数据进行特征学习的算法. 机器学习是人工智能的一种实现方式,深度学习是机器学习中的一 ...
- pytorch创建模型并训练(初探文本分类问题)
本博客对pytorch在深度学习上的使用进行了介绍,本博客并不会对怎么训练一个好的模型进行介绍(其实我也不会),我觉得训练一个好的模型首先得选对一个模型(关键的问题在于模型如何设计),然后再经 ...
最新文章
- 知乎热议:高数、线代应该成为计算机专业学习的重心吗?
- python中参数的位置传递和名称传递各有什么优缺点_Python开发TCP和UDP的区别是什么?优缺点对比总结...
- 变量命名规范 匈牙利 下划线 骆驼 帕斯卡
- python常用内置模块-Python常用内置模块之xml模块
- 印度式画线乘法基本操作
- [GitHub] 75+的 C# 数据结构和算法实现
- 话里话外:中小型装备制造企业竞争优势构建之路
- 调试远程服务器上的代码时报错:调试设置中的Python路径无效
- 如何挖掘评论中的关键信息
- cout和printf的区别
- Android 真实 简历
- 中国十大芯片企业排名
- linux与ipad传输文件,实用!三种iPhone与Windows电脑互传文件操作技巧,建议收藏...
- promise、axios 理解
- 武汉星起航跨境:跨境电商新蓝海,南非跨境电商市场迸发活力
- layui搜索重置功能
- 学习gitlab-runner
- Adobe photoshop工具箱工具名称中英文对照
- 菲律宾德拉斯大学计算机专业,2020年菲律宾大学以及各专业排行榜
- Django urls 下划线的坑-Using the URLconf defined in xxx, Django tried these URL patterns, in thi