1、序列化与反序列化

模型的保存与加载就是序列化与反序列化,序列化与反序列化主要将内存与硬盘之间的数据转换关系,模型在内存中以对象的形式存储,在内存中对象不能长久地保存,所以需要将训练好的模型保存到硬盘中。

在硬盘中数据是以0-1的二进制形式保存的,这就是二进制序列。序列化是指将内存中的某个对象保存到硬盘当中,以二进制序列的形式进行保存。

反序列化是将存储的二进制数据传输到内存中形成对象,这样就可以在内存中使用该模型。序列化和反序列化的目的是将数据长久地保存。

2、模型保存与加载的两种方式

Pytorct中的序列化与反序列化

2.1 torch.save

主要参数

  • obj:对象,可以是模型,可以是张量,可以是参数,只要是python的对象都可以进行保存;
  • f:输出路径;

两种保存方法(官方推荐是第二种方法,比较节省内存)

  1. 保存整个Module:
torch.save(net,path)
  1. 保存模型参数:
state_dict = net.state_dict()
torch.save(state_dict,path)

通过代码学习该方法:

import torch
import numpy as np
import torch.nn as nn
from toolss.common_tools import set_seedclass LeNet2(nn.Module):def __init__(self, classes):super(LeNet2, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 6, 5),nn.ReLU(),nn.MaxPool2d(2, 2),nn.Conv2d(6, 16, 5),nn.ReLU(),nn.MaxPool2d(2, 2))self.classifier = nn.Sequential(nn.Linear(16*5*5, 120),nn.ReLU(),nn.Linear(120, 84),nn.ReLU(),nn.Linear(84, classes))def forward(self, x):x = self.features(x)x = x.view(x.size()[0], -1)x = self.classifier(x)return xdef initialize(self):for p in self.parameters():p.data.fill_(20191104)net = LeNet2(classes=2019)# "训练"
print("训练前: ", net.features[0].weight[0, ...])
net.initialize()
print("训练后: ", net.features[0].weight[0, ...])# 保存整个模型
path_model = "./model.pkl"  # 保存路径
torch.save(net, path_model)# 保存模型参数
net_state_dict = net.state_dict()
path_state_dict = "./model_state_dict.pkl"  # 保存路径
torch.save(net_state_dict, path_state_dict)

2.2 torch.hold

主要参数

  • f:文件路径;
  • map_location:指定存放位置,cpu or gpu;

暂时先不管参数map_location,观察torch.hold在Pytorch中的两种使用方法:

import torch
import numpy as np
import torch.nn as nn
from toolss.common_tools import set_seedclass LeNet2(nn.Module):def __init__(self, classes):super(LeNet2, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 6, 5),nn.ReLU(),nn.MaxPool2d(2, 2),nn.Conv2d(6, 16, 5),nn.ReLU(),nn.MaxPool2d(2, 2))self.classifier = nn.Sequential(nn.Linear(16*5*5, 120),nn.ReLU(),nn.Linear(120, 84),nn.ReLU(),nn.Linear(84, classes))def forward(self, x):x = self.features(x)x = x.view(x.size()[0], -1)x = self.classifier(x)return xdef initialize(self):for p in self.parameters():p.data.fill_(20191104)# ================================== 加载整个模型 ===========================
# flag = 1
flag = 0
if flag:path_model = "./model.pkl"net_load = torch.load(path_model)print(net_load)# ================================== 加载模型参数 ===========================flag = 1
# flag = 0
if flag:path_state_dict = "./model_state_dict.pkl"state_dict_load = torch.load(path_state_dict)print(state_dict_load.keys())# ================================== 更新模型参数 ===========================
flag = 1
# flag = 0
if flag:net_new = LeNet2(classes=2019)print("加载前: ", net_new.features[0].weight[0, ...])net_new.load_state_dict(state_dict_load)print("加载后: ", net_new.features[0].weight[0, ...])

3、模型断点续训练

断点续训练可以解决因某种原因导致模型训练中断而需要重新训练的问题,断点续训练可以在训练过程中保存模型参数,以备在训练中断之后可以接着训练模型,断点续训练会保存模型的参数,优化器的参数。

checkpoint = {"model_state_dict":net_state_dict(),"optimizer_state_dict":optimizer_state_dict,"epoch":epoch}

下面通过代码学习模型断点续训练在代码中的实际使用:

import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from PIL import Image
from matplotlib import pyplot as plt
from model.lenet import LeNet
from toolss.my_dataset import RMBDataset
from toolss.common_tools import set_seed
import torchvisionset_seed(1)  # 设置随机种子
rmb_label = {"1": 0, "100": 1}# 参数设置
checkpoint_interval = 5
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1# ============================ step 1/5 数据 ============================split_dir = os.path.join("F:/Pytorch框架班/Pytorch-Camp-master/代码合集/rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]train_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.RandomCrop(32, padding=4),transforms.RandomGrayscale(p=0.8),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])valid_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)# ============================ step 2/5 模型 ============================net = LeNet(classes=2)
net.initialize_weights()# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)     # 设置学习率下降策略# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()start_epoch = -1
for epoch in range(start_epoch+1, MAX_EPOCH):loss_mean = 0.correct = 0.total = 0.net.train()for i, data in enumerate(train_loader):# forwardinputs, labels = dataoutputs = net(inputs)# backwardoptimizer.zero_grad()loss = criterion(outputs, labels)loss.backward()# update weightsoptimizer.step()# 统计分类情况_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).squeeze().sum().numpy()# 打印训练信息loss_mean += loss.item()train_curve.append(loss.item())if (i+1) % log_interval == 0:loss_mean = loss_mean / log_intervalprint("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))loss_mean = 0.scheduler.step()  # 更新学习率if (epoch+1) % checkpoint_interval == 0:  # 这里保存断点续训练参数checkpoint = {"model_state_dict": net.state_dict(),"optimizer_state_dict": optimizer.state_dict(),"epoch": epoch}path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)torch.save(checkpoint, path_checkpoint)if epoch > 5:print("训练意外中断...")break# validate the modelif (epoch+1) % val_interval == 0:correct_val = 0.total_val = 0.loss_val = 0.net.eval()with torch.no_grad():for j, data in enumerate(valid_loader):inputs, labels = dataoutputs = net(inputs)loss = criterion(outputs, labels)_, predicted = torch.max(outputs.data, 1)total_val += labels.size(0)correct_val += (predicted == labels).squeeze().sum().numpy()loss_val += loss.item()valid_curve.append(loss.item())print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val/len(valid_loader), correct / total))train_x = range(len(train_curve))
train_y = train_curvetrain_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curveplt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

保存断点续训练参数之后,如果要将保存的数据恢复到模型中,可以参考以下代码:

import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from PIL import Image
from matplotlib import pyplot as plt
from model.lenet import LeNet
from toolss.my_dataset import RMBDataset
from toolss.common_tools import set_seed
import torchvisionset_seed(1)  # 设置随机种子
rmb_label = {"1": 0, "100": 1}# 参数设置
checkpoint_interval = 5
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1# ============================ step 1/5 数据 ============================split_dir = os.path.join("F:/Pytorch框架班/Pytorch-Camp-master/代码合集/rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]train_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.RandomCrop(32, padding=4),transforms.RandomGrayscale(p=0.8),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])valid_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)# ============================ step 2/5 模型 ============================net = LeNet(classes=2)
net.initialize_weights()# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)     # 设置学习率下降策略# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()start_epoch = -1
for epoch in range(start_epoch+1, MAX_EPOCH):loss_mean = 0.correct = 0.total = 0.net.train()for i, data in enumerate(train_loader):# forwardinputs, labels = dataoutputs = net(inputs)# backwardoptimizer.zero_grad()loss = criterion(outputs, labels)loss.backward()# update weightsoptimizer.step()# 统计分类情况_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).squeeze().sum().numpy()# 打印训练信息loss_mean += loss.item()train_curve.append(loss.item())if (i+1) % log_interval == 0:loss_mean = loss_mean / log_intervalprint("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))loss_mean = 0.scheduler.step()  # 更新学习率if (epoch+1) % checkpoint_interval == 0:checkpoint = {"model_state_dict": net.state_dict(),"optimizer_state_dict": optimizer.state_dict(),"epoch": epoch}path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)torch.save(checkpoint, path_checkpoint)if epoch > 5:print("训练意外中断...")break# validate the modelif (epoch+1) % val_interval == 0:correct_val = 0.total_val = 0.loss_val = 0.net.eval()with torch.no_grad():for j, data in enumerate(valid_loader):inputs, labels = dataoutputs = net(inputs)loss = criterion(outputs, labels)_, predicted = torch.max(outputs.data, 1)total_val += labels.size(0)correct_val += (predicted == labels).squeeze().sum().numpy()loss_val += loss.item()valid_curve.append(loss.item())print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val/len(valid_loader), correct / total))train_x = range(len(train_curve))
train_y = train_curvetrain_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curveplt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

