导读

深度学习中模型的计算图可以被分为两种,静态图动态图,这两种模型的计算图各有优劣。

静态图需要我们先定义好网络的结构,然后再进行计算,所以静态图的计算速度快,但是debug比较的困难,因为只有当给计算图输入数据之后模型的参数才会有值。

动态图则是边运行边构建,动态图的优点在于可以在搭建网络的时候看见变量的值便于检查,缺点就是前向计算不方便优化,因为不知道下一步计算是做什么。

针对于这两种不同的计算图,paddlepladdle提供了多种不同的方式来保存和加载

模型保存和加载


  • paddle.save:模型参数和超参的保存,支持动态图和静态图
  • paddle.load:模型参数和超参的加载,支持动态图和静态图
  • paddle.jit.save:动态图模型参数和结构的保存
  • paddle.jit.load:动态图模型参数和结构的加载
  • paddle.static.save_inference_model:静态图模型参数和结构的保存
  • paddle.static.load_inference_model:静态图模型参数和结构的加载

除此之外,paddlepaddle还提供了动态图转静态图来训练和保存模型,用来加快模型的训练效率。

模型参数的保存和加载

  • 模型训练

定义一个线程的二元一次方程,通过随机生成一些输入数据来计算输出,来训练模型

import paddle
import numpy as np
from paddle import optimizer,nnnp.random.seed(28)#设置样本的数量
num_samples = 1000
#模型训练的超参数设置
epoch = 10
batch_size = 512
class_num = 1
input_size = 2
learing_rate = 0.01class LinearData(paddle.io.Dataset):def __init__(self,num_samples,input_size):super(LinearData, self).__init__()self._num_samples = num_samples# 生成一个线性方程的数据# 设置线程方程的w和bw = np.random.rand(input_size)b = np.random.rand()#随机生成输入数据,根据线程方程的参数来生成输出self._x = np.random.rand(num_samples,2).astype("float32")self._y = np.sum(w * self._x,axis=1) + bself._y = self._y.reshape(-1,class_num).astype("float32")def __getitem__(self, idx):return self._x[idx],self._y[idx]def __len__(self):return self._num_samples#构建模型
class SimpleNet(nn.Layer):def __init__(self,input_size,num_classes=class_num):super(SimpleNet, self).__init__()self._linear = nn.Linear(input_size,class_num)def forward(self, x):output = self._linear(x)return outputdef train(data_loader,model,loss_fn,opt):#开始训练模型for epoch_idx in range(epoch):for batch_idx,batch_data in enumerate(data_loader):batch_x, batch_y = batch_datapred_batch_y = model(batch_x)#计算Lossbatch_loss = loss_fn(pred_batch_y,batch_y)#更新参数batch_loss.backward()opt.step()opt.clear_grad()#打印日志print("epoch:{},batch idx:{},loss:{:.4f}".format(epoch_idx,batch_idx,np.mean(batch_loss.numpy())))#加载数据集
dataset = LinearData(num_samples,input_size)
data_loader = paddle.io.DataLoader(dataset,shuffle=True,batch_size=batch_size)
#初始化模型
model = SimpleNet(input_size,class_num)
#定义loss函数
loss_fn = paddle.nn.loss.MSELoss()
#优化器设置
opt = paddle.optimizer.sgd.SGD(learning_rate=learing_rate,parameters=model.parameters())
#训练模型
train(data_loader,model,loss_fn,opt)
  • 动态图的参数保存

通过paddle.save函数来保存模型的参数和优化器的参数

#保存模型的参数
paddle.save(model.state_dict(),"model.pdparams")
#保存优化器的参数
paddle.save(opt.state_dict(),"opt.pdparams")
  • 参数的加载

通过paddle.load来从磁盘中加载模型和优化器的参数

#加载模型的参数
model.set_state_dict(paddle.load("model.pdparams"))
#加载优化器的参数
opt.set_state_dict(paddle.load("opt.pdparams"))

