pytorch 保存模型,加载预训练模型问题
在写pytorch代码时,遇到问题:加载预训练模型时在验证集上测试的psnr结果与训练时验证集的psnr差异特别大。
源代码:
pretrained_dict = torch.load('epochG_515.pth')
net.load_state_dict(pretrained_dict)
net = prepare(net)
valdata = Data(root=os.path.join(args.dir_data, args.data_val), args=args, train=False)
valset = DataLoader(valdata, batch_size=1, shuffle=False, num_workers=1)val_psnr = 0
val_ssim = 0
with torch.no_grad():timer_test = util.timer()for batch, (lr, hr, filename) in enumerate(valset):lr, hr = prepare(lr), prepare(hr)sr = net(lr)print(sr.shape, hr.shape)val_psnr = val_psnr + cal_psnr(hr[0].data.cpu(), sr[0].data.cpu())val_ssim = val_ssim + cal_ssim(hr[0].data.cpu(), sr[0].data.cpu())print("Test psnr: {:.3f}".format(val_psnr / (len(valset))))print('Forward: {:.2f}s\n'.format(timer_test.toc()))print(val_ssim / (len(valset)))
后面修改在测试前加上net.eval()
pretrained_dict = torch.load('epochG_515.pth')
net.load_state_dict(pretrained_dict)
net = prepare(net)
valdata = Data(root=os.path.join(args.dir_data, args.data_val), args=args, train=False)
valset = DataLoader(valdata, batch_size=1, shuffle=False, num_workers=1)val_psnr = 0
val_ssim = 0
with torch.no_grad():net.eval()timer_test = util.timer()for batch, (lr, hr, filename) in enumerate(valset):lr, hr = prepare(lr), prepare(hr)sr = net(lr)print(sr.shape, hr.shape)val_psnr = val_psnr + cal_psnr(hr[0].data.cpu(), sr[0].data.cpu())val_ssim = val_ssim + cal_ssim(hr[0].data.cpu(), sr[0].data.cpu())print("Test psnr: {:.3f}".format(val_psnr / (len(valset))))print('Forward: {:.2f}s\n'.format(timer_test.toc()))print(val_ssim / (len(valset)))
pytorch 保存模型,加载预训练模型问题相关推荐
- Tensorflow保存模型和加载预训练模型
训练好的模型需要保存下来或者加载已经训练完成的模型,就用到了ckpt文件. 目录 1.了解tensorflow保存的文件 (1)checkpoint (2)MyModel.meta (3)MyMode ...
- Pytorch网络模型权重初始化、保存与加载模型、加载预训练模型、按需设置学习率
前言 在我们对神经网络模型进行训练时,往往需要对模型进行初始化或者加载预训练模型.本文将对模型的权重初始化与加载预训练模型做一个学习记录,以便后续查询使用. 权重初始化 常见的初始化方法 PyTorc ...
- pytorch:加载预训练模型(多卡加载单卡预训练模型,多GPU,单GPU)
在pytorch加载预训练模型时,可能遇到以下几种情况. 分为以下几种 在pytorch加载预训练模型时,可能遇到以下几种情况. 1.多卡训练模型加载单卡预训练模型 2. 多卡训练模型加载多卡预训练模 ...
- HuggingFace学习3:加载预训练模型完成机器翻译(中译英)任务
加载模型页面为:https://huggingface.co/liam168/trans-opus-mt-zh-en 文章目录 整理文件 跑通程序,测试预训练模型 拆解Pipeline,逐步进行翻译任 ...
- MXNet快速入门之训练加载预训练模型(四)
前言 在前面几篇文章中详细介绍了MXNet的一些特点以及入门基础知识,本篇文章主要介绍如何使用MXNet来训练模型.加载模型进行预测.预训练模型以及MXNet中GPU使用的相关知识. 在介绍训练模型之 ...
- keras冻结_Keras 实现加载预训练模型并冻结网络的层
在解决一个任务时,我会选择加载预训练模型并逐步fine-tune.比如,分类任务中,优异的深度学习网络有很多. ResNet, VGG, Xception等等... 并且这些模型参数已经在imagen ...
- 加载预训练模型时报错 KeyError: param ‘initial_lr‘ is not specified in param_groups[0]
在加载预训练模型继续训练时,程序报错:KeyError: "param 'initial_lr' is not specified in param_groups[0] when resum ...
- Caffe2教程实例,加载预训练模型
Caffe2教程实例,加载预训练模型 概述 本教程使用模型库中的预训练模型squeezenet 里分类我们自己的图片.我们需要提供要分类图片的路径或者URL信息作为输入.了解ImageNet对象代码可 ...
- pytorch 保存、加载模型
一般保存为.pt格式,保存模型使用: torch.save(model, '保存位置') 加载模型使用: model_load = torch.load('加载模型的位置') 完整代码 import ...
- 解决HuggingFace加载预训练模型时报错TypeError: expected str, bytes or os.PathLike object, not NoneType
完整报错: TypeError: expected str, bytes or os.PathLike object, not NoneType 解决方法 检查下载的组件: 步骤1:完整的下载组件,包 ...
最新文章
- Java网络编程之简单UDP通信
- kettle、Oozie、camus、gobblin
- C++ 双向链表的建立与遍历
- html 文字输出语音,html 录音与文本转语音demo
- 给不会调用C++STL库中二分函数lower_bound,upper_bound,binary_search同学的一些话!
- 水滴石穿C语言之extern声明辨析
- ios进度条Demo一个
- 使用windbg通过vtable找到优化后的this指针
- python代码运行助手下载_Python自学:使用代码运行助手
- .bin文件如何打开并使用
- Java字节码技术javassist
- 「硬见小百科」几种镜像恒流源电路分析
- 【论文】论文中的参考文献:国标GB/T 7714-2015文献类型与格式
- python 病毒 基因_#Python#提取基因对应的蛋白质名
- 一个计算机专业学生几年的Java编程经验汇总
- Spring 实战 第4版 读书笔记
- 【Visual-Hull + Bregman】基于Visual-Hull + Bregman算法的三维重建算法matlab仿真
- Java笔记(P393/234-P400/241)
- 机器学习和人工智能有什么关系?
- 中国大学生年度人物评选
热门文章
- 考拉升级https经验
- OVM 免费虚拟化软件迭代时间调整,提高产品稳定性!
- 在同个工程中使用 Swift 和 Objective-C(Swift 2.0更新)-b
- 有了WCF,Socket是否已人老珠黄?
- STC学习:串口通信
- memset函数的使用
- python struct_struct
- 为什么训练时测试准确率大幅度波动_Nature Mach Intell|类药性预测准确率有极限...
- java数组清空能释放jvm内存嘛_JVM面试题汇总
- 伙伴算法的核心思想是回收时进行相邻块的合并_Linux内存管理之伙伴算法