pytorch模型的保存与加载
我们先创建一个模型,使用的是pytorch笔记——简易回归问题_刘文巾的博客-CSDN博客 的主体框架,唯一不同的是,我这里用的是torch.nn.Sequential来定义模型框架,而不是那篇博客里面的类。
1 保存与加载之前的部分
#导入库
import torch#数据集
x=torch.linspace(-1,1,100).reshape(-1,1)
y=x*x+0.2*torch.rand(x.shape)#定义模型(Sequential比类简明了很多)
net=torch.nn.Sequential(torch.nn.Linear(1,10),torch.nn.ReLU(),torch.nn.Linear(10,1))#设置优化函数与损失函数
optimizer=torch.optim.SGD(net.parameters(),lr=0.2)loss_func=torch.nn.MSELoss()#进行训练
for epoch in range(1000):prediction=net(x)loss=loss_func(prediction,y)optimizer.zero_grad()#清空上一轮参数优化的参与梯度loss.backward()#损失函数反向传播optimizer.step()#梯度更新#打印模型里面的参数
for a,b in enumerate(net.parameters()):print('no:',a,'\n',b)
'''
一共有四个参数,分别对应的是每一层的w和b
no: 0 Parameter containing:
tensor([[ 1.0588],[ 0.1654],[ 0.8578],[-0.8756],[-1.0935],[ 0.7588],[ 0.9043],[-0.0723],[-0.4335],[ 0.3010]], requires_grad=True)
no: 1 Parameter containing:
tensor([-0.5888, 0.9550, -0.1572, -0.2610, -0.4367, 0.3084, -0.3802, -0.3834,0.6192, 0.3012], requires_grad=True)
no: 2 Parameter containing:
tensor([[ 0.3397, 0.0161, 0.7238, 0.6869, 0.5263, 0.1717, 0.6978, 0.1012,0.3311, -0.2264]], requires_grad=True)
no: 3 Parameter containing:
tensor([-0.0836], requires_grad=True)
'''
2 存储与加载(方法1)——直接保存模型
我们使用torch.save直接保存模型
torch.save(net,'net.pkl')
加载模型的时候,直接torch.load即可(可以看到net2参数和net是一样的)
net2=torch.load('net.pkl')
for a,b in enumerate(net2.parameters()):print('no:',a,'\n',b)'''
no: 0 Parameter containing:
tensor([[ 1.0588],[ 0.1654],[ 0.8578],[-0.8756],[-1.0935],[ 0.7588],[ 0.9043],[-0.0723],[-0.4335],[ 0.3010]], requires_grad=True)
no: 1 Parameter containing:
tensor([-0.5888, 0.9550, -0.1572, -0.2610, -0.4367, 0.3084, -0.3802, -0.3834,0.6192, 0.3012], requires_grad=True)
no: 2 Parameter containing:
tensor([[ 0.3397, 0.0161, 0.7238, 0.6869, 0.5263, 0.1717, 0.6978, 0.1012,0.3311, -0.2264]], requires_grad=True)
no: 3 Parameter containing:
tensor([-0.0836], requires_grad=True)
'''
3 存储和加载(方法2)——保存模型参数
保存的话存模型的参数
torch.save(net.state_dict(),'net_params.pkl')
加载的话,我们得先重新声明一个新的神经网络结构(用Sequential和用类都可以,有了这个新的神经网络后,才可以把参数传进去)【因为在声明新的神经网络之前,我们现在存的内容即使加载出来了,也不知道这些参数对应的结构是什么】
#声明一个新的net
net3=torch.nn.Sequential(torch.nn.Linear(1,10),torch.nn.ReLU(),torch.nn.Linear(10,1))
#加载数据
net3.load_state_dict(torch.load('net_params.pkl') )for a,b in enumerate(net3.parameters()):print('no:',a,'\n',b)
'''
和net也是一样的
no: 0 Parameter containing:
tensor([[ 1.0588],[ 0.1654],[ 0.8578],[-0.8756],[-1.0935],[ 0.7588],[ 0.9043],[-0.0723],[-0.4335],[ 0.3010]], requires_grad=True)
no: 1 Parameter containing:
tensor([-0.5888, 0.9550, -0.1572, -0.2610, -0.4367, 0.3084, -0.3802, -0.3834,0.6192, 0.3012], requires_grad=True)
no: 2 Parameter containing:
tensor([[ 0.3397, 0.0161, 0.7238, 0.6869, 0.5263, 0.1717, 0.6978, 0.1012,0.3311, -0.2264]], requires_grad=True)
no: 3 Parameter containing:
tensor([-0.0836], requires_grad=True)
'''
4 两种方法的比较
存参数的文件占用的空间少一点,这个在目前这种比较简单的模型可能还看不出来。对于那种大的模型,省下来的空间还是蛮多的。
pytorch模型的保存与加载相关推荐
- TensorFlow2.0:模型的保存与加载
** 一.权重参数的保存与加载 ** network.save_weights('weights.ckpt') network.load_weights('weights.ckpt') 权重参数的保存 ...
- [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式)
[TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式) 个人网站–> http://www.yansongsong.cn TensorFl ...
- pytorch保存模型pth_Day159:模型的保存与加载
网络结构和参数可以分开的保存和加载,因此,pytorch保存模型有两种方法: 保存 整个模型 (结构+参数) 只保存模型参数(官方推荐) # 保存整个网络torch.save(model, check ...
- Pytorch:模型的保存与加载 torch.save()、torch.load()、torch.nn.Module.load_state_dict()
Pytorch 保存和加载模型后缀:.pt 和.pth 1 torch.save() [source] 保存一个序列化(serialized)的目标到磁盘.函数使用了Python的pickle程序用于 ...
- torch.save()模型的保存于加载
保存模型主要分为两类:保存整个模型和只保存模型参数 1.保存加载整个模型(不推荐): 保存整个网络模型,网络结构+权重参数 torch.save(model,'net.pth') 加载整个网络模型(可 ...
- MXNET学习笔记(二):模型的保存与加载
当序列化 NDArray 的时候,我们序列化的是NDArray 中保存的 tensor 值.当序列化 Symbol 的时候,我们序列化的是 Graph. Symbol序列化 当序列化 Symbol 的 ...
- pb 保存变量文件名_【Tensorflow 2.0 正式版教程】模型的保存、加载与迁移
模型的保存和加载可以直接通过Model类的save_weights和load_weights实现.默认的保存格式为tensorflow的checkpoint格式,也可以手动设置保存为h5文件. mo ...
- 【xgboost】xgboost模型的保存与加载
xgboost模型的保存方法 有多种方法可以保存xgboost模型,包括pickle,joblib,以及原生的save_model,load_model函数 其中Pickle是Python中序列化对象 ...
- tensorflow1.0模型的保存、加载、在训练
1.checkpoint文件总览 tensorflow保存的模型文件如下所示: .meta文件保存的是图结构,meta文件是pb(protocol buffer)格式文件,包含变量.op.集合等. c ...
最新文章
- windows 2003 下oracle从10.2.0.1升级到10.2.0.4
- *27.硬实时和软实时
- Spring Boot——获取上传文件的MD5值解决方案
- boost::container实现devector选项程序
- 微软 Visual Studio 2019 16.5 发布:.NET 移动开发、生产力
- net core 3.0 之Grpc新特性小试牛刀
- Map的value值降序排序与升序排序(java)
- 如何用更短时间写出高质量的博客文章经验分享
- 数据结构与算法(Python)第二天
- iPhone Objective-C EXC_BAD_ACCESS问题
- Gulp介绍与入门实践
- 在java中如何定义一个方法,个人编程学习网 - Java-方法中的术语和如何定义方法...
- zendstudio html插件,ZendStudio安装Aptana插件(html,css,js代码提示功能)_html/css_WEB-ITnose...
- iPhone蓝牙回控,iPhone手机互联,认证
- python控制苹果手机触摸屏失灵怎么办_苹果手机触摸屏失灵了,怎么解决?
- Python将普通视频变成动漫视频,这就是知识的力量~
- 谢惠民,恽自求,易法槐,钱定边编数学分析习题课讲义23.2.3练习题参考解答[来自陶哲轩小弟]...
- 将数组倒置java_java数组元素倒置
- 2021年四川高考成绩及录取结果查询,2021年四川高考录取结果查询时间及查询入口,录取结果多少号公布...
- 【推荐】真正的安卓网络摄像机(Android IPCamera)任意浏览器输入IP地址即可观看视频
热门文章
- Should i Backup all my domain controllers
- CentOS系统提示用户名不在sudoers文件中
- SaaS系统给企业带来了哪些优势
- 利用SSH端口转发功能实现X转发
- BUG管理系统(Mantis)迁移实录
- 为SharePoint网站创建自定义导航菜单
- mac80211解析之发送速率控制
- IDEA2021.03 项目全部变红,但是可以正常编译运行
- CSP认证201604-4	游戏[C++题解]:bfs、拆点、迷宫问题加强版、三维数组
- Acwing900. 整数划分[计数类dp]:完全背包解法