pytorch网络输入图片的格式是[B,C,H,W],分别为批大小(batchsize),图片通道数(channel),图片高(height),图片宽(width)。

图片读取方式主要有两种:(1)通过PIL进行读取;(2)通过opencv进行读取。分别进行介绍。

(1)通过PIL进行读取:

通过PIL的Image读取的图片是一个图片对象,可以进行裁剪翻转等torchvision.transforms变换。

torchvision.transforms可以对图像对象进行一系列裁剪、翻转等转换操作,其中也包括转换为tensor张量。(transforms.ToTensor())

(2)通过opencv进行读取

opencv读取的是ndarray格式,不能进行torchvision.transforms变换。

两种读取图像的测试代码如下,其中Net为自己随便写的一个网络,此处输出的类别为7类,模型替换为自己的,同时修改class_names和权重文件。注意,读取的图片路径也要修改。

import torch
from torch import nn
from PIL import Image
from torchvision import transforms,datasets
import cv2class Net(nn.Module):def __init__(self):super(Net, self).__init__()       self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 32, padding = 1,kernel_size = 3)self.pool1 = nn.MaxPool2d(kernel_size = 2,stride = 2)self.conv2 = nn.Conv2d(in_channels = 32, out_channels = 64, padding = 1,kernel_size = 3)self.conv3 = nn.Conv2d(in_channels = 64, out_channels = 128, padding = 1,kernel_size = 3)self.conv4 = nn.Conv2d(in_channels = 128, out_channels = 128, padding = 1, kernel_size = 3)self.relu = nn.ReLU()self.flatten = nn.Flatten()self.linear1 = nn.Linear(128 * 7 * 7, 512)self.linear2 = nn.Linear(512, 6)self.softmax = nn.Softmax()def forward(self, x):x = self.relu(self.conv1(x))x = self.pool1(x)x = self.relu(self.conv2(x))x = self.pool1(x)x = self.relu(self.conv3(x))x = self.pool1(x)x = self.relu(self.conv4(x))x = self.flatten(x)x = self.relu(self.linear1(x))y = self.linear2(x)return y  class_names = ['GC', 'GL', 'NL', 'RC', 'RL', 'UK'] #这个顺序很重要,要和训练时候的类名顺序一致device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")##载入模型并读取权重
model = Net()
model.load_state_dict(torch.load("./data/detect_light.pt"))
model.to(device)
model.eval()img_path = '/home/jwd/dataset/roi455.jpg'#(1)此处为使用PIL进行测试的代码
transform_valid = transforms.Compose([transforms.Resize((56, 56), interpolation=2),transforms.ToTensor()]
)
img = Image.open(img_path)
img_ = transform_valid(img).unsqueeze(0) #拓展维度##(2)此处为使用opencv读取图像的测试代码,若使用opencv进行读取,将上面(1)注释掉即可。
# img = cv2.imread(img_path)
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# img = cv2.resize(img, (56, 56))
# img_ = torch.from_numpy(img).float().permute(2, 0, 1).unsqueeze(0)/255img_ = img_.to(device)
outputs = model(img_)#输出概率最大的类别
_, indices = torch.max(outputs,1)
percentage = torch.nn.functional.softmax(outputs, dim=1)[0] * 100
perc = percentage[int(indices)].item()
result = class_names[indices]
print('predicted:', result)# 得到预测结果,并且从大到小排序
# _, indices = torch.sort(outputs, descending=True)
# 返回每个预测值的百分数
# percentage = torch.nn.functional.softmax(outputs, dim=1)[0] * 100
# print([(class_names[idx], percentage[idx].item()) for idx in indices[0][:5]])