静态图模型参数和结构的保存

  • 构建一个静态图模型
    构建了一个简单的静态图模型,只包含了输入和输出
import paddle#开启静态图模型
paddle.enable_static()#设置静态图模型的输入
input = paddle.static.data(name="input",shape=[None,10],dtype="float32")
#设置模型的输出
output = paddle.static.nn.fc(input,2)#将模型放在CPU上执行
place = paddle.CPUPlace()
#静态图的执行器
exe = paddle.static.Executor(place)
#获取到静态图模型并且运行
exe.run(paddle.static.default_startup_program())
  • 保存静态图的模型和参数
prog = paddle.static.default_startup_program()
#保存静态图的参数
paddle.save(prog.state_dict(),"static.pdparams")
#保存静态图的模型
paddle.save(prog,"static.pdmodel")
  • 静态图模型的加载和初始化
#加载静态图的模型
prog = paddle.load("static.pdmodel")
#加载模型的参数
params = paddle.load("static.pdparams")
#初始化模型的参数
prog.set_state_dict(params)

动态图转静态图的模型保存和加载

为了便于构建模型和调试,我们通常会选择动态图的方式来构建模型,如果想要加快模型的训练效率以及方便在训练完成之后保存模型的结构,这时候我们可以将动态图转换成为静态图来解决这两个问题。
针对这种情况paddlepaddle提供了两种方式来实现:

  1. 先将动态图转换成为静态图模型进行训练,然后再保存
  2. 采用动态图进行训练,训练完成之后再保存模型
  • 动态图转静态图进行训练
    这种方法的优点就是通过将动态图转换为静态图进行训练,可以提升模型的训练效率,缺点就是不方便调试
    paddle提供了一种比较简单的方法,只需要通过paddle.jit.to_static来装饰forward方法即可,非常简单
#构建模型
class SimpleNet(nn.Layer):def __init__(self,input_size,num_classes=class_num):super(SimpleNet, self).__init__()self._linear = nn.Linear(input_size,num_classes)@paddle.jit.to_staticdef forward(self, x):output = self._linear(x)return output

然后保存模型的时候使用paddle.jit.save方法即可

paddle.jit.save(model,"model")

保存成功之后会生成三个文件model.pdiparamsmodel.pdiparams.infomodel.pdmodel,如果使用paddle.jit.to_static装饰了多个forward方法,则会生成多个模型文件。

如果想要让保存的模型能够支持动态输入,只需要指定InputSepc参数即可

from paddle.static import InputSpec
#构建模型
class SimpleNet(nn.Layer):def __init__(self,input_size,num_classes=class_num):super(SimpleNet, self).__init__()self._linear = nn.Linear(input_size,num_classes)@paddle.jit.to_static(input_spec=[InputSpec(shape=[None, input_size], dtype='float32')])def forward(self, x):output = self._linear(x)return output

模型的加载和预测

#加载模型
model = paddle.jit.load("model")
#输出模型参数
print(model.state_dict())
#构建一个模型的输入数据
input_array = np.array([[1,2],[3,4],[5,6]],dtype=np.float32)
inputs = paddle.to_tensor(input_array,place=paddle.CUDAPlace(0),stop_gradient=False,dtype=paddle.float32)
print(inputs)
#模型预测
predit = model(inputs)
print(predit)
  • 动态图训练保存模型
    相比于动态图转静态图进行训练而言,我们不需要给模型添加装饰方法,只需要使用paddle.jit.save来保存模型即可,在保存模型的时候只需要指定一下模型输入的shape即可
paddle.jit.save(model,"model",input_spec=[InputSpec(shape=[None,input_size],dtype="float32")])

注意:在使用Layer构建模型的时候,不要把loss的计算写到forward方法中

