我们先创建一个模型,使用的是pytorch笔记——简易回归问题_刘文巾的博客-CSDN博客 的主体框架,唯一不同的是,我这里用的是torch.nn.Sequential来定义模型框架,而不是那篇博客里面的类。

1 保存与加载之前的部分

#导入库
import torch#数据集
x=torch.linspace(-1,1,100).reshape(-1,1)
y=x*x+0.2*torch.rand(x.shape)#定义模型(Sequential比类简明了很多)
net=torch.nn.Sequential(torch.nn.Linear(1,10),torch.nn.ReLU(),torch.nn.Linear(10,1))#设置优化函数与损失函数
optimizer=torch.optim.SGD(net.parameters(),lr=0.2)loss_func=torch.nn.MSELoss()#进行训练
for epoch in range(1000):prediction=net(x)loss=loss_func(prediction,y)optimizer.zero_grad()#清空上一轮参数优化的参与梯度loss.backward()#损失函数反向传播optimizer.step()#梯度更新#打印模型里面的参数
for a,b in enumerate(net.parameters()):print('no:',a,'\n',b)
'''
一共有四个参数,分别对应的是每一层的w和b
no: 0 Parameter containing:
tensor([[ 1.0588],[ 0.1654],[ 0.8578],[-0.8756],[-1.0935],[ 0.7588],[ 0.9043],[-0.0723],[-0.4335],[ 0.3010]], requires_grad=True)
no: 1 Parameter containing:
tensor([-0.5888,  0.9550, -0.1572, -0.2610, -0.4367,  0.3084, -0.3802, -0.3834,0.6192,  0.3012], requires_grad=True)
no: 2 Parameter containing:
tensor([[ 0.3397,  0.0161,  0.7238,  0.6869,  0.5263,  0.1717,  0.6978,  0.1012,0.3311, -0.2264]], requires_grad=True)
no: 3 Parameter containing:
tensor([-0.0836], requires_grad=True)
'''

2 存储与加载(方法1)——直接保存模型

我们使用torch.save直接保存模型

torch.save(net,'net.pkl')

加载模型的时候,直接torch.load即可(可以看到net2参数和net是一样的)

net2=torch.load('net.pkl')
for a,b in enumerate(net2.parameters()):print('no:',a,'\n',b)'''
no: 0 Parameter containing:
tensor([[ 1.0588],[ 0.1654],[ 0.8578],[-0.8756],[-1.0935],[ 0.7588],[ 0.9043],[-0.0723],[-0.4335],[ 0.3010]], requires_grad=True)
no: 1 Parameter containing:
tensor([-0.5888,  0.9550, -0.1572, -0.2610, -0.4367,  0.3084, -0.3802, -0.3834,0.6192,  0.3012], requires_grad=True)
no: 2 Parameter containing:
tensor([[ 0.3397,  0.0161,  0.7238,  0.6869,  0.5263,  0.1717,  0.6978,  0.1012,0.3311, -0.2264]], requires_grad=True)
no: 3 Parameter containing:
tensor([-0.0836], requires_grad=True)
'''

3 存储和加载(方法2)——保存模型参数

保存的话存模型的参数

torch.save(net.state_dict(),'net_params.pkl')

加载的话,我们得先重新声明一个新的神经网络结构(用Sequential和用类都可以,有了这个新的神经网络后,才可以把参数传进去)【因为在声明新的神经网络之前,我们现在存的内容即使加载出来了,也不知道这些参数对应的结构是什么】

#声明一个新的net
net3=torch.nn.Sequential(torch.nn.Linear(1,10),torch.nn.ReLU(),torch.nn.Linear(10,1))
#加载数据
net3.load_state_dict(torch.load('net_params.pkl') )for a,b in enumerate(net3.parameters()):print('no:',a,'\n',b)
'''
和net也是一样的
no: 0 Parameter containing:
tensor([[ 1.0588],[ 0.1654],[ 0.8578],[-0.8756],[-1.0935],[ 0.7588],[ 0.9043],[-0.0723],[-0.4335],[ 0.3010]], requires_grad=True)
no: 1 Parameter containing:
tensor([-0.5888,  0.9550, -0.1572, -0.2610, -0.4367,  0.3084, -0.3802, -0.3834,0.6192,  0.3012], requires_grad=True)
no: 2 Parameter containing:
tensor([[ 0.3397,  0.0161,  0.7238,  0.6869,  0.5263,  0.1717,  0.6978,  0.1012,0.3311, -0.2264]], requires_grad=True)
no: 3 Parameter containing:
tensor([-0.0836], requires_grad=True)
'''

4 两种方法的比较

