训练过程中保存模型参数,就不怕断电了——沃资基·索德

在训练完成之前,我们需要每隔一段时间保存模型当前参数值,一方面可以防止断电重跑,另一方面可以观察不同迭代次数模型的表现;在训练完成以后,我们需要保存模型参数值用于后续的测试过程。所以,保存的对象包含网络参数值、优化器参数值、epoch值等等。

一、定义一个容易识别的网络

在正式介绍模型的保存和加载之前,我们首先定义一个基本的网络Net,它只包含一个全连接层:

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.layer = nn.Linear(1, 1)self.layer.weight = nn.Parameter(torch.FloatTensor([[10]]))self.layer.bias = nn.Parameter(torch.FloatTensor([1]))def forward(self, x):y = self.layer(x)return y

将全连接的权重w和偏差b分别设置为10和1,全连接的计算方式如下:

假设输入x=1,可以知道y值为11:

测试一下输出是不是11,代码如下:

x = torch.FloatTensor([[1]])
net = Net()
out = net(x)
print(out)

输出:tensor([[11.]], grad_fn=<AddmmBackward>),说明上述计算是正确的。不采用参数随机初始化,而是用特殊的数值初始化,是因为我们希望重载模型的时候,能够从特殊数值一眼判断出保存和重载过程是否正确,也可以把权重设置为一张图片数值,然后判断加载的参数值能不能恢复原图。

二、保存Net的参数值

保存模型参数之前,需要知道Net的参数值存储在其state_dict(状态字典)属性中,我们查看一下net的state_dict包含哪些参数:

print(net.state_dict())

我们将会得到net包含的所有参数名称与参数值

包含一个weight和一个bias,对应的值分别是10和1,和我们之前定义的全连接层一致。我们需要保存的就是这个state_dict,保存的函数为“torch.save()”,参数是我们需要保存的dict和存储路径

torch.save(obj=net.state_dict(), f="models/net.pth")

现在,同级目录models下将会出现net.pth文件,pth文件中的内容就是net的参数名称和值对应的state_dict,如下:

三、加载Net参数值并用于新的模型

最后一个步骤就是从pth文件中重新获取Net参数值,并把参数值装载到新定义的Model对象中。这里我们重新定义一个结构和Net类相同的类Model,区别仅仅是Model参数初始值和Net不同,代码如下:

class Model(nn.Module):def __init__(self):super(Model, self).__init__()self.layer = nn.Linear(1, 1)self.layer.weight = nn.Parameter(torch.FloatTensor([[0]]))self.layer.bias = nn.Parameter(torch.FloatTensor([0]))def forward(self, x):out = self.layer(x)return out

这里将Model的初始值权重w和偏差都设置为0,查看其state_dict:

model = Model()
print(model.state_dict())

得到的w和b值与预期相同,均为0,如下:

现在,我们将model对象的参数值设置为net.pth中的值,需要使用“model.load_state_dict()”函数重置model的参数值为"torch.load(models/ net.pth)"中的参数值,如下:

model.load_state_dict(torch.load("models/net.pth"))
print(model.state_dict())

至此,model的w和b值就不再是0了,而是net中w和b对应的10和1,如下:

其中参数值重载的核心函数为“model.load_state_dict()”,每个继承自nn.Module的网络都能通过这个函数设定参数值。

四、优化器与epoch的保存

保存优化器参数值和epoch值的主要目的是用于继续训练,保存的流程依旧是先“torch.save()”再“torch.load_state_dict()”,我们首先定义一个Adam优化器、一个任意的epoch值与net如下:

net = Net()
Adam = optim.Adam(params=net.parameters(), lr=0.001, betas=(0.5, 0.999))
epoch = 96

现在,创建一个字典来保存所有的对象,并用save函数保存这个字典

all_states = {"net": net.state_dict(), "Adam": Adam.state_dict(), "epoch": epoch}
torch.save(obj=all_states, f="models/all_states.pth")

所有的对象都被保存到models文件夹下了:

可以使用load()函数把所有的对象再次提取出来:

reload_states = torch.load("models/all_states.pth")
print(reload_states)

得到的所有参数如下:

五、总结

pytorch中state_dict()和load_state_dict()函数配合使用可以实现状态的获取与重载,load()和save()函数配合使用可以实现参数的存储与读取。其中最重要的部分是“字典”的概念,因为参数对象的存储是需要“名称”——“值”对应(即键值对),读取时也是通过键值对读取的。

参考:

https://www.pytorchtutorial.com/pytorch-note5-save-and-restore-models/

https://blog.csdn.net/Code_Mart/article/details/88254444