paddlepaddle模型的保存和加载相关推荐

  1. 线性回归之模型的保存和加载

    线性回归之模型的保存和加载 1 sklearn模型的保存和加载API from sklearn.externals import joblib   [目前这行代码报错,直接写import joblib ...

  2. numpy将所有数据变为0和1_PyTorch 学习笔记(二):张量、变量、数据集的读取、模组、优化、模型的保存和加载...

    一. 张量 PyTorch里面最基本的操作对象就是Tensor,Tensor是张量的英文,表示的是一个多维的矩阵,比如零维就是一个点,一维就是向量,二维就是一般的矩阵,多维就相当于一个多维的数组,这和 ...

  3. PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard

    文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...

  4. PyTorch | 模型的保存和加载

    PyTorch | 模型的保存和加载 一.模型参数的保存和加载 二.完整模型的保存和加载 一.模型参数的保存和加载 torch.save(module.state_dict(), path):使用mo ...

  5. pytorch模型的保存和加载、checkpoint

    pytorch模型的保存和加载.checkpoint 其实之前笔者写代码的时候用到模型的保存和加载,需要用的时候就去度娘搜一下大致代码,现在有时间就来整理下整个pytorch模型的保存和加载,开始学习 ...

  6. tensorflow 模型的保存和加载

    为了让训练结果可以复用,需要将训练得到的神经网络模型持久化,也就是把模型的参数保存下来,并保证可以持久化后的模型文件中还原出保存的模型. 1. 保存模型 tensorflow提供了一个API可以方便的 ...

  7. PyTorch基础-模型的保存和加载-09

    模型的保存 import numpy as np import torch from torch import nn,optim from torch.autograd import Variable ...

  8. 调gensim库,word2vec模型的保存和加载

    一.模型的保存 模型保存可以有很多种格式,根据格式的不同可以分为2种,一种是保存为.model的文件,一种是非.model文件的保存.我常用的保存格式是.model和.vector直接上代码和结果: ...

  9. 机器学习算法------2.11 模型的保存和加载(joblib.dump()、joblib.load())

    #  模型保存 joblib.dump(estimator, "./data/test.pkl") # 模型加载 estimator = joblib.load("./d ...

最新文章

  1. mysql关于时间的面试题,mysql时间设置默认值MySQL常见面试题
  2. python 异步io_python之同步IO和异步IO
  3. Nand Flash基础知识与坏块管理机制的研究
  4. 当数据库遇见FPGA:X-DB异构计算如何实现百万级TPS?
  5. 【Python】选择Python2还是Python3?
  6. C. Sum of Log(数位dp)
  7. JS:ES6-2 const 关键字
  8. 201521460005 实验五
  9. [原创]c# 加解密通用类
  10. 学习数据结构 AVL树
  11. mysql如何更新一个表中的某个字段值等于另一个表的某个字段值
  12. StarUml:Exception EOleSysError in module StarUML.ex
  13. 138译码器的工作原理
  14. 找工作杂谈(一)2019年春招复习资料总结
  15. JPA Native Query(本地查询)及查询结果转换
  16. dijkstra算法为什么不能计算负权重?
  17. EasyPoi的简介
  18. Oracle 字符集从GBK升级到Utf8
  19. springboot毕设项目社团管理系统7qls9(java+VUE+Mybatis+Maven+Mysql)
  20. 解决 docker 容器无法正常解析域名

热门文章

  1. 特征工程--特征离散化的意义
  2. 云原生WebAssembly应用程序已来
  3. 室内地图在哪些方面提升了我们的生活便利性?
  4. 懒人法解决IPTV和宽带的单线复用问题
  5. python爬虫爬汽车图片_python爬虫爬取汽车网站外型图片
  6. 让MyEclipse注册码不过期的方法 MyEclipse注册码
  7. Word分词标题 和JDK的contain的测试日志显示本地的笔记本 的效率基本上都是1秒以上,显然是Word分词标题 占优势,可是服务器上JDK与Word分析显然无区别,针对8W数据的检索
  8. 3D立方体旋转相册特效
  9. 计算机无法自动连接网络地址,ip地址错误网络无法连接怎么办-ip地址错误网络无法连接解决办法 - 河东软件园...
  10. java 傅里叶变换 频谱_傅里叶变换分析频谱(FFT)