利用pytorch训练好的模型测试单张图片相关推荐

  1. Pytorch训练Bilinear CNN模型笔记

    Pytorch训练Bilinear CNN模型笔记 注:一个项目需要用到机器学习,而本人又是一个python小白,根据老师的推荐,然后在网上查找了一些资料,终于实现了目的. 参考文献: Caltech ...

  2. 如何调用 caffe 训练好的模型对输入图片进行测试

    如何调用 caffe 训练好的模型对输入图片进行测试 该部分包括两篇文章 win10 下 caffe 的第一个测试程序(附带详细讲解) 主要讲解如何利用 caffe 来训练模型. 如何调用 caffe ...

  3. 利用GPT2训练中文闲聊模型

    利用GPT2模型训练中文闲聊模型 最近看了一下GPT2模型,看到很多博主都用来写诗歌,做问答等,小编突然萌生一个想法,利用GPT2来训练一个闲聊模型!!(小说生成器模型已经破产,写出来的东西狗屁不通, ...

  4. 利用PaddleOCR训练车牌识别模型

    目录 1--前言 2--生成车牌数据集 3--构建车牌数据集标签 4--自定义字典 5--训练模型 6--模型转换和推理 7--模型转换为onnx模型 8--参考 1--前言 ①系统:Ubuntu18 ...

  5. 使用SSD训练自己的模型(从图片标注开始)

    此文章参考了https://blog.csdn.net/zzZ_CMing/article/details/81131101  在此表示感谢,如果有侵权的地方可联系本人删除 训练手表模型步骤 未经允许 ...

  6. GPU测试单张图片时间过长

    为什么在测试GPU和CPU的速度的时候会出现GPU反而比CPU慢的 我之前为了测试resnext101网络在CPU和GPU上的单张图片测试程序 很显然不对啊,后来我发现,GPU在刚启动测试第一张图片, ...

  7. PyTorch:保存/加载训练好的模型测试

    保存 torch.save(model.state_dict(), './cnn.pth') 加载 model = VGG16() #加载模型前要创建一个模型的实例对象 model.load_stat ...

  8. 利用 PyTorch 训练神经网络(详细版)

    点击关注我哦 欢迎关注 "小白玩转Python",发现更多 "有趣" "A little learning is a dangerous thing; ...

  9. pytorch加载的模型测试的结果和保存时测试的结果不一致

    假设有一个dropout网络net,训练过程中用测试集进行了测试,接着将该网络进行了保存 torch.save(net.state_dict(), path) 然后将保存的网络加载出来: net=cl ...

  10. pytorch训练的pt模型转换为onnx(nn.DataParallel()、model、model.state_dict())

    pt转onnx流程与常见问题 pt转onnx流程 pt转onnx流程 1.读取pt模型文件,文件既可以是torch.save(model,path)整体保存的模型,也可以是保存的字典文件. // An ...

最新文章

  1. 在GitHub上搭建GitHub Pages博客-- Jekyll
  2. [No000010F]Git8/9-使用GitHub
  3. 【Docker实战之入门】Dockerfile详细分析:构建docker镜像(4)构建动态网站WordPress...
  4. 【Python】疫情卷土重来?Python可视化带你追踪疫情的最新动态
  5. P1600 天天爱跑步
  6. [笑]每个人都有脑袋脱线的时候……
  7. mongoose换成mysql_Package - tms-koa
  8. django 表单html5,我们如何在django管理表单中添加动态html5数据属性
  9. 安徽50岁计算机职称免考,50岁以上评职称免考外语
  10. 多个线程“打架抢夺”同一个资源,该如何让它们安分?
  11. 中国生态系统服务空间数据集/食物生产、土壤保持、水源涵养、防风固沙、生物多样性、碳固定
  12. 2019java后端面试集合篇最值得收藏的(一)
  13. 深度学习---之显存单位,KiB,MiB与MB区别
  14. storm和vgj vgj_风暴很忙:VGJ.Storm新阵容亮相DAC预选赛
  15. 业务安全情报第四期:新能源车企重金打造的私域流量,成为黑灰产“掘金发财”的新目标
  16. uni-app 评论五星
  17. 小米手机刷android one,小米手机(Mi One)刷机教程详解完整版 (刷MIUI官方刷机包)...
  18. C#:Winform 打字测速程序 Typer
  19. Blog技巧,让Google把你的blog翻译成英文
  20. nodejs安装ffi模块调用dll详解

热门文章

  1. 给别的计算机硬盘装系统,在一台计算机上装好系统的硬盘移到另一个电脑能用吗?...
  2. centos7.3根目录空间扩展
  3. 自我鉴定200字大专生计算机专业,大专毕业生自我鉴定200字
  4. 淘宝商品上传API接口
  5. 针对s3c2440芯片制作交叉编译工具链
  6. Gitlab Code Review
  7. PageOffice国产版的授权及离线注册
  8. 考研高数 专题7:方程根的存在性及个数(零点定理-罗尔定理;单调性-罗尔定理推论)
  9. 小黄鸡.Net版(Simsimi.Net)
  10. 去中心化的联邦学习专栏