目录

0. 前言

1. Pytorch框架加载与保存权重的方法

2. 实例问题说明

3. 加载权重数据

4. 保存权重数据


0. 前言

在深度学习实际应用中,往往涉及到的神经元网络模型都很大,权重参数众多,因此会导致训练epoch次数很多,训练时间长。

如果每次调整非模型相关的参数(训练数据集、优化函数类型、学习率、迭代次数)都要重新训练一次模型,这显然会浪费大量的训练时间。

而且,对于一些成熟的网络模型,已经有前人做过大量的“预训练”,这时如果能基于前人预训练的结果,训练自己的数据集,明显会事半功倍。

因此,加载与保存权重在深度学习实际使用中有很大的必要。

1. Pytorch框架加载与保存权重的方法

①加载权重的方法: .load_state_dict()方法说明:

.load_state_dict()定义:

def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',strict: bool = True):

- state_dict :即要加载的权重,通常是一个文件地址;

- strick: 可以理解为等于"True"时是“精确匹配”,要求要加载的权重与要被加载权重的模型完全匹配。

Pytorch源文件注释:

Args:state_dict (dict): a dict containing parameters andpersistent buffers.strict (bool, optional): whether to strictly enforce that the keysin :attr:`state_dict` match the keys returned by this module's:meth:`~torch.nn.Module.state_dict` function. Default: ``True``

*小注释:meth笔误了,应该是mesh,网格

②保存权重的方法:.save()方法说明:

.save()定义:

def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]],pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None:

- obj:要保存的权重参数;

- f:保存的文件路径

这里仅说明.save()在保存网络模型权重数据上的作用。实际上.save()还有很多应用,例如:保存整个网络,这里不再赘述。

Pytorch源文件注释:

"""Saves an object to a disk file.See also: `saving-loading-tensors`Args:obj: saved objectf: a file-like object (has to implement write and flush) or a string oros.PathLike object containing a file namepickle_module: module used for pickling metadata and objectspickle_protocol: can be specified to override the default protocol

2. 实例问题说明

首先说明本次的实例问题:本次要构建的神经元网络为一个“平方网络”,即网络输出数据为输入数据的平方。

网络模型结构:

输入(1)→全连接层(1×5)→Sigmoid激活函数(5)→全连接层(5×5)→Sigmoid激活函数(5)→全连接层(5×1)→输出(1)

训练数据:

输入数据[1, 2, 3, 4, 5];

输出数据[1, 4, 9, 16, 25]

3. 加载权重数据

直接上代码

import torchclass LinearNet(torch.nn.Module):def __init__(self, input_size, output_size):super().__init__()self.net = torch.nn.Sequential(torch.nn.Linear(in_features=input_size, out_features= 5, bias=True),torch.nn.Sigmoid(),torch.nn.Linear(in_features= 5, out_features=5, bias=True),torch.nn.Sigmoid(),torch.nn.Linear(in_features=5, out_features=output_size, bias=True))def forward(self,x):return self.net(x)square_net = LinearNet(1,1)square_net.load_state_dict(torch.load('weight.pth'))  #直接加载已经训练好的权重if __name__ == '__main__':print(square_net(torch.tensor([3.16],dtype=torch.float32)))

其中weight.pth是我已经训练好的权重数据路径,这里定义好网络模型后,直接加载权重数据,不必关心这个权重是如何训练来的,更不必关系具体权重的值是多少。测试输入为3.16输出为:

tensor([9.9180], grad_fn=<AddBackward0>)

这里要注意的是:因为上面strict默认为True,即为“精确匹配”,这里新构建的网络模型结构必须和权重来源的网络模型结构相同。

4. 保存权重数据

import torchinput = torch.tensor([[1],[2],[3],[4],[5]], dtype=torch.float32)
output = torch.tensor([[1],[4],[9],[16],[25]], dtype=torch.float32)class LinearNet(torch.nn.Module):def __init__(self, input_size, output_size):super().__init__()self.net = torch.nn.Sequential(torch.nn.Linear(in_features=input_size, out_features= 5, bias=True),torch.nn.Sigmoid(),torch.nn.Linear(in_features= 5, out_features=5, bias=True),torch.nn.Sigmoid(),torch.nn.Linear(in_features=5, out_features=output_size, bias=True))def forward(self,x):return self.net(x)Loss = torch.nn.MSELoss()
linear_net = LinearNet(1,1)
opt = torch.optim.SGD(linear_net.parameters(), lr= 0.003)for k in range(1000):opt.zero_grad()for i in range(len(input)):train_out = linear_net(input[i])loss = Loss(train_out, output[i])loss.backward()opt.step()torch.save(linear_net.state_dict(),'weight.pth')   #保存.pth权重文件for keys,values in linear_net.state_dict().items():   #查看权重名称及值print(keys)print(values)print('************************************************************************')if __name__ == '__main__':print(linear_net(torch.tensor([3.16],dtype=torch.float32)))

这里可以看到具体的训练过程及相关的训练参数,权重保存在'weight.pth'文件中。

可以通过print查看具体的权重数值:

net.0.weight
tensor([[-0.8204],[-1.7341],[-0.6987],[ 0.9370],[-1.5558]])
************************************************************************
net.0.bias
tensor([ 0.9285,  2.1061,  1.0247, -2.9221,  7.1159])
************************************************************************
net.2.weight
tensor([[-1.6075, -1.3072, -1.5342,  2.4527, -3.9922],[-0.7101, -1.5125, -0.6791,  2.0325, -2.3406],[-1.1707, -1.6899, -0.9883,  2.9682, -1.5409],[-1.1992, -2.0559, -0.7610,  2.3890, -1.3782],[-1.1274, -1.7907, -1.0860,  2.3549, -3.6847]])
************************************************************************
net.2.bias
tensor([0.4826, 0.7057, 0.9702, 1.0532, 0.4214])
************************************************************************
net.4.weight
tensor([[7.3601, 4.7667, 6.2473, 5.0187, 7.2028]])
************************************************************************
net.4.bias
tensor([-0.2476])
************************************************************************

通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()相关推荐

  1. Pytorch加载模型只导入部分层权重,即跳过指定网络层的方法

    需求 Pytorch加载模型时,只导入部分层权重,跳过部分指定网络层.(权重文件存储为dict形式) 方法一 常见方法:加载权重时用if对网络层进行筛选 ''' # model为定义的网络结构: cl ...

  2. Pytorch 加载和保存模型

    目录 保存和加载模型 1.  什么是状态字典:state_dict? 2.保存和加载推理模型 2.1 保存/加载 state_dict (推荐使用) 2.2 保存/加载完整模型 3. 保存和加载 Ch ...

  3. pytorch加载模型报错Unexpected key(s) in state_dict: module.conv1.weight, module.bn1

    文章目录 背景 报错 原因 解决 背景 Pytorch在加载模型参数的时候,有两种情况可能出现这种问题: 自己写的网络结构,例如: 代码 import models arch = 'resnet50' ...

  4. pytorch 驱动不兼容_解决Pytorch 加载训练好的模型 遇到的error问题

    这是一个非常愚蠢的错误 debug的时候要好好看error信息 提醒自己切记好好对待error!切记!切记! -----------------------分割线---------------- py ...

  5. 用pytorch加载训练模型

    用pytorch加载.pth格式的训练模型 在pytorch/vision/models网页上有很多现成的经典网络模型可以调用,其中包括alexnet.vgg.googlenet.resnet.inc ...

  6. Pytorch加载模型并进行图像分类预测

    目录 1. 整体流程 1)实例化模型 2)加载模型 3)输入图像 4)输出分类结果 5)完整代码 2. 处理图像 1) How can i convert an RGB image into gray ...

  7. pytorch加载自己的图片数据集的两种方法

    目录 ImageFolder 加载数据集 使用pytorch提供的Dataset类创建自己的数据集. Dataset加载数据集 接下来我们就可以构建我们的网络架构: 训练我们的网络: 保存网络模型(这 ...

  8. pytorch加载训练数据集dataloader操作耗费时间太久,该如何解决?

    笔者在使用pytorch加载训练数据进行模型训练的时候,发现数据加载需要耗费太多时间,该如何缩短数据加载的时间消耗呢?经过查询相关文档,总结实际操作过程如下: 1.尽量将jpg等格式的文件保存为bmp ...

  9. Pytorch加载torchvision从本地下载好的预训练模型的简单解决方案

      大家好,我是爱编程的喵喵.双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中.从事机器学习以及相关的前后端开发工作.曾在阿里云.科大讯飞.CCF等比赛获得多次Top名次.喜 ...

最新文章

  1. 对象检测工具包mmdetection简介、安装及测试代码
  2. 决定系数 均方误差mse_回归模型评价指标 SSE, MSE、RMSE、MAE、R-SQUARED
  3. struts2文件上传一个错误的解决
  4. java--模板方法模式
  5. “约见”面试官系列之常见面试题第三十二篇之async和await(建议收藏)
  6. LeetCode-665:非递减数列
  7. TOSCA自动化测试工具视频资料
  8. 系统登录界面(收集)
  9. gdb调试出现optimized out解决方法
  10. firefox 53支持java_JavaSelenium 2.53在Firefox 47上不起作用
  11. 电脑重装操作系统——使用U盘安装(简略步骤)
  12. 算法题打卡-超人进化(剑指offer第一天)
  13. 本地管理表空间(LMT)与自动段空间管理(ASSM)概念(未看)
  14. PPP概念股一览 PPP概念股盈利预测
  15. Vi/Vim 编辑器常见命令
  16. 氚云徐平俊:低代码赛道热度陡升,今年增长目标200%
  17. winserver-记录共享文件夹操作日志
  18. 这个截图神器,能轻松碾压QQ和微信。。。
  19. 英语语法总结--虚拟语气
  20. 高光谱数据集下载Indian_pines, Salinas, Pavia Centre and University, Cuprite, Kennedy Space Center, Botswana等

热门文章

  1. 【MySQL】新版本特性
  2. 如何解决win11“无法枚举容器中的对象,访问被拒绝”、“右键新建只有文件夹,没有其他选项”的问题。
  3. Python OJ输入输出
  4. ES6新特性总结(2)解构赋值、模板字符串、Symbol
  5. 如何计算字符串的字节长度
  6. JS元素的提取,删除 ,添加,修改
  7. C#入门4——计算自由落体运动
  8. IDEA-Translation3.0插件右键无文档翻译解决
  9. Linux应用开发入门
  10. Web前端技术 Web学习资料 Web学习路线 Web入门宝典(不断更新中)