angularjs中state的参数4_一文梳理pytorch保存和重载模型参数攻略相关推荐

  1. Pytorch 如何 优化/调整 模型参数

    Pytorch 如何自动优化/调整 模型超参 文章目录 Pytorch 如何自动优化/调整 模型超参 背景 优化模型参数 贝叶斯优化 深度学习框架下的参数优化 平台安装 使用参考 参考 背景 对于优化 ...

  2. python保存模型与参数_基于pytorch的保存和加载模型参数的方法

    当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了. 保存和加载模型参数有两种方式: 方式一: torc ...

  3. Java伽马什么意思,游戏设置中的伽马值是什么意思 | 手游网游页游攻略大全

    发布时间:2016-03-21 中法司马         我现在永远支持5红 3红4蓝的 就请闪远 因为会玩的诸葛前期 直接拼药. 我以前都是4红2蓝 但是呢~ 总被诸葛压着打 哭求了 然后看了马超他 ...

  4. ML之FE:特征工程中数据缺失值填充的简介、方法、全部代码实现之详细攻略

    ML之FE:特征工程中数据缺失值填充的简介.方法.全部代码实现之详细攻略 目录 特征工程中数据缺失值填充的简介.方法.经典案例

  5. Python之pandas:pandas中常见的数据类型转换四大方法以及遇到的一些坑之详细攻略

    Python之pandas:pandas中常见的数据类型转换四大方法以及遇到的一些坑之详细攻略 目录 pandas中常见的数据类型转换方法 T1.读取时直接转换数据类型 T2.采用astype

  6. Python:python语言中与时间有关的库函数简介、安装、使用方法之详细攻略

    Python:python语言中与时间有关的库函数简介.安装.使用方法之详细攻略 目录 与时间有关的库函数 案例应用 1.打印程序块前后运行时间 #T1.采用time库

  7. PyTorch计算损失函数对模型参数的Hessian矩阵

    前言 在实现Per-FedAvg的代码时,遇到如下问题: 可以发现,我们需要求损失函数对模型参数的Hessian矩阵. 模型定义 我们定义一个比较简单的模型: class ANN(nn.Module) ...

  8. 英魂之刃服务器维护中修改,英魂之刃gg修改教程 | 手游网游页游攻略大全

    发布时间:2016-04-09 今天小编要给大家带来的是辅助秒世界BOOS修改教程,如果你正好在寻找的辅助修改器,修改教程那就来对地方了哟. 辅助秒世界BO ... 标签: 剑魂之刃辅助 剑魂之刃修改 ...

  9. 第七史诗无限显示服务器连接中,第七史诗神器满破是什么意思?神器满破攻略...

    第七史诗中为了增添游戏的多样化,游戏中很重要的一个内容就是神器,很多萌新只知道怎么使用神器,但是大佬们对于神器的讨论,往往会听到一个词叫做满破,很多萌新都不知道什么叫做满破,这里小编带来的就是第七史诗 ...

最新文章

  1. 清华团队将Transformer用到3D点云分割上后,效果好极了
  2. 通信系统计算机仿真上机实验报告,昆明理工大学计算机仿真实验.docx
  3. csrediscore访问redis集群_搭建文档 | centos7.6环境下redis5.0.8集群搭建
  4. 动态生成CheckBox(Winform程序)
  5. unity 开发总结
  6. Android Monkey(转载)
  7. 自学Java必看的知识点,猿们怎么看?
  8. scala基础之隐式转换
  9. 马斯克公开特斯拉Model 3成本 价值这个数...
  10. 信息: 开始协议处理句柄[http-nio-8080]_你必须要知道的HTTP协议原理
  11. git flow命令
  12. AI中台——智能聊天机器人平台的架构与应用
  13. Unihan(统汉字)常用字段介绍
  14. azkaban 与 java任务_任务调度工具oozie和azkaban的对比
  15. winrar命令行加压解密
  16. 接着前几期内容继续对单片机怎么学习来做一个了解
  17. git仓库-客户端软件安装配置过程
  18. 【Scratch考级99图】图32-等级考试scratch绘制复杂图形8个八边形 少儿编程 scratch画图案例教程
  19. DIY:利用单片机自制的RGB拖尾流水灯,含电路图、源代码、演示视频、效果图
  20. 【四月答题勋章】四月答题勋章获取方法

热门文章

  1. keras从入门到放弃(十一)电影评价预测
  2. 北京内推 | 华为高斯实验室招聘AI算法工程师/实习生
  3. 从多篇2021年顶会论文看多模态预训练模型最新研究进展
  4. 博士申请 | 佐治亚理工学院陈永昕教授招收机器学习理论方向博士生
  5. ACL 2021 | 为什么机器阅读理解模型会学习走捷径?
  6. 知识表示与融入技术前沿进展及应用
  7. 热门的模型跨界,Transformer、GPT做CV任务一文大盘点
  8. 百万奖金!交通事件、医学病理、广告检测,江苏大数据开发与应用大赛报名...
  9. 实习推荐 | 腾讯AI Lab虚拟人中心招聘算法工程师实习生
  10. 直播 | 平安人寿资深算法工程师姚晓远:对话生成模型的探析与创新