pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

备注:

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)
2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

-------------------------------------------------------------------------------------------------------------------------------

模态字典(state_dict)的保存(model是一个网络结构类的对象)

1.1)仅保存学习到的参数,用以下命令

torch.save(model.state_dict(), PATH)

1.2)加载model.state_dict,用以下命令

model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.eval()

备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

-----------

2.1)保存整个model的状态,用以下命令

torch.save(model,PATH)

2.2)加载整个model的状态,用以下命令:

# Model class must be defined somewhere

model = torch.load(PATH)

model.eval()

--------------------------------------------------------------------------------------------------------------------------------------

state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

----------------------------------------------------------------------------------------------------------------------

如何仅加载某一层的训练的到的参数(某一层的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
--------------------------------------------------------------------------------------------

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

for param in list(model.pretrained.parameters()):
param.requires_grad = False
注意: requires_grad的操作对象是tensor.

疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:经测试,不可以.model.conv1 没有requires_grad属性.

---------------------------------------------------------------------------------------------

全部测试代码:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass,self).__init__()
self.conv1 = nn.Conv2d(3,6,5)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)

def forward(self,x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1,16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

# initial model
model = TheModelClass()

#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor,'\t',model.state_dict()[param_tensor].size())

print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
print(var_name,'\t',optimizer.state_dict()[var_name])

print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)

print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)

model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)
---------------------
作者:wzg2016
来源:CSDN
原文:https://blog.csdn.net/strive_for_future/article/details/83240081
版权声明:本文为博主原创文章,转载请附上博文链接!

pytorch 状态字典:state_dict 模型和参数保存相关推荐

  1. pytorch 状态字典:state_dict

    pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等) (注意,只有那些参数可以训练的l ...

  2. PyTorch:存储和恢复模型并查看参数,load_state_dict(),state_dict()

    # save torch.save(model.state_dict(), PATH)# load model = MyModel(*args, **kwargs) model.load_state_ ...

  3. python保存模型与参数_基于pytorch的保存和加载模型参数的方法

    当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了. 保存和加载模型参数有两种方式: 方式一: torc ...

  4. pytorch保存模型pth_pytorch中保存的模型文件.pth深入解析

    前言:前面有专门的讲解关于如何深入查询模型的参数信息,可以参考这篇文章: 沈鹏燕:pytorch教程之nn.Module类详解​zhuanlan.zhihu.com 本次来解析一下我们通常保存的模型文 ...

  5. python保存模型与参数_Pytorch - 模型和参数的保存与恢复

    模型训练后,需要保存到文件,以供测试和部署:或,继续之前的训练状态. 1. Best Practices 主要有两种模型序列化保存和加载恢复的方法. 1.1 方法 M1 - 推荐 只保存和加载恢复模型 ...

  6. python保存模型与参数_如何导出python中的模型参数

    模型的保存和读取 1.tensorflow保存和读取模型:tf.train.Saver() .save()#保存模型需要用到save函数 save( sess, save_path, global_s ...

  7. Keras保存和载入训练好的模型和参数

    1.保存模型 my_model = create_model_function( ...... )my_model.compile( ...... )my_model.fit( ...... )mod ...

  8. Unity-Live2d(模型参数设置,当前参数保存与恢复所保存参数, 部分位置透明度设置,自动眨眼)

    Unity-Live2D 概述:这是我学习Unity中Live2d的相关操作的一个笔记,欢迎各位同好和大牛的指点.(参考siki学院出的视频学的) 模型参数设置 先来说一下这个模型参数是个什么东西,之 ...

  9. 使用PyTorch加载模型部分参数方法

    前言 在深度学习领域,经常需要使用其他人已训练好的模型进行改进或微调,这个时候我们通常会希望加载预训练模型文件的参数,如果网络结构不变,只需要使用load_state_dict方法即可.而当我们改动网 ...

最新文章

  1. C# Byte数组与Int16数组之间的转换
  2. Django-路由控制
  3. 鼠标键盘唤醒计算机,除了按下电源按钮唤醒计算机,WIN10也可以使用鼠标或键盘来唤醒...
  4. 极市分享|第32期 张德兵小美:分布式人脸识别及工业级运用经验
  5. 使用Java实现K-Means聚类算法
  6. 一作发14篇SCI,累计IF60,博士前两年,他也曾走过弯路
  7. java笔记:熟练掌握线程技术---基础篇之解决资源共享的问题(中)--前篇
  8. Swift开发之NSStringFromClass的使用和代替方法
  9. java泛型的英文_Java泛型一:泛型的定义及规则
  10. vue3移动端腾讯地图坐标拾取,获取当前定位(腾讯、高德、百度、天地图),火星坐标GCJ-02–>百度坐标BD-09,根据坐标经纬度计算两点距离的方法,点击链接打开地图导航的方法
  11. 【转载】Unity3D研究院之静态自动检查代码缺陷与隐患
  12. 实验项目一 俄罗斯方块游戏
  13. Django详细教程(图文)
  14. Linux 系统设置 : insmod 命令详解
  15. Linux下文件名乱码的解决方法
  16. 中国“苹果皮”之父:希望与苹果公司展开合作
  17. win10应用商店打不开_微软上架新版QQ,秒杀正版!升级win10,体验超越原版的自带应用...
  18. 享元模式 - Unity
  19. b374k php webshell
  20. 查看IC中文文档的网站

热门文章

  1. 2022-2028年中国专用化学品行业投资分析及前景预测报告
  2. 2022-2028年中国再生天然橡胶行业市场调查分析及未来前景分析报告
  3. jquery autocomplete demo
  4. 判别模型和生成模型的区别
  5. TVM性能评估分析(四)
  6. Kubeedge Edged概述
  7. 使用NVIDIA A100 TF32获得即时加速
  8. 处理器解决物联网和人工智能的融合
  9. 将深度学习低延迟推理性能提高一倍
  10. Linux内存技术分析(下)