文章目录

  • 1.前言
  • 2.torch.save(保存模型)
  • 3.torch.load整个网络
  • 4.torch.load网络参数(只提取参数)
  • 5.调用三个函数

1.前言

训练好了一个模型, 我们当然想要保存它, 留到下次要用的时候直接提取直接用,下面我将来讲如何存储训练好的模型参数

2.torch.save(保存模型)

首先,先搭建一个神经网络

import torch
from torch import nn
import matplotlib.pyplot as plt
torch.manual_seed(11)    # 使每次得到的随机数是固定的。但是如果不加上torch.manual_seed这个函数调用的话,打印出来的随机数每次都不一样x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # [100] -> [100,1]
y = x.pow(2) + 0.5*torch.rand(x.size())  # y的形状与x一样def make_and_save_model():network = torch.nn.Sequential(torch.nn.Linear(1, 8),torch.nn.ReLU(),torch.nn.Linear(8, 1))optimizer = torch.optim.SGD(network.parameters(), lr=0.3)   #优化器criterion = torch.nn.MSELoss()     #损失函数# 训练for i in range(200):prediction = network(x)      #数据放入模型后得到预测值loss = criterion(prediction, y)    #计算预测值与真实值之间的误差optimizer.zero_grad()       #清空梯度loss.backward()           #误差反向传播optimizer.step()          #更新参数torch.save(network, 'network.pth')  # 保存整个网络torch.save(network.state_dict(), 'network_params.pth')   # 只保存网络中的参数plt.figure(1, figsize = (10,3))plt.subplot(131)plt.title('network')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5)plt.pause(1)

3.torch.load整个网络

这种方式将会提取整个神经网络, 网络大的时候可能会比较慢.

def load_whole_model():network_whole = torch.load('network.pth')prediction = network_whole(x)plt.figure(1, figsize = (10,3))plt.subplot(132)plt.title('network_whole')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5)plt.pause(1)

4.torch.load网络参数(只提取参数)

这种方式将会提取所有的参数, 然后再放到你的新建网络中

def load_only_params():network_params = torch.nn.Sequential(torch.nn.Linear(1, 8),torch.nn.ReLU(),torch.nn.Linear(8, 1))network_params.load_state_dict(torch.load('network_params.pth'))prediction = network_params(x)plt.figure(1, figsize = (10,3))plt.subplot(133)plt.title('network_params')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5)

5.调用三个函数

会看到加载后的模型画出的图是一样的,说明模型的参数正确加载了。

make_and_save_model()
load_whole_model()
load_only_params()

Pytorch——保存训练好的模型参数相关推荐

  1. python训练模型函数参数_keras读取训练好的模型参数并把参数赋值给其它模型详解...

    介绍 本博文中的代码,实现的是加载训练好的模型model_halcon_resenet.h5,并把该模型的参数赋值给两个不同的新的model. 函数式模型 官网上给出的调用一个训练好模型,并输出任意层 ...

  2. 转载:tensorflow保存训练后的模型

    训练完一个模型后,为了以后重复使用,通常我们需要对模型的结果进行保存.如果用Tensorflow去实现神经网络,所要保存的就是神经网络中的各项权重值.建议可以使用Saver类保存和加载模型的结果. 1 ...

  3. python如何保存训练好的模型_Python机器学习7:如何保存、加载训练好的机器学习模型...

    本文将介绍如何使用scikit-learn机器学习库保存Python机器学习模型.加载已经训练好的模型.学会了这个,你才能够用已有的模型做预测,而不需要每次都重新训练模型. 本文将使用两种方法来实现模 ...

  4. [pytorch、学习] - 4.2 模型参数的访问、初始化和共享

    参考 4.2 模型参数的访问.初始化和共享 在3.3节(线性回归的简洁实现)中,我们通过init模块来初始化模型的参数.我们也介绍了访问模型参数的简单方法.本节将深入讲解如何访问和初始化模型参数,以及 ...

  5. pytorch 保存、加载模型

    一般保存为.pt格式,保存模型使用: torch.save(model, '保存位置') 加载模型使用: model_load = torch.load('加载模型的位置') 完整代码 import ...

  6. 保存训练好的模型并调用

    当我们训练好一个model后,下次如果还想用这个model,我们就需要把这个model保存下来,下次直接导入就好了,不然每次都跑一遍,训练时间短还好,要是一次跑好几天的那怕是要天荒地老了. sciki ...

  7. libsvm 训练后,模型参数详解

    本节主要就是讲解利用libsvm-mat工具箱建立分类(回归模型)后,得到的模型model里面参数的意义,以及如果通过model得到相应模型的表达式,这里主要以分类问题为例子. 测试数据使用的是lib ...

  8. pytorch 使用训练好的模型预测新数据

    神经网络在进行完训练和测试后,如果达到了较高的正确率的话,我们可以尝试将模型用于预测新数据.总共需要两大部分:神经网络.预测函数(新图片的加载,传入模型.得出结果). 完整代码 import torc ...

  9. 【项目实战】pytorch载入训练好的模型并进行可视化模型预测绘图

    main函数载入模型,加载图片,输出结果: if __name__ == '__main__':image = Image.open(r"C:\Users\pic\test\he_5.jpg ...

最新文章

  1. ChaLearn Gesture Challenge_2:examples体验
  2. 九宫格抽奖转盘源码分析
  3. 产品经理必备知识之如何用CREATE模型对用户进行行为分析
  4. 在网页输出10的阶乘.php,ASP网络程序设计实验报告和期末考试复习范围
  5. NoSql中的B-tree、B+tree和LSM-tree
  6. spark集群详细搭建过程及遇到的问题解决(四)
  7. soul群聊显示服务器异常,soul群聊状态是什么
  8. Kubernetes学习总结(15)—— Kubernetes 实战之部署 Mysql 集群
  9. mysql显示表已存在_「Docker系列」 如何在Docker中部署MySQL数据库?
  10. Whitelabel Error Page : spring boot项目启动后,无法访问@RequestMapping标注的请求
  11. 安装anaconda,jupyter基本操作说明快捷键使用
  12. row_number()分页返回结果顺序不确定
  13. [转]jQuery知识总结
  14. 算法:数组找出2个只出现一次的数字(其他元素出现两次)
  15. ST-Link下载 KELL5程序下载 STM32程序下载
  16. java下cmyk图片读取和转换rgb,以及图片压缩
  17. iphone模拟器安装app
  18. 用计算机数字技术制作的电影是,计算机数字技术为电影带来的空前发展.doc
  19. Carte作为Windows服务
  20. 自动驾驶专题介绍 ———— 转向系统

热门文章

  1. java 同步 set_Java Collections synchronizedSet()用法及代码示例
  2. JSON与localStorage的爱恨情仇
  3. ServiceLoader用法demo
  4. 大巴山计算机教育中心那所学校,大巴山计算机教育中心
  5. TCP/IP模型的简单解释
  6. java jdbc 占位符_java-jdbc
  7. linux lvs 存储层,LVS集群配置之LVS介绍
  8. 迁移是10g-11g ogg正好有用武之地N种方法
  9. 【图】二分图最大权匹配
  10. 使用java.util.zip包实现根据文件目录控制文件的压缩与解压