pytorch1.0网络保存、提取、加载

import torch
import torch.nn.functional as F  # 包含激励函数
import matplotlib.pyplot as plt# 假数据
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)# The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
# x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)def save():# save net1# 建网络net1 = torch.nn.Sequential(torch.nn.Linear(1, 10),torch.nn.ReLU(),torch.nn.Linear(10, 1))optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)loss_func = torch.nn.MSELoss()# 训练for t in range(100):prediction = net1(x)loss = loss_func(prediction, y)optimizer.zero_grad()loss.backward()optimizer.step()# plot resultplt.figure(1, figsize=(10, 3))plt.subplot(131)plt.title('Net1')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)# 2 ways to save the nettorch.save(net1, 'net.pkl')  # save entire net # 保存整个网络torch.save(net1.state_dict(), 'net_params.pkl')   # save only the parameters # 只保存网络中的参数 (速度快, 占内存少)# 提取网络
def restore_net():# restore entire net1 to net2net2 = torch.load('net.pkl')prediction = net2(x)# plot resultplt.subplot(132)plt.title('Net2')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)# 只提取网络参数
def restore_params():# 新建 net3# restore only the parameters in net1 to net3net3 = torch.nn.Sequential(torch.nn.Linear(1, 10),torch.nn.ReLU(),torch.nn.Linear(10, 1))# 将保存的参数复制到 net3# copy net1's parameters into net3net3.load_state_dict(torch.load('net_params.pkl'))prediction = net3(x)# plot resultplt.subplot(133)plt.title('Net3')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)plt.show()# 保存 net1 (1. 整个网络, 2. 只有参数)
# save net1
save()
# 提取整个网络
# restore entire net (may slow)
restore_net()
# 提取网络参数, 复制到新网络
# restore only the net parameters
restore_params()

转载于:https://www.cnblogs.com/jeshy/p/11199820.html

pytorch1.0神经网络保存、提取、加载相关推荐

  1. TensorFlow2.0 —— 模型保存与加载

    目录 1.Keras版本模型保存与加载 2.自定义版本模型保存与加载 3.总结 1.Keras版本模型保存与加载 保存模型权重(model.save_weights) 保存HDF5文件(model.s ...

  2. tensorflow1.0模型的保存、加载、在训练

    1.checkpoint文件总览 tensorflow保存的模型文件如下所示: .meta文件保存的是图结构,meta文件是pb(protocol buffer)格式文件,包含变量.op.集合等. c ...

  3. TensorFlow2.0:模型的保存与加载

    ** 一.权重参数的保存与加载 ** network.save_weights('weights.ckpt') network.load_weights('weights.ckpt') 权重参数的保存 ...

  4. Pytorch网络模型权重初始化、保存与加载模型、加载预训练模型、按需设置学习率

    前言 在我们对神经网络模型进行训练时,往往需要对模型进行初始化或者加载预训练模型.本文将对模型的权重初始化与加载预训练模型做一个学习记录,以便后续查询使用. 权重初始化 常见的初始化方法 PyTorc ...

  5. tensorflow 1.x Saver(保存与加载模型) 预测

    20201231 tensorflow 1.X 模型保存 https://blog.csdn.net/qq_35290785/article/details/89646248 保存模型 saver=t ...

  6. pytorch模型的保存与加载

    我们先创建一个模型,使用的是pytorch笔记--简易回归问题_刘文巾的博客-CSDN博客 的主体框架,唯一不同的是,我这里用的是torch.nn.Sequential来定义模型框架,而不是那篇博客里 ...

  7. pytorch数据加载、模型保存及加载

    主要涉及的Pytorch官方示例下图红框部分的一些翻译及备注. 1.数据加载及处理   该部分主要是用于进行数据集加载及数据预处理说明,使用的数据集为:人脸+标注坐标.demo程序需要pandas(读 ...

  8. tensor和模型 保存与加载 PyTorch

    PyTorch教程-7:PyTorch中保存与加载tensor和模型详解 保存和读取Tensor PyTorch中的tensor可以保存成 .pt 或者 .pth 格式的文件,使用torch.save ...

  9. (一次性搞定)ORB_SLAM2地图保存与加载

    (一次性搞定)ORB_SLAM2地图保存与加载 本文记录了ORB_SLAM2中地图保存与加载的过程. 参考博客: https://blog.csdn.net/qq_34254510/article/d ...

最新文章

  1. 遭遇“生活天花板”,如何用 OKR 弯道超车?
  2. PLSQL重点问题理解和实战
  3. Java实现算法导论中图的广度优先搜索(BFS)和深度优先搜索(DFS)
  4. search string iteration
  5. GlobalPointer:用统一的方式处理嵌套和非嵌套NER
  6. 图论--一般图带花树匹配--模板
  7. C#中数据类型及其转换知识点汇总
  8. c调用python第三方库_Python使用ctypes模块调用DLL函数之C语言数组与numpy数组传递...
  9. Windows Server 2008搭建域环境---安装活动目录
  10. 小霸王被申请破产重整;虎牙员工自曝被HR抬出公司;Office 2010被微软终止服务|极客头条
  11. 如何在你的blog中添加炫酷的飘雪动画效果
  12. qt中如何刷新一下屏幕_感情维护:如何在恋爱关系中分开一下,然后更坚强地回来...
  13. Http实战之Wireshark抓包分析
  14. python攻击局域网电脑_怎么攻击对方电脑?以知对方IP,且对方在线
  15. 推荐下载使用:功能强大的光盘刻录软件NERO 9.0中文版(最新官方原版+有效序列号)(转)...
  16. 冰点下载器手机版apk_冰点文库下载器app下载_冰点文库下载器手机安卓版软件下载v1.0.3...
  17. 非三星手机无法登录三星账号_如何解决所有三星手机的烦恼
  18. JavaScript【狂神笔记】
  19. WPS怎么转换成PDF?这样转换准没错
  20. Java.IO.InputStream-OutputStream

热门文章

  1. html中如何计算图片的像素,html – 浏览器的1px计算问题(子像素问题)
  2. element ui select 自动向上向下弹出_[selenium]用Selenium自动填问卷星的问卷
  3. H3C设备运行状态查询常用命令
  4. 新 CEO 谈论GitHub 被微软接管后的未来
  5. Introducing DataFrames in Apache Spark for Large Scale Data Science(中英双语)
  6. 15款Cocos2d-x游戏源码
  7. IntelJIdea 如何修改控制台字体大小和主题
  8. 大文件表空间+创建大文件表空间+查询数据库表空间类型信息+查询数据库表空间类型信息...
  9. XCODE 出现 The operation couldn't be completed.(LaunchServicesError error 0.)错误修复
  10. 教你如何监控 Apache?