Pytorch —— 模型保存与加载相关推荐

  1. pytorch模型保存与加载总结

    pytorch模型保存与加载总结 模型保存与加载方式 模型保存 方式一 只存储模型中的参数,该方法速度快,占用空间少(官方推荐使用) model = VGGNet() torch.save(model ...

  2. PyTorch模型保存与加载

    torch.save:保存序列化的对象到磁盘,使用了Python的pickle进行序列化,模型.张量.所有对象的字典. torch.load:使用了pickle的unpacking将pickled的对 ...

  3. tensor和模型 保存与加载 PyTorch

    PyTorch教程-7:PyTorch中保存与加载tensor和模型详解 保存和读取Tensor PyTorch中的tensor可以保存成 .pt 或者 .pth 格式的文件,使用torch.save ...

  4. PyTorch系列入门到精通——模型保存与加载

    PyTorch系列入门到精通--模型保存与加载

  5. 飞桨框架2.0RC新增模型保存、加载方案,与用户场景完美匹配,更全面、更易用

    通过一段时间系统的课程学习,算法攻城狮张同学对于飞桨框架的使用越来越顺手,于是他打算在企业内尝试使用飞桨进行AI产业落地. 但是AI产业落地并不是分秒钟的事情,除了专业技能过硬,熟悉飞桨的使用外,在落 ...

  6. 机器学习之模型——保存与加载

    机器学习之模型--保存与加载 知识点 fit() transform() fit_transform() 目的 API 流程 获取数据 划分数据集 标准化 预估器 保存模型 加载模型 得出模型 模型评 ...

  7. TensorFlow2.0 —— 模型保存与加载

    目录 1.Keras版本模型保存与加载 2.自定义版本模型保存与加载 3.总结 1.Keras版本模型保存与加载 保存模型权重(model.save_weights) 保存HDF5文件(model.s ...

  8. gensim bm25模型保存与加载

    gensim bm25模型保存与加载 1. 模型保存 2. 模型加载 20210719修改: python version:3.6.12 gensim version:3.8.3 使用bm25模型计算 ...

  9. [tensorflow] 模型保存、加载与转换详解

    TensorFlow模型加载与转换详解 本次讲解主要涉及到TensorFlow框架训练时候模型文件的管理以及转换. 首先我们需要明确TensorFlow模型文件的存储格式以及文件个数: model_f ...

最新文章

  1. 健身环爆打老头环!超高难度击败boss,宫崎英高估计也想不到,代码+硬件教程已开源...
  2. 大数据量涉及算法及常见问题
  3. 关于各种JOIN连接的解释说明【原创】
  4. javascript立即调用的函数表达式
  5. Android Handler的使用方法
  6. HDU1859 最小长方形【水题】
  7. ELK在广告系统监控中的应用 及 Elasticsearch简介
  8. 你连问问题都不会?怎么学编程!!!
  9. Laravel下用户权限系统 的解决方案Zizaco/Entrust
  10. 虚拟存储页面置换算法c语言,虚拟存储器管理页面置换算法模拟实验.doc
  11. MySQL索引-视频+图文详解
  12. 【开发教程14】AI语音人脸识别(会议记录仪/人脸打卡机)-AI人脸系统架构
  13. WIFI驱动配置实战(Linux驱动开发篇)
  14. 亚马逊EC2使用账号密码登录
  15. linux内核设置mac地址,Linux下如何修改网卡MAC地址
  16. seu校园网简易自动登陆教程
  17. 叉姐的魔法训练小结(未完结)
  18. Jolla 超额完成开源平板电脑众筹
  19. FD快餐店迭代三设计文档
  20. js vue 设置excel单元格样式_vue项目使用xlsx-style实现前端导出Excel样式修改(添加标题,边框等),并且上传npm踩坑记录...

热门文章

  1. .net知识和学习方法系列(二十)CLR-委托
  2. 略论bs架构设计的几种模式
  3. 面试官系统精讲Java源码及大厂真题 - 16 ConcurrentHashMap 源码解析和设计思路
  4. 面试官系统精讲Java源码及大厂真题 - 03 Java 常用关键字理解
  5. 容器编排技术 -- Kubernetes kubectl set 命令详解
  6. Linux crontab下关于使用date命令和sudo命令的坑
  7. 04737 c++ 自学考试2019版 第六章课后练习 程序设计题 1
  8. 编写高质量代码的50条黄金守则-Day 01(首选隐式类型转换)
  9. C#LeetCode刷题之#860-柠檬水找零(Lemonade Change)
  10. C#LeetCode刷题之#414-第三大的数(Third Maximum Number)