'''本文件用于举例说明pytorch保存和加载文件的方法'''import torch as torch
import torchvision as tv
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import os# 参数声明
batch_size = 32
epochs = 10
WORKERS = 0  # dataloder线程数
test_flag = False  # 测试标志,True时加载保存好的模型进行测试
ROOT = '/home/pxt/pytorch/cifar'  # MNIST数据集保存路径
log_dir = '/home/pxt/pytorch/logs/cifar_model.pth'  # 模型保存路径
# 加载MNIST数据集
transform = tv.transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform)
test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform)train_load = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=WORKERS)
test_load = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=WORKERS)# 构造模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 64, 3, padding=1)self.conv2 = nn.Conv2d(64, 128, 3, padding=1)self.conv3 = nn.Conv2d(128, 256, 3, padding=1)self.conv4 = nn.Conv2d(256, 256, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(256 * 8 * 8, 1024)self.fc2 = nn.Linear(1024, 256)self.fc3 = nn.Linear(256, 10)def forward(self, x):x = F.relu(self.conv1(x))x = self.pool(F.relu(self.conv2(x)))x = F.relu(self.conv3(x))x = self.pool(F.relu(self.conv4(x)))x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3])x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xmodel = Net().cpu()criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 模型训练
def train(model, train_loader, epoch):model.train()train_loss = 0for i, data in enumerate(train_loader, 0):x, y = datax = x.cpu()y = y.cpu()optimizer.zero_grad()y_hat = model(x)loss = criterion(y_hat, y)loss.backward()optimizer.step()train_loss += lossprint('正在进行第{}个epoch中的第{}次循环'.format(epoch,i))loss_mean = train_loss / (i + 1)print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))# 模型测试
def test(model, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for i, data in enumerate(test_loader, 0):x, y = datax = x.cpu()y = y.cpu()optimizer.zero_grad()y_hat = model(x)test_loss += criterion(y_hat, y).item()pred = y_hat.max(1, keepdim=True)[1]correct += pred.eq(y.view_as(pred)).sum().item()test_loss /= (i + 1)print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_data), 100. * correct / len(test_data)))def main():# 如果test_flag=True,则加载已保存的模型并进行测试,测试以后不进行此模块以后的步骤if test_flag:# 加载保存的模型直接进行测试机验证checkpoint = torch.load(log_dir)model.load_state_dict(checkpoint['model'])optimizer.load_state_dict(checkpoint['optimizer'])start_epoch = checkpoint['epoch']test(model, test_load)return# 如果有保存的模型,则加载模型,并在其基础上继续训练if os.path.exists(log_dir):checkpoint = torch.load(log_dir)model.load_state_dict(checkpoint['model'])optimizer.load_state_dict(checkpoint['optimizer'])start_epoch = checkpoint['epoch']print('加载 epoch {} 成功!'.format(start_epoch))else:start_epoch = 0print('无保存了的模型,将从头开始训练!')for epoch in range(start_epoch+1, epochs):train(model, train_load, epoch)test(model, test_load)# 保存模型state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}torch.save(state, log_dir)if __name__ == '__main__':main()

pytorch保存和加载文件的方法,从断点处继续训练相关推荐

  1. python torch exp_Python:PyTorch 保存和加载训练过的网络 (八十)

    保存和加载模型 在这个 notebook 中,我将为你展示如何使用 Pytorch 来保存和加载模型.这个步骤十分重要,因为你一定希望能够加载预先训练好的模型来进行预测,或是根据新数据继续训练. %m ...

  2. PyTorch | 保存和加载模型教程

    点击上方"算法猿的成长",选择"加为星标" 第一时间关注 AI 和 Python 知识 图片来自 Unsplash,作者: Jenny Caywood 2019 ...

  3. Pytorch 保存和加载模型

    当保存和加载模型时,需要熟悉三个核心功能: 1. torch.save :将序列化对象保存到磁盘.此函数使用Python的 pickle 模块进行序列化.使 用此函数可以保存如模型.tensor.字典 ...

  4. egret白鹭引擎保存加载文件到本地的实现方案

    最近几天我在做游戏场景编辑器开发,所以有一些绘画内容编辑的内容需要存储文件到硬盘中,最好是用户可以进行保存自己编辑的项目到硬盘然后随时打开项目文件,就想flash的fla文件一样,可以存储到u盘在u盘 ...

  5. pytorch保存模型的两种方法

    文章目录 前言 一.保存整个模型 二.只保存参数 模型不同后缀名的区别 总结 前言 模型的本质是一堆用某种结构存储起来的参数 用数据对模型进行训练后得到了比较理想的模型,就需要将其存储起来,然后在需要 ...

  6. pytorch保存和加载模型state_dict

    保存模型: torch.save({'epoch': epoch + 1,'state_dict': model.state_dict(),'optimizer': optimizer.state_d ...

  7. 保存和加载模型的方法

    目录 保存模型权重 保存整个模型 保存模型权重 1. 使用回调函数保存 2. 手动保存 这种是在model.fit时传入保存checkpoint的回调函数.使用的回调函数是tf.keras.callb ...

  8. 加载dict_PyTorch 7.保存和加载pytorch模型的两种方法

    众所周知,python的对象都可以通过torch.save和torch.load函数进行保存和加载(不知道?那你现在知道了(*^_^*)),比如: x1 = {"d":" ...

  9. pytorch保存模型pth_Day159:模型的保存与加载

    网络结构和参数可以分开的保存和加载,因此,pytorch保存模型有两种方法: 保存 整个模型 (结构+参数) 只保存模型参数(官方推荐) # 保存整个网络torch.save(model, check ...

最新文章

  1. Openssl req命令
  2. python使用matplotlib可视化不同年份、不同项目的均值(mean)对比条形图(bar plot comparision of mean with different years)
  3. MVC--Razor(2)
  4. python线性回归模型预处理_线性回归-2 数据预处理与模型验证评估
  5. php数组为什么其他语言,PHP语言特性和各版本的差异
  6. 换掉flash,flex,用FlashDevelop
  7. 计算机网络-思维导图(2)物理层
  8. AcWing 253. 普通平衡树
  9. SpringBoot部署项目到Linux上传文件路径问题
  10. Linux系统(四) echo和重定向、管道的概念和使用
  11. python入门教程(非常详细),从零基础入门到精通,看完这一篇就够了
  12. word添加自定义样式(导入normal.dotm)
  13. linux下mysql免安装_linux下免安装版本mysql5.5 配置
  14. 微信开放平台和微信公众平台配置流程简介,
  15. 跨平台应用开发进阶(四) :uni-app 实现上传图片
  16. Python数据挖掘——文本分析
  17. 配置VS2019 执行cu文件
  18. 百度地图上定位自己所在的位置
  19. 华硕顽石第四代FL5900u拆机换电池
  20. Golang iota详解

热门文章

  1. “无法找到运行搜索助理需要的一个文件”的解决办法
  2. SpringBoot2 整合FreeMarker模板,完成页面静态化处理
  3. 大数据之Azkaban部署
  4. hadoop--历史服务器配置
  5. 【java SOAP】对SOAP的一个个人印象
  6. Python数据清洗基本流程
  7. 怎么发现RAC环境中#39;library cache pin#39;等待事件的堵塞者(Blocker)?
  8. 快速判断一个数是否是2的幂次方
  9. workstation虚拟机详尽教程
  10. Suofanker:为什么搜索引擎重视链接