这是Pytorch学习之路的第五篇

遇到问题

虽然已经知道了怎么保存已经训练好的网络模型,但是还是不知道怎么调用。其他博客中讲的有点简略,还需要自己摸索一下:

PyTorch要加载已经训练好的网络模型,需要保留什么代码,增加什么代码?

解决方法(只讨论仅加载参数的方法)

导入的库都不变,且只有测试模型前代码需要做改动:

import torch.nn as nn
import torch.nn.functional as F
#以下为需要保留的代码
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class CNNNet(nn.Module):def __init__(self):super(CNNNet, self).__init__()self.conv1 = nn.Conv2d(in_channels=3,out_channels=16,kernel_size=5,stride=1)self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)self.conv2 = nn.Conv2d(in_channels=16,out_channels=36,kernel_size=3,stride=1)self.pool2 = nn.MaxPool2d(kernel_size=2,stride=2)#self.aap = nn.AdaptiveAvgPool2d(1)self.fc1 = nn.Linear(1296,128)self.fc2 = nn.Linear(128,10)#self.fc3 = nn.Linear(36,10)def forward(self,x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))#x = self.aap(x)#x = x.view(x.shape[0],-1)#x = self.fc3(x)x = x.view(-1,36*6*6)#print("x.shape:{}".format(x.shape))x = F.relu(self.fc2(F.relu(self.fc1(x))))return xmodel = CNNNet()#以下为新增代码
model.load_state_dict(torch.load('./model/model.pth'))#再加载网络的参数
model = model.to(device)
print("load success")

注意

model = torch.load('./model/model.pth')

会报错

原因未知。

效果

成功

灵感来源

  1. pytorch:无法加载CNN模型并做预测TypeError:'collections. OrderedDict’对象不可调用(转载)
  2. Pytorch文档阅读(五)如何保存、加载网络模型(转载)

PyTorch如何加载已经训练好的网络模型相关推荐

  1. PyTorch 加载预训练权重

    前言  使用PyTorch官方提供的权重或者其他第三方提供的权重对相同模型的参数进行初始化,在数据量较少的前提下,可以帮助模型更快地收敛到最优点,达到更好的效果,即迁移学习.  在大部分的迁移学习场景 ...

  2. 【Pytorch】加载torchvision中预训练好的模型并修改默认下载路径(使用models.__dict__[model_name]()读取)

    说明 使用torchvision.model加载预训练好的模型时,发现默认下载路径在系统盘下面的用户目录下(这个你执行的时候就会发现),即C:\用户名\.cache\torch\.checkpoint ...

  3. Pytorch 词嵌入word_embedding2实例(加载已训练词向量)

    目录 1.加载已训练好的词嵌入 2.是否需要重新训练词嵌入 3.不重新训练词嵌入时优化器设置

  4. torch编程-加载预训练权重-模型冻结-解耦-梯度不反传

    1)加载预训练权重 net = torchvision.models.resnet50(pretrained=False) # 构建模型 pretrained_model = torch.load(p ...

  5. PyTorch中加载模型权重

    在做深度学习项目时,从头训练一个模型是需要大量时间和算力的,我们通常采用加载预训练权重的方法,而我们往往面临以下几种情况: #mermaid-svg-freoBrrdezozjyan {font-fa ...

  6. PyTorch数据加载处理

    PyTorch数据加载处理 PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 • scikit-image:用于图像的IO和变换 • pandas:用于更容易地进行 ...

  7. 使用torchvision.models.inception_v3(pretrained=True)加载预训练的模型每次都特别慢

    欢迎大家关注笔者,你的关注是我持续更博的最大动力 原创文章,转载告知,盗版必究 使用torchvision.models.inception_v3(pretrained=True)加载预训练的模型每次 ...

  8. PyTorch数据加载器

    We'll be covering the PyTorch DataLoader in this tutorial. Large datasets are indispensable in the w ...

  9. java加载tensorflow训练的PB模型记录

    java加载tensorflow训练的PB模型记录 python训练 1. 模型的输入输出定义 2. 训练时保存模型的方法 java加载模型 1.maven依赖 2. Java代码实例 tensor注 ...

最新文章

  1. useGeneratedKeys的详解
  2. 441. Arranging Coins
  3. linux修改ssh登陆端口号,Linux 6 修改ssh默认远程端口号的操作步骤
  4. Qt Creator导入3D资产Importing 3D Assets
  5. Windows+Caffe(Faster RCNN/RFCN/SSD)编译(Cuda7.5+Cuda8.0)未完待续
  6. predis如何实现phpredis的pconnect方法
  7. iOS当中的设计模式
  8. 2021企业直播新观察——市场升温蕴藏机会,消费场景左右未来
  9. ADOQuery 的几个事件
  10. PyTorch 1.0 中文官方教程:用例子学习 PyTorch
  11. C#笔记05 方法和参数
  12. UnityShader7:内置包含文件UnityCG.cginc与GG/HLSL语义
  13. 拓端tecdat|R语言对股票风险“溃疡指数”( Ulcer Index)曲面图可视化
  14. React Native重构路线图发布!
  15. weblogic错误页面
  16. 10分钟临时邮箱,无限邮箱
  17. 数据库设计 资源表与资源收藏表的设计
  18. 网络安全工程师的入门学习的路径
  19. 硕士毕业,两年北漂算法工程打工生活【上】
  20. 不求人,手把手教你学会微信WIFI!

热门文章

  1. 查看电脑主机ID的两种方法
  2. web前端 运用CSS动画 实现图像的幻灯片 自动播放 切换效果
  3. 物理渗透-Mifare Classic S50(IC)卡分析
  4. 计算机科学与技术导师保研推荐信,研究生推荐信共篇
  5. 电子工程类职称包含计算机专业吗,电子信息工程专业技术职称
  6. 使用SQLMonitor监视访问ORACLE的“服务”
  7. 1024| 只为程序员们打Call(多重福利)
  8. android crosswalk 集成
  9. Java的定时器Timer和定时任务TimerTask应用以及原理简析
  10. WIN10安装DB2详细教程(附安装文件)