Pytorch使用预训练模型进行图像分类
在本文中,我们将介绍一些使用预训练网络的实际例子,这些网络出现在TorchVision模块的图像分类中。
Torchvision包包括流行的数据集,模型体系结构,和通用的图像转换为计算机视觉。基本上,如果你进入计算机视觉并使用PyTorch, Torchvision将会有很大的帮助!
1. Pre-trained Models for Image Classification
预训练模型是在大型 benchmark数据集(如ImageNet)上训练的神经网络模型。深度学习社区从这些开源模式中获益匪浅。此外,预先训练的模型是计算机视觉研究快速发展的一个主要因素。其他研究人员和从业人员可以使用这些最先进的模型,而不是重新发明一切从零开始。
下面给出了一个粗略的时间轴,说明了这些先进的模型是如何随着时间的推移而改进的。我们只包括那些模型是在Torchvision
包。
在详细介绍如何使用预先训练的模型进行图像分类之前,让我们看看有哪些预先训练的模型。我们将在这里讨论AlexNet和ResNet101作为两个主要的例子。这两个网络都接受过ImageNet数据集的训练。
ImageNet
数据集拥有超过1400
万张由斯坦福大学维护的图像。它被广泛用于各种图像相关的深度学习项目。这些图像属于不同的类别或标签。像AlexNet和ResNet101这样的预训练模型的目的是将图像作为输入并预测它的类别。
这里的“预训练”指的是深度学习架构,例如AlexNet和ResNet101,已经在一些(巨大的)数据集上进行了训练,因此产生了权重和偏差。架构与权重和偏差之间的区别应该非常清楚,因为我们将在下一节中看到,TorchVision既有架构,又有预训练的模型。
1.1. 模型推理过程
由于我们将关注如何使用预先训练的模型来预测输入的类(标签),所以让我们也讨论其中涉及的过程。这个过程被称为模型推理。整个过程由以下主要步骤组成。
- 1.读取输入图像
- 2.在图像上执行变换。例如,调整大小,中央裁剪,正规化等等。
- 3.前向传递:使用预先训练的权值找出输出向量。这个输出向量中的每个元素描述模型预测输入图像属于某个特定类的置信度。
- 4.根据得到的分数(我们在步骤3中提到的输出向量的元素),显示预测。
1.2 使用TorchVision加载预训练网络
现在,我们已经具备了模型推理的知识,并了解了预训练模型的含义,让我们看看如何在TorchVision模块的帮助下使用它们。
首先,让我们使用下面给出的命令安装TorchVision
模块。
pip install torchvision
接下来,让我们从torchvision
模块导入模型,看看我们有哪些不同的模型和架构可用。
from torchvision import models
import torchdir(models)
仔细观察我们得到的输出
['AlexNet','DenseNet','GoogLeNet','Inception3','MobileNetV2','ResNet','ShuffleNetV2','SqueezeNet','VGG',
...'alexnet','densenet','densenet121','densenet161','densenet169','densenet201','detection','googlenet','inception','inception_v3',
...
]
注意,有一个条目叫AlexNet
,另一个叫AlexNet
。大写的名称指的是Python类(AlexNet)
,而AlexNet
是一个方便的函数,它返回从AlexNet类实例化的模型。这些方便的函数也可能有不同的参数集。例如,densenet121
、densenet161
、densenet169
、densenet201
都是DenseNet
类的实例,但层数不同——分别为121,161,169和201
。
1.3. 利用AlexNet进行图像分类
让我们先从AlexNet开始。它是图像识别领域早期的突破性网络之一。如果你有兴趣了解AlexNet的架构,你可以看看我们在理解AlexNet上的帖子。
AlexNet架构
步骤1:加载预训练模型
在第一步中,我们将创建一个网络实例。我们还将传递一个参数,以便函数可以https://github.com/spmallick/learnopencv/tree/master/Inference-for-PyTorch-Models/ONNX-Caffe2d模型的权重。
alexnet = models.alexnet(pretrained=True)# You will see a similar output as below
# Downloading: "https://download.pytorch.org/models/alexnet-owt- 4df8aa71.pth" to /home/hp/.cache/torch/checkpoints/alexnet-owt-4df8aa71.pth
注意,通常PyTorch模型的扩展名为.pt或.pth
一旦下载了权重,我们就可以继续其他步骤。我们还可以检查网络架构的一些细节,如下所示。
print(alexnet)
步骤2:图像变换
一旦我们有了模型,下一步就是变换输入图像,使它们具有正确的形状和其他特征,如平均值和标准差。这些值应该与训练模型时使用的值一样
。这确保了网络将产生有意义的答案。
我们可以利用TochVision模块中的变换对输入图像进行预处理。在这种情况下,我们可以对AlexNet和ResNet使用以下转换。
from torchvision import transforms
transform = transforms.Compose([ #[1]transforms.Resize(256), #[2]transforms.CenterCrop(224), #[3]transforms.ToTensor(), #[4]transforms.Normalize( #[5]mean=[0.485, 0.456, 0.406], #[6]std=[0.229, 0.224, 0.225] #[7])])
让我们试着理解上面的代码片段中发生了什么。
- 行[1]:在这里,我们定义了一个
transform
,它是对输入图像进行的所有图像变换的组合。 - 行[2]:调整图像的大小为256×256像素。
- 行[3]:裁切图像到224×224像素左右的中心。
- 行[4]:将图像转换为PyTorch张量数据类型。
- 行[5-7]:对图像进行归一化,将其平均值和标准差设置为规定值。
第三步:加载输入图像并进行预处理
接下来,让我们加载输入图像并执行上面指定的图像转换。请注意,我们将在TorchVision
中广泛使用Pillow (PIL)
模块,因为它是TorchVision
支持的默认图像后端。
# Import Pillow
from PIL import Image
img = Image.open("dog.jpg")
接下来,对图像进行预处理,并准备批处理以通过网络。
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)
步骤4:模型推理
最后,是时候使用预先训练的模型来查看模型认为图像是什么。
首先,我们需要将模型置于eval
模式
alexnet.eval()
接下来,让我们执行推论。
out = alexnet(batch_t)
print(out.shape)
这一切都很好,但我们如何处理这个输出向量,其中包含1000
个元素?我们仍然没有得到图像的类(或标签)。为此,我们将首先从一个包含所有1000
个标签列表的文本文件中读取和存储标签。注意,行号指定了类号,所以确保不改变顺序是非常重要的。
with open('imagenet_classes.txt') as f:classes = [line.strip() for line in f.readlines()]
因为AlexNet
和ResNet
已经在相同的ImageNet
数据集上进行了训练,所以我们可以对两个模型使用相同的Class
列表。
现在,我们需要找出输出向量out中最大分数出现的索引。我们将使用这个指标来找出预测。
_, index = torch.max(out, 1)percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100print(labels[index[0]], percentage[index[0]].item())
该模型预测的图像是拉布拉多犬有41.58%
的置信度。
但这听起来太低了。让我们看看模型认为图像还属于什么类别。
这里是输出:
[('Labrador retriever', 41.585166931152344),('golden retriever', 16.59166145324707),('Saluki, gazelle hound', 16.286880493164062),('whippet', 2.8539133071899414),('Ibizan hound, Ibizan Podenco', 2.3924720287323)]
这个模型成功地预测出这是一只狗,但它对狗的品种不是很确定。
让我们对草莓和汽车的图像做同样的尝试,看看我们得到的输出。
这是上述草莓图像得到的输出。我们可以看到,得分最高得分是“草莓”,得分接近99.99%。
[('strawberry', 99.99365997314453),('custard apple', 0.001047826954163611),('banana', 0.0008201944874599576),('orange', 0.0007371827960014343),('confectionery, confectionary, candy store', 0.0005758354091085494)]
类似地,对于上面给出的汽车图像,输出如下。
[('cab, hack, taxi, taxicab', 33.30569839477539),('sports car, sport car', 14.424001693725586),('racer, race car, racing car', 10.685123443603516),('beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',7.846532821655273),('passenger car, coach, carriage', 6.985556125640869)]
就是这样!通过这4个步骤,就可以使用预先训练好的模型进行图像分类。
我们在ResNet上试试同样的方法怎么样?
1.4. 使用ResNet进行图像分类
我们将使用resnet101 - 101
层卷积神经网络。Resnet101
在训练过程中调整了大约4450
万个参数。这是巨大的!
让我们快速浏览一下使用resnet101进行图像分类所需的步骤。
# First, load the model
resnet = models.resnet101(pretrained=True)# Second, put the network in eval mode
resnet.eval()# Third, carry out model inference
out = resnet(batch_t)# Forth, print the top 5 classes predicted by the model
_, indices = torch.sort(out, descending=True)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
[(labels[idx], percentage[idx].item()) for idx in indices[0][:5]]
这是resnet101的预测。
[('Labrador retriever', 48.25556945800781),('dingo, warrigal, warragal, Canis dingo', 7.900787353515625),('golden retriever', 6.916920185089111),('Eskimo dog, husky', 3.6434383392333984),('bull mastiff', 3.0461232662200928)]
就像AlexNet一样,ResNet成功地预测出这是一只狗,并以48.25%的自信预测出这是一只拉布拉多寻回犬。
2. 模型比较
到目前为止,我们已经讨论了如何使用预先训练的模型来执行图像分类,但我们还没有回答的一个问题是,我们如何决定为特定的任务选择哪个模型。在本节中,我们将根据以下标准比较预训练模型:
- Top-1误差:如果最有把握的模型预测的类与真实的类不一样,就会出现Top-1误差。
-top-5错误:当真实的类不在模型预测的前5个类中(根据置信度排序),就会出现top-5错误。 - CPU上的推理时间:推理时间是模型推理步骤所花费的时间。
- GPU上的推理时间
模型大小:这里的大小表示PyTorch提供的预训练模型的.pth文件所占用的物理空间
一个好的模型会有较低的Top-1
错误,较低的Top-5
错误,较低的CPU
和GPU
推理时间和较低的模型尺寸。
所有的实验都是在相同的输入图像上进行的,并进行多次,以便对特定模型的所有结果进行平均,以便进行分析。实验是在谷歌Colab上进行的。现在,让我们看看所获得的结果。
2.1. 模型精度比较
我们要讨论的第一个标准包括Top-1和Top-5错误。top -1误差是指top预测类与真实情况不同时的误差。由于这是一个相当困难的问题,有另一个误差测量称为Top-5误差。如果前5个预测类别中没有一个是正确的,那么该预测将被归类为错误。
从图中可以看出,这两个error遵循相似的趋势。AlexNet是基于深度学习的第一次尝试,从那时起在错误方面有了改进。值得一提的是GoogLeNet, ResNet, VGGNet, ResNext。
2.2. 推理时间比较
接下来,我们将基于模型推理所花费的时间来比较模型。一个图像被多次提供给每个模型,所有迭代的推理时间被平均。在谷歌Colab
上对CPU
和GPU
执行了类似的过程。即使在顺序上有一些变化,我们可以看到SqueezeNet, ShuffleNet
和ResNet-18
的推断时间非常低,这正是我们想要的。
2.3. 模型的大小比较
很多时候,当我们在android或iOS设备上使用深度学习模型时,模型大小成为一个决定性因素,有时甚至比准确性更重要。SqueezeNet的模型尺寸最小(5 MB),其次是ShuffleNet V2 (6 MB)和MobileNet V2 (14 MB)。很明显,这些模型在使用深度学习的移动应用中更受欢迎。
2.4. 整体比较
我们讨论了在特定标准的基础上哪个模型表现得更好。我们可以将所有这些重要的细节压缩到一个气泡图中,然后根据我们的需求来决定使用哪个模型。
我们使用的x坐标是Top-1误差(越低越好)。y坐标是GPU上的推断时间,以毫秒为单位(越低越好)。气泡大小代表模型大小(越小越好)。
注意:
- 较小的气泡在模型尺寸方面更好。
- 靠近原点的气泡在精度和速度方面都更好。
3.结论
从上图可以看出,ResNet50
在所有三个参数上都是最好的模型(尺寸小,更接近原点)
DenseNets
和ResNext101
在推理时间上是昂贵的。AlexNet
和SqueezeNet
的错误率都很高。
好了,就到这里吧!在这篇文章中,我们介绍了如何使用Torchvison模块进行图像分类,使用预训练的模型-只有4个步骤的过程。我们也进行了模型比较,根据我们的项目需求来决定选择什么样的模型。在下一篇文章中,我们将介绍如何使用迁移学习来使用PyTorch训练自定义数据集上的模型。
源代码地址下载
Pytorch使用预训练模型进行图像分类相关推荐
- 使用PyTorch中的预训练模型进行图像分类
PyTorch的TorchVision模块中包含多个用于图像分类的预训练模型,TorchVision包由流行的数据集.模型结构和用于计算机视觉的通用图像转换函数组成.一般来讲,如果你进入计算机视觉和使 ...
- linux载入pytorch的预训练模型时遇到_pickle.UnpicklingError: unpickling stack underflow
linux试图载入pytorch的预训练模型resnet101时遇到如下报错: Traceback (most recent call last): File "train_baseline ...
- Pytorch——BERT 预训练模型及文本分类(情感分类)
BERT 预训练模型及文本分类 介绍 如果你关注自然语言处理技术的发展,那你一定听说过 BERT,它的诞生对自然语言处理领域具有着里程碑式的意义.本次试验将介绍 BERT 的模型结构,以及将其应用于文 ...
- pytorch载入预训练模型后,训练指定层
1.有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练: pretrained_params = torch.load('Pretrained_Model') ...
- Pytorch——XLNet 预训练模型及命名实体识别
介绍 在之前我们介绍和使用了 BERT 预训练模型和 GPT-2 预训练模型,分别进行了文本分类和文本生成次.我们将介绍 XLNet 预训练模型,并使用其进行命名实体识别次. 知识点 XLNet 在 ...
- Pytorch提取预训练模型特定中间层的输出
如果是你自己构建的模型,那么可以再forward函数中,返回特定层的输出特征图. 下面是介绍针对预训练模型,获取指定层的输出的方法. 如果你只想得到模型最后全连接层之前的输出,那么只需要将最后一个全连 ...
- Pytorch Resnet预训练模型参数地址
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet34': 'https://downl ...
- pytorch bert预训练模型的加载地址
https://blog.csdn.net/weixin_39331401/article/details/109328681
- 载入pytorch的预训练模型时遇到_pickle.UnpicklingError: unpickling stack underflow
转载自https://blog.csdn.net/iteapoy/article/details/106193500
最新文章
- 刘铁岩:AI打通关键环节,加快物流行业数字化转型
- 2017年英特尔在其数据中心业务和AI方面下大注
- Docker之docker简介及其优势
- 智慧城轨信息技术架构及信息安全规范_在深圳,我们打造智慧地铁的“最强大脑”...
- mysql 本地连接_mysql开启远程连接及本地连接
- 7-60 致命的珠宝 (10分)
- spss下载以及安装详细教程
- PADS2007教程(二)——PCB封装
- newifi3刷什么固件最稳定_新路由三无线路由器刷什么固件好?
- 如何在Mac OS上从Photoshop 2020作为插件访问Topaz DeNoise AI?
- 生物医学数据统计分析-两组或多组计量资料的比较
- debian下配置防火墙iptables
- Flash 101-第1部分:锤子和凿子
- 2020牛客暑期多校训练营(第九场)	The Escape Plan of Groundhog
- codeup刷题2.5小节 C/C++快速入门->数组——《算法笔记》(胡凡)
- Elasticsearch外网无法通过ip访问
- 程序中的地址转换(虚拟地址-物理地址)
- JavaWeb Ajax的使用
- 视频教程-用project做项目计划及总结报表-研发管理
- 明天就是七夕了,用Python做了个可能会被女朋友打死的礼物!
热门文章
- 营销值得学:创业做生意如何降维打击?
- 3dmax快速实现一个逼真地毯效果
- Progressive Layered Extraction: A Novel Multi-TaskLearning Model for Personalized Recommendations
- 认识浏览器:浏览器内核/页面加载/DOM和DOM树
- oracle用升序索引去降序查询,Oracle工作札记
- Java警告The serializable class XXX does not declare a static final serialVersionUID field of type long
- ArcGIS切片生成工具-ArcGIS缓存管理
- 直播平台接入美颜SDK已成刚需,它将带来哪些影响?
- 信息学奥赛一本通:1194:移动路线
- MySql数据类型-读书笔记