TorchVision中给出了AlexNet的pretrained模型,模型存放位置为https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth ,可通过models.alexnet函数下载,此函数实现在torchvision/models/alexnet.py中,下载后在Ubuntu上存放在~/.cache/torch/hub/checkpoints目录下,在Windows上存放在C:\Users\spring\.cache\torch\hub\checkpoints目录下,其中spring为用户名。

AlexNet的介绍参考:https://blog.csdn.net/fengbingchun/article/details/112709281

在推理(inference)过程中,模型的输入是一个tensor,shape需要是[1,c,h,w],原始图像进行预处理操作包括:

(1).resize到短边为256,长边等比缩放。

(2).在中心裁剪图像大小到224*224。

(3).将数据从numpy.ndarray转换到tensor;原数据shape为[h,w,c],转换后tensor shape为[c,h,w];原数据值范围为[0,255],转换后值范围为[0.0,1.0]。

(4).使用均值和标准差对tensor图像进行归一化。

(5).将tensor的shape从[c,h,w]转换到[1,c,h,w]。

模型是通过ImageNet数据集训练获得的,它的图像分类数是1000,ImageNet数据集的介绍参考:https://blog.csdn.net/fengbingchun/article/details/88606621

以下为测试代码:

import torch
from torchvision import models
from torchvision import transforms
import cv2
from PIL import Image
import math
import numpy as np#print(dir(models))images_path = "../../data/image/"
images_name = ["5.jpg", "6.jpg", "7.jpg", "8.jpg", "9.jpg", "10.jpg"]
images_data = [] # opencv
tensor_data = [] # pytorch tensordef images_stitch(images, cols=3, name="result.jpg"): # 图像简单拼接'''images: list, opencv image data; cols: number of images per line; name: save image result name'''width_total = 660width, height = width_total // cols, width_total // colsnumber = len(images)height_total = height * math.ceil(number / cols)mat1 = np.zeros((height_total, width_total, 3), dtype="uint8") # in Python images are represented as NumPy arraysfor idx in range(number):height_, width_, _ = images[idx].shapeif height_ != width_:if height_ > width_:width_ = math.floor(width_ / height_ * width)height_ = heightelse:height_ = math.floor(height_ / width_ * height)width_ = widthelse:height_, width_ = height, widthmat2 = cv2.resize(images[idx], (width_, height_))offset_y, offset_x = (height - height_) // 2, (width - width_) // 2start_y, start_x = idx // cols * height, idx % cols * widthmat1[start_y + offset_y:start_y + height_+offset_y, start_x + offset_x:start_x + width_+offset_x, :] = mat2cv2.imwrite(images_path+name, mat1)for name in images_name:img = cv2.imread(images_path + name)print(f"name: {images_path+name}, opencv image shape: {img.shape}") # (h,w,c)images_data.append(img)img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img_pil = Image.fromarray(img)transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])tensor = transform(img_pil)print(f"tensor shape: {tensor.shape}, max: {torch.max(tensor)}, min: {torch.min(tensor)}") # (c,h,w)tensor = torch.unsqueeze(tensor, 0) # 返回一个新的tensor,对输入的既定位置插入维度1print(f"tensor shape: {tensor.shape}, max: {torch.max(tensor)}, min: {torch.min(tensor)}") # (1,c,h,w)tensor_data.append(tensor)images_stitch(images_data)model = models.alexnet(pretrained=True) # AlexNet网络
#print(model) # 可查看模型结构,与torchvision/models/alexnet.py中一致
model.eval() # AlexNet is required to be put in evaluation mode in order to do prediction/evaluationwith open("imagenet_classes.txt") as f:classes = [line.strip() for line in f.readlines()] # the line number specified the class numberfor x in range(len(tensor_data)):prediction = model(tensor_data[x])#print(prediction.shape) # [1,1000]_, index = torch.max(prediction, 1)percentage = torch.nn.functional.softmax(prediction, dim=1)[0] * 100print(f"result: {classes[index[0]]}, {percentage[index[0]].item()}")print("test finish")

执行结果如下:以下原始测试图像来自网络,每张图像仅输出可信度值最高的一个类别。从上往下,从左往右,每张图像的分类结果依次是:goldfish(金鱼)、hen(母鸡)、ostrich(鸵鸟)、African crocodile(非洲鳄鱼)、goose(鹅)、hartebeest(羚羊)。

GitHub:https://github.com/fengbingchun/PyTorch_Test

