任务要求:利用torchvision中的预训练CNN模型来对真实的图像进行分类,预测每张图片的top5类别。
数据: real_image, class_index.json

导入:

import torch
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
import torchvisionimport os
import json
import time
import matplotlib.pyplot as plt
%matplotlib inline

1. 类别索引:

构建类别索引词典

f = open('./data/class_index.json')
class_index = json.load(f)
print('class num:', len(class_index))
class_dict = {int(k): v[1] for k, v in class_index.items()}
print(class_dict)

2. 预训练模型:

加载预训练CNN模型

import ssl
alexnet = models.alexnet(pretrained=True)
# import ssl
# resnet = models.resnet50(pretrained=True)

3. 图像预处理:

图像缩放、裁剪、转Tensor、归一化

image_transforms = transforms.Compose([transforms.Resize(256),                    transforms.CenterCrop(224),                transforms.ToTensor(),                     transforms.Normalize(                      mean=[0.485, 0.456, 0.406],                std=[0.229, 0.224, 0.225]                  )
])

4. 测试数据集加载:

构建测试数据集,迭代返回预处理后的Tensor格式图像和原始图像

class TestDataset():def __init__(self, root, transforms=None):imgs = os.listdir(root)self.imgs = [os.path.join(root, img) for img in imgs]self.transforms = transformsdef __getitem__(self, index):img_path = self.imgs[index]img_pil = Image.open(img_path)label = Noneimg_np = np.asarray(img_pil)data = self.transforms(img_pil)return data, img_npdef __len__(self):return len(self.imgs)
test_dir = './data/real_image/'
test_dataset = TestDataset(test_dir, image_transforms)
print('test image num:', test_dataset.__len__())

运行结果如下:

test image num: 20

5. 模型预测图像类别:

在测试模式下,对于每张图片显示原始图像,并输出模型预测的top5类别及top1类别

alexnet.eval()
for data, img_np in test_dataset:img = torch.unsqueeze(data, 0)output = alexnet(img)_, index = torch.max(output, 1)index=index.numpy()percentage = torch.nn.functional.softmax(output, dim=1)[0] * 100plt.imshow(img_np, aspect='auto')plt.show()print('top1类别:')print(class_dict[index[0]], percentage[index[0]].item())_, indices = torch.sort(output, descending=True)indices=indices.numpy()print('top5类别:')for idx in indices[0][:5]:print((class_dict[idx], percentage[idx].item()))print()

运行结果如下(部分):



……

python与机器学习(七)下——torchvision预训练模型测试真实图像分类相关推荐

  1. 深度学习之openvino预训练模型测试(车牌识别)

    0 背景 在上一篇文章<深度学习之openvino预训练模型测试>,我们介绍了如何使用 intel 提供的预训练模型完成语义分割任务.但在用 public 预训练模型时,发现我的 open ...

  2. 深度学习之openvino预训练模型测试

    0 背景 在<深度学习之win10安装配置openvino>中我们介绍了 openvino 的安装方法,本文对下一步的使用进行一个介绍. 1 模型介绍 openvino 提供了一系列的预训 ...

  3. pix2pix学习系列(1):预训练模型测试pix2pix

    pix2pix学习系列(1):预训练模型测试pix2pix 参考文献: [Pytorch系列-66]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix ...

  4. 【小白学PyTorch】5.torchvision预训练模型与数据集全览

    [机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 小白学PyTorch | 4 构建模型三要素与权重初始化 小白学PyTorch | 3 浅谈Dataset和Da ...

  5. mmsegmentation使用教程 使用 OpenMMLab 的 MM segmention 下Swin-T 预训练模型 语义分割 推理的记录

    前言: 给大家说一下怎么使用mmsegmention时,毕竟自己用mmsegmention走过很多弯路,然后结合其他人的文章和mmsegmention自己的doc来写下这个教程 第一部分: 我的电脑配 ...

  6. python寻找近义词:预训练模型 nltk+20newsbydate / gensim glove 转 word2vec

    本文用python寻找英文近义词(中文:https://github.com/huyingxi/Synonyms) 使用的都是预训练模型 方法一.nltk+20newsbydate (运行时下载太慢/ ...

  7. [Pytorch系列-66]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型测试pix2pix模型

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  8. RoNIN: Robust Neural Inertial Navigation预训练模型测试

    接着上篇模型驱动PDR的文章继续往下讲: RoNIN: Robust Neural Inertial Navigation in the Wild: Benchmark, Evaluations, a ...

  9. NeXtVLAD 飞酱预训练模型测试

    hi,dear 大佬: 找遍了全网,只有飞酱提供了预训练的模型,请使用_final版本的,下面我将用inceptionV3提取图像特征然后经过该模型得到concat之前聚类之后的特征,该特征我将用做e ...

最新文章

  1. 三、MySql查询语句执行的特征
  2. 【html 及 HTML5所有标签汇总】★★★
  3. UA MATH567 高维统计 专题0 为什么需要高维统计理论?——以线性判别分析为例
  4. 3种常用的防盗链的方式
  5. Tunnel Warfare(HDU1540+线段树+区间合并)
  6. day1||python
  7. 05Prism WPF 入门实战 - Navigation
  8. Oracle入门(五A)之conn命令
  9. python实例26[计算MD5]
  10. pyqt5设置dialog的标题_Python GUI教程(一):在PyQt5中创建第一个GUI图形用户界面...
  11. 关于LINUX的NVIDIA显卡驱动安装
  12. cad隐藏图层命令快捷键_cad与天正局部隐藏对象大法
  13. 国内外低代码平台一览
  14. 学习总结:Handler机制
  15. QT安装 and VS2019中安装QT插件
  16. adb shell 小米手机_小米手机ADB删除系统应用去广告。
  17. MATLAB直方图图像去雾算法实现
  18. 四足机器人|机器狗|仿生机器人|多足机器人|Adams仿真|Simulink仿真|基于CPG的四足机器人Simulink与Adams虚拟样机|源码可直接执行|绝对干货!需要资料及指导的可以联系我!
  19. 爬取初试----猫眼电影,猫眼评分
  20. 阿里点赞立法惩治刷单炒信:坚决拥护、全力支持

热门文章

  1. cratedb导入json文件
  2. Broadcom fullmac WLAN 驱动解析(1)
  3. enq: TT - contention等待事件
  4. jQuery选择器理解
  5. War3窗口限定小工具发布
  6. JFace中TableViewer的使用
  7. 性能测试工具选型原则
  8. 包教会一对一跟着CNS学单细胞测序(含空间转录组、chipseq、RNAseq、Atacseq 和外显子等)3月13日开始...
  9. 设置oracle每行显示字符个数,Oracle一列的多行数据拼成一行显示字符-Oracle
  10. 5条能让web前端至少手拿20万年薪的特性!