Pytorch加载模型并进行图像分类预测
目录
1. 整体流程
1)实例化模型
2)加载模型
3)输入图像
4)输出分类结果
5)完整代码
2. 处理图像
1) How can i convert an RGB image into grayscale in Python?
2)PIL 处理图像的基本操作
3)图像通道数的理解
4)Convert 3 channel image to 2 channel
5)图像通道转换
6)将所有的图像合并为一个numpy数组
7)torch.from_numpy VS torch.Tensor
8)torch.squeeze() 和torch.unsqueeze()
3.issue
1)TypeError: 'module' object is not callable
2)TypeError: 'collections.OrderedDict' object is not callable
3) TypeError: __init__() missing 1 required positional argument: 'XX'
4) RuntimeError: Error(s) in loading state_dict for PythonNet: Missing key(s) in state_dict:
5) RuntimeError: Expected 4-dimensional input for 4-dimensional weight [128, 1, 3, 3], but got 2-dimensional input of size [480, 640] instead
6)RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
1. 整体流程
1)实例化模型
Assume that the content of YourClass.py is:
class YourClass:# ......
If you use:
from YourClassParentDir import YourClass # means YourClass
from model import PythonNetnet= PythonNet(T=16).eval().cuda()
2)加载模型
import torchnet.load_state_dict(torch.load('checkpoint_max.pth'),False)
3)输入图像
目的:从文件夹中加载所有图像组合为一个 numpy array 作为模型输入
原始图像输入维度:(480,640)
目标图像输入维度:(16,1,128,128)
import glob
from PIL import Image
import numpy as np
from torch.autograd import Variable#获取图像路径
filelist = glob.glob('./testdata/a/*.jpg')#打开图像open('frame_path')--》转换为灰度图convert('L')--》缩放图像resize((width, height)) --》合并文件夹中的所有图像为一个numpy array
x = np.array([np.array(Image.open(frame).convert('L').resize((128,128))) for frame in filelist])#用torch.from_numpy这个方法将numpy类转换成tensor类
x = torch.from_numpy(x).type(torch.FloatTensor).cuda()#扩充数据维度
x = Variable(torch.unsqueeze(x,dim=1).float(),requires_grad=False)
4)输出分类结果
outputs = net(x)
_, predicted = torch.max(outputs,1)
torch.max()这个函数返回输入张量中所有元素的最大值。
返回一个命名元组(values,indices),其中values是给定维dim中输入张量的每一行的最大值。indices是找到的每个最大值的索引位置(argmax)。也就是说,返回的第一个值是对应图像在类别中的最大概率值,第二个值是最大概率值的对应类别。
Pytorch 分类问题输出结果的数据整理方式:_, predicted = torch.max(outputs.data, 1) - stardsd - 博客园
5)完整代码
from PIL import Image
from torch.autograd import Variable
import numpy as np
import torch
import glob
from model import PythonNet##############处理输入图像#######################################
#获取图像路径
filelist = glob.glob('./testdata/a/*.jpg')#打开图像open('frame_path')--》转换为灰度图convert('L')--》缩放图像resize((width, height)) --》合并文件夹中的所有图像为一个numpy array
x = np.array([np.array(Image.open(frame).convert('L').resize((128,128))) for frame in filelist])#用torch.from_numpy这个方法将numpy类转换成tensor类
x = torch.from_numpy(x).type(torch.FloatTensor).cuda()#扩充数据维度
x = Variable(torch.unsqueeze(x,dim=1).float(),requires_grad=False)#############定义预测函数#######################################
def predict(x):net= PythonNet(T=16).eval().cuda()net.load_state_dict(torch.load('checkpoint_max.pth'),False)outputs = net(x)_, predicted = torch.max(outputs,1)print("_:",_)print("predicted:",predicted)print("outputs:",outputs)############输入图像进行预测#####################################
predict(x)
2. 处理图像
1) How can i convert an RGB image into grayscale in Python?
matplotlib - How can I convert an RGB image into grayscale in Python? - Stack Overflow
2)PIL 处理图像的基本操作
python——PIL Image处理图像_aaon22357的博客-CSDN博客
3)图像通道数的理解
关于图像通道的理解 | TinaCristal's Blog
4)Convert 3 channel image to 2 channel
python - I have converted 3 channel RGB image into 2 channels grayscale image, How to decrease greyscale channels to 1? - Stack Overflow
5)图像通道转换
图像通道转换——从np.ndarray的[w, h, c]转为Tensor的[c, w, h]_莫邪莫急的博客-CSDN博客
6)将所有的图像合并为一个numpy数组
python — 如何在numpy数组中加载多个图像?
7)torch.from_numpy VS torch.Tensor
torch.from_numpy VS torch.Tensor_麦克斯韦恶魔的博客-CSDN博客
8)torch.squeeze() 和torch.unsqueeze()
pytorch学习 中 torch.squeeze() 和torch.unsqueeze()的用法_xiexu911的博客-CSDN博客_torch.unsqueeze
3.issue
1)TypeError: 'module' object is not callable
python - TypeError: 'module' object is not callable - Stack Overflow
2)TypeError: 'collections.OrderedDict' object is not callable
pytorch加载模型报错TypeError: ‘collections.OrderedDict‘ object is not callable_xiaoqiaoliushuiCC的博客-CSDN博客
3) TypeError: __init__() missing 1 required positional argument: 'XX'
Python成功解决TypeError: __init__() missing 1 required positional argument: ‘comment‘_肥鼠路易的博客-CSDN博客
4) RuntimeError: Error(s) in loading state_dict for PythonNet: Missing key(s) in state_dict:
pytorch加载模型报错RuntimeError: Error(s) in loading state_dict for SSD:Missing key(s) in state_dict:... - 代码先锋网
5) RuntimeError: Expected 4-dimensional input for 4-dimensional weight [128, 1, 3, 3], but got 2-dimensional input of size [480, 640] instead
RuntimeError: Expected 4-dimensional input for 4-dimensional weight 64 3 3, but got 3-dimensional in_Steven_ycs的博客-CSDN博客
6)RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.DoubleTensor) should be the_一千零一夜的博客-CSDN博客
Pytorch加载模型并进行图像分类预测相关推荐
- Pytorch加载模型只导入部分层权重,即跳过指定网络层的方法
需求 Pytorch加载模型时,只导入部分层权重,跳过部分指定网络层.(权重文件存储为dict形式) 方法一 常见方法:加载权重时用if对网络层进行筛选 ''' # model为定义的网络结构: cl ...
- PyTorch加载模型model.load_state_dict()问题,Unexpected key(s) in state_dict: “module.features..,Expected .
希望将训练好的模型加载到新的网络上.如上面题目所描述的,PyTorch在加载之前保存的模型参数的时候,遇到了问题. Unexpected key(s) in state_dict: "mod ...
- pytorch加载模型报错Unexpected key(s) in state_dict: module.conv1.weight, module.bn1
文章目录 背景 报错 原因 解决 背景 Pytorch在加载模型参数的时候,有两种情况可能出现这种问题: 自己写的网络结构,例如: 代码 import models arch = 'resnet50' ...
- pytorch加载模型时出现.....ckpt_100.pth is a zip archive (did you mean to use torch.jit.load()?)
在测试加载训练好的模型时出现上方问题,参考这篇文章,原因是训练和测试的torch版本不一致. 训练的时候是1.6,测试的时候是1.2,因此需要先在1.6版本下加载模型,重新保存,在保存的时候设置use ...
- pytorch 加载模型 模型大小测试速度
直接加载整个模型 Pytorch保存和加载整个模型: save_net=model if hasattr(model, 'module'):save_net=model.module torch.sa ...
- pytorch 加载模型_福利,PyTorch中文版官方教程来啦(附下载)
PyTorch 中文版官方教程来了. PyTorch 是近期最为火爆的深度学习框架之一,然而其中文版官方教程久久不来.近日,一款完整的 PyTorch 中文版官方教程出炉,读者朋友从中可以更好的学习了 ...
- 使用PyTorch加载模型部分参数方法
前言 在深度学习领域,经常需要使用其他人已训练好的模型进行改进或微调,这个时候我们通常会希望加载预训练模型文件的参数,如果网络结构不变,只需要使用load_state_dict方法即可.而当我们改动网 ...
- pytorch 加载模型:
1.直接加载网络 import torchpthfile = r'E:\models\squeezenet1_1.pth'net = torch.load(pthfile)print(net) 方法2 ...
- pytorch 加载模型报错:‘function‘ object has no attribute ‘copy‘
太粗心了,保存模型的时候写错了,写成了如下: torch.save(model,model_file) 而实际上应该是: torch.save(model.state_dict(),model_fil ...
最新文章
- slf4j+log4j打印日志,控制台无日志输出
- switch case 支持的 6 种数据类型!
- 32位数型计算机什么意思,展示32位是什么意思
- Unity c#中Attribute用法详解
- 深入理解Magento – 第五章 – Magento资源配置
- jdbctemplate 批量删除_10秒3步批量去除PDF水印
- notepad++ :正则表达式系统教程(zz)
- 也谈虚拟化的服务器选型,以及性能考虑
- MFC初探 —— 子窗体相对于显示屏位置固定
- 2003退休去世领了2年退休金没回本就死了能退吗?
- 下列有关计算机系统叙述正确,()下列有关计算机系统软件的叙述正确的是____
- java如何检测redis是否可用
- LM2596、LM2576
- 一文搞懂数据结构之 递归-八皇后问题
- php的include once,php include_once的使用方法详解
- [4G5G专题-83]:架构 - 移动通信网2G/3G/4G/5G/6G网络架构的演进历程
- python黑色的_python怎么设置黑色背景
- 谱聚类算法入门教程(三)—— 求f^TLf的最小值
- 彩虹域名授权平台系统正版源码 带下载更新功能
- php的bs_PHP能否做BS架构的开发?
热门文章
- 非常好的产品研发管理文章,后面问题回答的很精彩(转)
- word两个不同表格合并,防止自动调整
- P1472 奶牛家谱 Cow Pedigrees
- fractions库的使用
- 济南近郊出游——线路指南
- android APP开发时,全屏手机适配的问题解决
- 如何扩展计算机c盘的控件,如何无损扩展C盘空间大小,这一招足够!
- python全栈工程师薪水_Python工程师薪资刷出新高度,有望成为世界上最流行的编程语言...
- 为什么很多人会觉得FPGA难学?
- matlab摩托车刹车问题,安全骑行篇,摩托车刹车的基本知识与技巧!