TorchVision中通过AlexNet网络进行图像分类相关推荐

  1. 深度学习经典网络解析图像分类篇(二):AlexNet

    深度学习经典网络解析图像分类篇(二):AlexNet 1.背景介绍 2.ImageNet 3.AlexNet 3.1AlexNet简介 3.2AlexNet网络架构 3.2.1第一层(CONV1) 3 ...

  2. 经典再读 | NASNet:神经架构搜索网络在图像分类中的表现

    (图片付费下载于视觉中国) 作者 | Sik-Ho Tsang 译者 | Rachel 编辑 | Jane 出品 | AI科技大本营(ID:rgznai100) [导读]从 AutoML 到 NAS, ...

  3. 使用PyTorch中的预训练模型进行图像分类

    PyTorch的TorchVision模块中包含多个用于图像分类的预训练模型,TorchVision包由流行的数据集.模型结构和用于计算机视觉的通用图像转换函数组成.一般来讲,如果你进入计算机视觉和使 ...

  4. 深度学习入门笔记之ALexNet网络

    Alex提出的alexnet网络结构模型,在imagenet2012图像分类challenge上赢得了冠军.作者训练alexnet网络时大致将120万张图像的训练集循环了90次,在两个NVIDIA G ...

  5. Pytorch:使用Alexnet网络实现CIFAR10分类

    全部代码: https://github.com/SPECTRELWF/pytorch-cnn-study 网络介绍: Alexnet网络是CV领域最经典的网络结构之一了,在2012年横空出世,并在当 ...

  6. AlexNet网络的结构详解与实现

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 一:AlexNet网络结构 在2012年ImageNet图像分类任 ...

  7. AlexNet网络复现

    AlexNet 学习流程 阅读AlexNet论文原文 搜集学习资源:视频讲解-博客资源 熟悉AlexNet网络结构 代码复现,清楚网络结构中层与层之间的操作 AlexNet论文 原论文:imagene ...

  8. 使用AlexNet网络区分宝可梦和数码宝贝

    想法来源:李宏毅老师的机器学习课 通过改写Alexnet网络最后一层全连接层,使其能分辨宝可梦和数码宝贝. 数据集准备以及预处理 数据来源: 宝可梦 数码宝贝 下载宝可梦和数码宝贝,设置训练集和测试集 ...

  9. TorchVision中使用FasterRCNN+ResNet50+FPN进行目标检测

    TorchVision中给出了使用ResNet-50-FPN主干(backbone)构建Faster R-CNN的pretrained模型,模型存放位置为https://download.pytorc ...

最新文章

  1. 资源推荐 | 知识图谱顶会文献集锦(附链接)
  2. Jackson:数组json字符串转对象集合(List)的两种方式
  3. Python的numpy库中rand(),randn(),randint(),random_integers()的使用
  4. ISP、IAP、ICP的区别!
  5. win7系统, vim的_vimrc文件无法修改
  6. vue el-upload上传组件限制文件类型:accept属性
  7. xxx钻石商城功能开发需求
  8. mysql查询错误_一个奇怪的MySQL查询错误
  9. LeetCode 865. 具有所有最深结点的最小子树(递归)
  10. String s1=new String(“abc“); 和String s1=“abc“区别
  11. python 双向循环链表实现_python实现双向循环链表基本结构及其基本方法
  12. 一线大厂在用的反爬虫方法,看我如何破了它!
  13. SpringBoot+JMail
  14. [Win32] 打字游戏MFC版
  15. Heartbeat简介
  16. 理解手机中的感应器模块:重力感应/光线感应/电子罗盘/陀螺仪模块功能
  17. ubuntu 18.04快捷显示桌面
  18. codecademy SQL lesson2
  19. 计算机为动态分区无法安装系统,磁盘动态分区形式的电脑怎么重装系统win10
  20. 2G/3G LAC与4G/5G TAC的协同优化

热门文章

  1. 【YOLOV4】(7) 特征提取网络代码复现(CSPDarknet53+SPP+PANet+Head),附Tensorflow完整代码
  2. 三、如何搞自定义数据集?
  3. 检测硬盘使用时长_如何检测硬盘问题
  4. 有java基础的人学python_准备自学Python ,会java,有什么建议吗?
  5. 结构化场景中的RGB-D SLAM
  6. 【fiveKeyPress】2秒内五次点击键盘任意键(或组合键)触发自定义事件(以Pause/Break键为例)
  7. python threading模块多线程源码示例(二)
  8. 《DDIA》读书笔记(一):可靠性、可扩展性、可维护性
  9. Python学习 day01打卡
  10. Django模板系统和admin模块