存参数的文件占用的空间少一点,这个在目前这种比较简单的模型可能还看不出来。对于那种大的模型,省下来的空间还是蛮多的。

pytorch模型的保存与加载相关推荐

  1. TensorFlow2.0:模型的保存与加载

    ** 一.权重参数的保存与加载 ** network.save_weights('weights.ckpt') network.load_weights('weights.ckpt') 权重参数的保存 ...

  2. [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式)

    [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式) 个人网站–> http://www.yansongsong.cn TensorFl ...

  3. pytorch保存模型pth_Day159:模型的保存与加载

    网络结构和参数可以分开的保存和加载,因此,pytorch保存模型有两种方法: 保存 整个模型 (结构+参数) 只保存模型参数(官方推荐) # 保存整个网络torch.save(model, check ...

  4. Pytorch:模型的保存与加载 torch.save()、torch.load()、torch.nn.Module.load_state_dict()

    Pytorch 保存和加载模型后缀:.pt 和.pth 1 torch.save() [source] 保存一个序列化(serialized)的目标到磁盘.函数使用了Python的pickle程序用于 ...

  5. torch.save()模型的保存于加载

    保存模型主要分为两类:保存整个模型和只保存模型参数 1.保存加载整个模型(不推荐): 保存整个网络模型,网络结构+权重参数 torch.save(model,'net.pth') 加载整个网络模型(可 ...

  6. MXNET学习笔记(二):模型的保存与加载

    当序列化 NDArray 的时候,我们序列化的是NDArray 中保存的 tensor 值.当序列化 Symbol 的时候,我们序列化的是 Graph. Symbol序列化 当序列化 Symbol 的 ...

  7. pb 保存变量文件名_【Tensorflow 2.0 正式版教程】模型的保存、加载与迁移

    模型的保存和加载可以直接通过Model类的save_weights和load_weights实现.默认的保存格式为tensorflow的checkpoint格式,也可以手动设置保存为h5文件. mo ...

  8. 【xgboost】xgboost模型的保存与加载

    xgboost模型的保存方法 有多种方法可以保存xgboost模型,包括pickle,joblib,以及原生的save_model,load_model函数 其中Pickle是Python中序列化对象 ...

  9. tensorflow1.0模型的保存、加载、在训练

    1.checkpoint文件总览 tensorflow保存的模型文件如下所示: .meta文件保存的是图结构,meta文件是pb(protocol buffer)格式文件,包含变量.op.集合等. c ...

最新文章

  1. windows 2003 下oracle从10.2.0.1升级到10.2.0.4
  2. *27.硬实时和软实时
  3. Spring Boot——获取上传文件的MD5值解决方案
  4. boost::container实现devector选项程序
  5. 微软 Visual Studio 2019 16.5 发布:.NET 移动开发、生产力
  6. net core 3.0 之Grpc新特性小试牛刀
  7. Map的value值降序排序与升序排序(java)
  8. 如何用更短时间写出高质量的博客文章经验分享
  9. 数据结构与算法(Python)第二天
  10. iPhone Objective-C EXC_BAD_ACCESS问题
  11. Gulp介绍与入门实践
  12. 在java中如何定义一个方法,个人编程学习网 - Java-方法中的术语和如何定义方法...
  13. zendstudio html插件,ZendStudio安装Aptana插件(html,css,js代码提示功能)_html/css_WEB-ITnose...
  14. iPhone蓝牙回控,iPhone手机互联,认证
  15. python控制苹果手机触摸屏失灵怎么办_苹果手机触摸屏失灵了,怎么解决?
  16. Python将普通视频变成动漫视频,这就是知识的力量~
  17. 谢惠民,恽自求,易法槐,钱定边编数学分析习题课讲义23.2.3练习题参考解答[来自陶哲轩小弟]...
  18. 将数组倒置java_java数组元素倒置
  19. 2021年四川高考成绩及录取结果查询,2021年四川高考录取结果查询时间及查询入口,录取结果多少号公布...
  20. 【推荐】真正的安卓网络摄像机(Android IPCamera)任意浏览器输入IP地址即可观看视频

热门文章

  1. Should i Backup all my domain controllers
  2. CentOS系统提示用户名不在sudoers文件中
  3. SaaS系统给企业带来了哪些优势
  4. 利用SSH端口转发功能实现X转发
  5. BUG管理系统(Mantis)迁移实录
  6. 为SharePoint网站创建自定义导航菜单
  7. mac80211解析之发送速率控制
  8. IDEA2021.03 项目全部变红,但是可以正常编译运行
  9. CSP认证201604-4 游戏[C++题解]:bfs、拆点、迷宫问题加强版、三维数组
  10. Acwing900. 整数划分[计数类dp]:完全背包解法