pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

备注:

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

-------------------------------------------------------------------------------------------------------------------------------

模态字典(state_dict)的保存(model是一个网络结构类的对象)

1.1)仅保存学习到的参数,用以下命令

    torch.save(model.state_dict(), PATH)

1.2)加载model.state_dict,用以下命令

    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.eval()

    备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

-----------

2.1)保存整个model的状态,用以下命令

    torch.save(model,PATH)

2.2)加载整个model的状态,用以下命令:

# Model class must be defined somewhere

    model = torch.load(PATH)

    model.eval()

--------------------------------------------------------------------------------------------------------------------------------------

state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

----------------------------------------------------------------------------------------------------------------------

如何仅加载某一层的训练的到的参数(某一层的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

--------------------------------------------------------------------------------------------

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

for param in list(model.pretrained.parameters()):param.requires_grad = False

注意: requires_grad的操作对象是tensor.

疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:经测试,不可以.model.conv1 没有requires_grad属性.

---------------------------------------------------------------------------------------------

全部测试代码:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim# define model
class TheModelClass(nn.Module):def __init__(self):super(TheModelClass,self).__init__()self.conv1 = nn.Conv2d(3,6,5)self.pool = nn.MaxPool2d(2,2)self.conv2 = nn.Conv2d(6,16,5)self.fc1 = nn.Linear(16*5*5,120)self.fc2 = nn.Linear(120,84)self.fc3 = nn.Linear(84,10)def forward(self,x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1,16*5*5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# initial model
model = TheModelClass()#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():print(param_tensor,'\t',model.state_dict()[param_tensor].size())print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():print(var_name,'\t',optimizer.state_dict()[var_name])print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

pytorch 状态字典:state_dict相关推荐

  1. pytorch 状态字典:state_dict 模型和参数保存

    pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等) (注意,只有那些参数可以训练的l ...

  2. PyTorch | 保存和加载模型教程

    点击上方"算法猿的成长",选择"加为星标" 第一时间关注 AI 和 Python 知识 图片来自 Unsplash,作者: Jenny Caywood 2019 ...

  3. PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard

    文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...

  4. coggle11月打卡—pytorch与CV竞赛

    文章目录 任务1:PyTorch张量计算与Numpy的转换 任务2:梯度计算和梯度下降过程 1.学习自动求梯度原理 1.1 pytorch自动求导初步认识 1.2 tensor的创建与属性设置 1.3 ...

  5. python 神经网络 多进程_Pytorch多进程最佳实践

    预备知识 模型并行( model parallelism ):即把模型拆分放到不同的设备进行训练,分布式系统中的不同机器(GPU/CPU等)负责网络模型的不同部分 -- 例如,神经网络模型的不同网络层 ...

  6. pytorch - state_dict() , parameters() 详解

    目录 1 parameters() 1.1 model.parameters(): 1.2 model.named_parameters(): 2 state_dict() torch.nn.Modu ...

  7. 使用pytorch构建图片分类器

    分类器任务和数据介绍 构造一个将不同图像进行分类的神经网络分类器, 对输入的图片进行判别并完成分类. 本案例采用CIFAR10数据集作为原始图片数据. CIFAR10数据集介绍: 数据集中每张图片的尺 ...

  8. python torch exp_Python:PyTorch 保存和加载训练过的网络 (八十)

    保存和加载模型 在这个 notebook 中,我将为你展示如何使用 Pytorch 来保存和加载模型.这个步骤十分重要,因为你一定希望能够加载预先训练好的模型来进行预测,或是根据新数据继续训练. %m ...

  9. PyTorch模型的保存加载以及数据的可视化

    文章目录 PyTorch模型的保存和加载 模块和张量的序列化和反序列化 模块状态字典的保存和载入 PyTorch数据的可视化 TensorBoard的使用 总结 PyTorch模型的保存和加载 在深度 ...

最新文章

  1. 产品Backlog(Product Backlog)是什么?
  2. SQL SERVER全面优化-------索引有多重要?
  3. php商品分类显示商品,ecshop首页显示全部商品分类的方法
  4. Sybase数据库优化手册
  5. Callback Functions Tutorial
  6. Java程序员从笨鸟到菜鸟之(五十八)细谈Hibernate(九)hibernate一对一关系映射...
  7. 隐藏nginx 版本号信息(转)
  8. axure数据报表元件库_axure图表元件库 axure自制的组件库(包括数据组件)
  9. C语言编程方法技巧,C语言编程小技巧分享
  10. python基础——闭包函数和生成器
  11. ios 加速计效果实现
  12. 时间标准 GMT, UTC, CST
  13. ctfshow菜狗杯wp
  14. GIT提示Another git process seems to be running in this repository
  15. Windows禁用端口(445端口为例)
  16. RecyclerView 配合 DiffUtil,RecyclerView局部刷新
  17. 前端构建3D建模知识(css,html)
  18. 如果你喜欢的女孩有了男朋友,但她男朋友比你差很多,怎么办?
  19. zbar android解码错误,Android原生编解码接口 MediaCodec 之——踩坑
  20. Python相关环境变量配置和模拟手机app登录

热门文章

  1. 学海泛舟系列文章开篇语
  2. choco设置后续软件默认安装路径
  3. npm install 安装软件,出现 operation not permitted, mkdir
  4. 总算有人讲明白了什么是特性阻抗什么是阻抗匹配
  5. 5G网络架构与组网部署
  6. BoundsChecker简易使用教程
  7. 云服务器文件打包,如何把云服务器的文件打包出来
  8. 在windows上部署IIS web服务
  9. 东方信息苑c语言,上海市东方社区信息苑一览表.PDF
  10. JavaWeb:request.setAttribute()和session.setAttribute()的区别