pytorch1.0神经网络保存、提取、加载
pytorch1.0网络保存、提取、加载
import torch import torch.nn.functional as F # 包含激励函数 import matplotlib.pyplot as plt# 假数据 x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1) # x data (tensor), shape=(100, 1) y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)# The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors # x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)def save():# save net1# 建网络net1 = torch.nn.Sequential(torch.nn.Linear(1, 10),torch.nn.ReLU(),torch.nn.Linear(10, 1))optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)loss_func = torch.nn.MSELoss()# 训练for t in range(100):prediction = net1(x)loss = loss_func(prediction, y)optimizer.zero_grad()loss.backward()optimizer.step()# plot resultplt.figure(1, figsize=(10, 3))plt.subplot(131)plt.title('Net1')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)# 2 ways to save the nettorch.save(net1, 'net.pkl') # save entire net # 保存整个网络torch.save(net1.state_dict(), 'net_params.pkl') # save only the parameters # 只保存网络中的参数 (速度快, 占内存少)# 提取网络 def restore_net():# restore entire net1 to net2net2 = torch.load('net.pkl')prediction = net2(x)# plot resultplt.subplot(132)plt.title('Net2')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)# 只提取网络参数 def restore_params():# 新建 net3# restore only the parameters in net1 to net3net3 = torch.nn.Sequential(torch.nn.Linear(1, 10),torch.nn.ReLU(),torch.nn.Linear(10, 1))# 将保存的参数复制到 net3# copy net1's parameters into net3net3.load_state_dict(torch.load('net_params.pkl'))prediction = net3(x)# plot resultplt.subplot(133)plt.title('Net3')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)plt.show()# 保存 net1 (1. 整个网络, 2. 只有参数) # save net1 save() # 提取整个网络 # restore entire net (may slow) restore_net() # 提取网络参数, 复制到新网络 # restore only the net parameters restore_params()
转载于:https://www.cnblogs.com/jeshy/p/11199820.html
pytorch1.0神经网络保存、提取、加载相关推荐
- TensorFlow2.0 —— 模型保存与加载
目录 1.Keras版本模型保存与加载 2.自定义版本模型保存与加载 3.总结 1.Keras版本模型保存与加载 保存模型权重(model.save_weights) 保存HDF5文件(model.s ...
- tensorflow1.0模型的保存、加载、在训练
1.checkpoint文件总览 tensorflow保存的模型文件如下所示: .meta文件保存的是图结构,meta文件是pb(protocol buffer)格式文件,包含变量.op.集合等. c ...
- TensorFlow2.0:模型的保存与加载
** 一.权重参数的保存与加载 ** network.save_weights('weights.ckpt') network.load_weights('weights.ckpt') 权重参数的保存 ...
- Pytorch网络模型权重初始化、保存与加载模型、加载预训练模型、按需设置学习率
前言 在我们对神经网络模型进行训练时,往往需要对模型进行初始化或者加载预训练模型.本文将对模型的权重初始化与加载预训练模型做一个学习记录,以便后续查询使用. 权重初始化 常见的初始化方法 PyTorc ...
- tensorflow 1.x Saver(保存与加载模型) 预测
20201231 tensorflow 1.X 模型保存 https://blog.csdn.net/qq_35290785/article/details/89646248 保存模型 saver=t ...
- pytorch模型的保存与加载
我们先创建一个模型,使用的是pytorch笔记--简易回归问题_刘文巾的博客-CSDN博客 的主体框架,唯一不同的是,我这里用的是torch.nn.Sequential来定义模型框架,而不是那篇博客里 ...
- pytorch数据加载、模型保存及加载
主要涉及的Pytorch官方示例下图红框部分的一些翻译及备注. 1.数据加载及处理 该部分主要是用于进行数据集加载及数据预处理说明,使用的数据集为:人脸+标注坐标.demo程序需要pandas(读 ...
- tensor和模型 保存与加载 PyTorch
PyTorch教程-7:PyTorch中保存与加载tensor和模型详解 保存和读取Tensor PyTorch中的tensor可以保存成 .pt 或者 .pth 格式的文件,使用torch.save ...
- (一次性搞定)ORB_SLAM2地图保存与加载
(一次性搞定)ORB_SLAM2地图保存与加载 本文记录了ORB_SLAM2中地图保存与加载的过程. 参考博客: https://blog.csdn.net/qq_34254510/article/d ...
最新文章
- 遭遇“生活天花板”,如何用 OKR 弯道超车?
- PLSQL重点问题理解和实战
- Java实现算法导论中图的广度优先搜索(BFS)和深度优先搜索(DFS)
- search string iteration
- GlobalPointer:用统一的方式处理嵌套和非嵌套NER
- 图论--一般图带花树匹配--模板
- C#中数据类型及其转换知识点汇总
- c调用python第三方库_Python使用ctypes模块调用DLL函数之C语言数组与numpy数组传递...
- Windows Server 2008搭建域环境---安装活动目录
- 小霸王被申请破产重整;虎牙员工自曝被HR抬出公司;Office 2010被微软终止服务|极客头条
- 如何在你的blog中添加炫酷的飘雪动画效果
- qt中如何刷新一下屏幕_感情维护:如何在恋爱关系中分开一下,然后更坚强地回来...
- Http实战之Wireshark抓包分析
- python攻击局域网电脑_怎么攻击对方电脑?以知对方IP,且对方在线
- 推荐下载使用:功能强大的光盘刻录软件NERO 9.0中文版(最新官方原版+有效序列号)(转)...
- 冰点下载器手机版apk_冰点文库下载器app下载_冰点文库下载器手机安卓版软件下载v1.0.3...
- 非三星手机无法登录三星账号_如何解决所有三星手机的烦恼
- JavaScript【狂神笔记】
- WPS怎么转换成PDF?这样转换准没错
- Java.IO.InputStream-OutputStream
热门文章
- html中如何计算图片的像素,html – 浏览器的1px计算问题(子像素问题)
- element ui select 自动向上向下弹出_[selenium]用Selenium自动填问卷星的问卷
- H3C设备运行状态查询常用命令
- 新 CEO 谈论GitHub 被微软接管后的未来
- Introducing DataFrames in Apache Spark for Large Scale Data Science(中英双语)
- 15款Cocos2d-x游戏源码
- IntelJIdea 如何修改控制台字体大小和主题
- 大文件表空间+创建大文件表空间+查询数据库表空间类型信息+查询数据库表空间类型信息...
- XCODE 出现 The operation couldn't be completed.(LaunchServicesError error 0.)错误修复
- 教你如何监控 Apache?