Pytorch——保存训练好的模型参数
文章目录
- 1.前言
- 2.torch.save(保存模型)
- 3.torch.load整个网络
- 4.torch.load网络参数(只提取参数)
- 5.调用三个函数
1.前言
训练好了一个模型, 我们当然想要保存它, 留到下次要用的时候直接提取直接用,下面我将来讲如何存储训练好的模型参数
2.torch.save(保存模型)
首先,先搭建一个神经网络
import torch
from torch import nn
import matplotlib.pyplot as plt
torch.manual_seed(11) # 使每次得到的随机数是固定的。但是如果不加上torch.manual_seed这个函数调用的话,打印出来的随机数每次都不一样x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # [100] -> [100,1]
y = x.pow(2) + 0.5*torch.rand(x.size()) # y的形状与x一样def make_and_save_model():network = torch.nn.Sequential(torch.nn.Linear(1, 8),torch.nn.ReLU(),torch.nn.Linear(8, 1))optimizer = torch.optim.SGD(network.parameters(), lr=0.3) #优化器criterion = torch.nn.MSELoss() #损失函数# 训练for i in range(200):prediction = network(x) #数据放入模型后得到预测值loss = criterion(prediction, y) #计算预测值与真实值之间的误差optimizer.zero_grad() #清空梯度loss.backward() #误差反向传播optimizer.step() #更新参数torch.save(network, 'network.pth') # 保存整个网络torch.save(network.state_dict(), 'network_params.pth') # 只保存网络中的参数plt.figure(1, figsize = (10,3))plt.subplot(131)plt.title('network')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5)plt.pause(1)
3.torch.load整个网络
这种方式将会提取整个神经网络, 网络大的时候可能会比较慢.
def load_whole_model():network_whole = torch.load('network.pth')prediction = network_whole(x)plt.figure(1, figsize = (10,3))plt.subplot(132)plt.title('network_whole')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5)plt.pause(1)
4.torch.load网络参数(只提取参数)
这种方式将会提取所有的参数, 然后再放到你的新建网络中
def load_only_params():network_params = torch.nn.Sequential(torch.nn.Linear(1, 8),torch.nn.ReLU(),torch.nn.Linear(8, 1))network_params.load_state_dict(torch.load('network_params.pth'))prediction = network_params(x)plt.figure(1, figsize = (10,3))plt.subplot(133)plt.title('network_params')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5)
5.调用三个函数
会看到加载后的模型画出的图是一样的,说明模型的参数正确加载了。
make_and_save_model()
load_whole_model()
load_only_params()
Pytorch——保存训练好的模型参数相关推荐
- python训练模型函数参数_keras读取训练好的模型参数并把参数赋值给其它模型详解...
介绍 本博文中的代码,实现的是加载训练好的模型model_halcon_resenet.h5,并把该模型的参数赋值给两个不同的新的model. 函数式模型 官网上给出的调用一个训练好模型,并输出任意层 ...
- 转载:tensorflow保存训练后的模型
训练完一个模型后,为了以后重复使用,通常我们需要对模型的结果进行保存.如果用Tensorflow去实现神经网络,所要保存的就是神经网络中的各项权重值.建议可以使用Saver类保存和加载模型的结果. 1 ...
- python如何保存训练好的模型_Python机器学习7:如何保存、加载训练好的机器学习模型...
本文将介绍如何使用scikit-learn机器学习库保存Python机器学习模型.加载已经训练好的模型.学会了这个,你才能够用已有的模型做预测,而不需要每次都重新训练模型. 本文将使用两种方法来实现模 ...
- [pytorch、学习] - 4.2 模型参数的访问、初始化和共享
参考 4.2 模型参数的访问.初始化和共享 在3.3节(线性回归的简洁实现)中,我们通过init模块来初始化模型的参数.我们也介绍了访问模型参数的简单方法.本节将深入讲解如何访问和初始化模型参数,以及 ...
- pytorch 保存、加载模型
一般保存为.pt格式,保存模型使用: torch.save(model, '保存位置') 加载模型使用: model_load = torch.load('加载模型的位置') 完整代码 import ...
- 保存训练好的模型并调用
当我们训练好一个model后,下次如果还想用这个model,我们就需要把这个model保存下来,下次直接导入就好了,不然每次都跑一遍,训练时间短还好,要是一次跑好几天的那怕是要天荒地老了. sciki ...
- libsvm 训练后,模型参数详解
本节主要就是讲解利用libsvm-mat工具箱建立分类(回归模型)后,得到的模型model里面参数的意义,以及如果通过model得到相应模型的表达式,这里主要以分类问题为例子. 测试数据使用的是lib ...
- pytorch 使用训练好的模型预测新数据
神经网络在进行完训练和测试后,如果达到了较高的正确率的话,我们可以尝试将模型用于预测新数据.总共需要两大部分:神经网络.预测函数(新图片的加载,传入模型.得出结果). 完整代码 import torc ...
- 【项目实战】pytorch载入训练好的模型并进行可视化模型预测绘图
main函数载入模型,加载图片,输出结果: if __name__ == '__main__':image = Image.open(r"C:\Users\pic\test\he_5.jpg ...
最新文章
- ChaLearn Gesture Challenge_2:examples体验
- 九宫格抽奖转盘源码分析
- 产品经理必备知识之如何用CREATE模型对用户进行行为分析
- 在网页输出10的阶乘.php,ASP网络程序设计实验报告和期末考试复习范围
- NoSql中的B-tree、B+tree和LSM-tree
- spark集群详细搭建过程及遇到的问题解决(四)
- soul群聊显示服务器异常,soul群聊状态是什么
- Kubernetes学习总结(15)—— Kubernetes 实战之部署 Mysql 集群
- mysql显示表已存在_「Docker系列」 如何在Docker中部署MySQL数据库?
- Whitelabel Error Page : spring boot项目启动后,无法访问@RequestMapping标注的请求
- 安装anaconda,jupyter基本操作说明快捷键使用
- row_number()分页返回结果顺序不确定
- [转]jQuery知识总结
- 算法:数组找出2个只出现一次的数字(其他元素出现两次)
- ST-Link下载 KELL5程序下载 STM32程序下载
- java下cmyk图片读取和转换rgb,以及图片压缩
- iphone模拟器安装app
- 用计算机数字技术制作的电影是,计算机数字技术为电影带来的空前发展.doc
- Carte作为Windows服务
- 自动驾驶专题介绍 ———— 转向系统
热门文章
- java 同步 set_Java Collections synchronizedSet()用法及代码示例
- JSON与localStorage的爱恨情仇
- ServiceLoader用法demo
- 大巴山计算机教育中心那所学校,大巴山计算机教育中心
- TCP/IP模型的简单解释
- java jdbc 占位符_java-jdbc
- linux lvs 存储层,LVS集群配置之LVS介绍
- 迁移是10g-11g ogg正好有用武之地N种方法
- 【图】二分图最大权匹配
- 使用java.util.zip包实现根据文件目录控制文件的压缩与解压