PyTorch的TorchVision模块中包含多个用于图像分类的预训练模型,TorchVision包由流行的数据集、模型结构和用于计算机视觉的通用图像转换函数组成。一般来讲,如果你进入计算机视觉和使用PyTorch,TorchVision可以提供还多便利的操作!

1、使用预训练模型进行图像分类

预训练模型是在像ImageNet这样的大型基准数据集上训练得到的神经网络模型。深度学习社区从这些开源模型中获益良多,这也是计算机视觉研究迅速发展的主要原因之一。其他研究人员和实践者可以使用这些最先进的模型,而不是从零开始训练。

下面是一个粗略的时间表,说明最先进的模型是如何随着时间的推移而改进的。我们只包括了TorchVision中的那些模型。

在我们关注如何使用预训练模型进行图像分类的细节之前,让我们看看各种预训练模型是什么。我们将以AlexNet和ResNet50为两个主要例子。这两个网络都在ImageNet数据集上进行了训练。ImageNet数据集拥有斯坦福大学维护的1400多万张图像,广泛用于各种与图像相关的深度学习项目。图像属于各种标签。即使可以互换使用这两个术语,我们在此统一使用类这个术语。像AlexNet和ResNet101这样预先训练过的模型的目的是以图像作为输入并预测它的类。

这里预先训练的含义是像AlexNet和ResNet50这样的深度学习结构,已经在一些(巨大的)数据集上进行了训练,从而携带了由此训练所产生的权重和偏差。'网络架构'与‘权重和偏差’之间的这种区别应该是非常清楚的,TorchVision既有网络架构,也有预训练模型。

1.1.模型推断过程

由于我们将重点讨论如何使用预训练模型来预测输入的类(标签),所以我们也讨论一下与此相关的过程。这个过程被称为模型推断。整个过程由以下主要步骤组成.

1、读取输入图像

2、在图像上执行转换,例如调整大小、中心裁剪、正则化等,这些都属于图像的预处理,为了是输入图像符合模型的要求。

3、正向传递:这一步使用预先训练的权重来找出输出向量(二维数组),此输出向量中的每个元素描述该模型预测输入图像属于某个类的置信度。例如是狗的可能性为0.2、猫的可能性为0.5、老虎的可能性为0.1等。

4、根据所获得的置信度(步骤3中我们提到的输出向量的元素),显示预测结果。

我们先要为接下来的实验准备材料,下载如下资源:

  1. wget https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg -O dog.jpg
  2. wget https://raw.githubusercontent.com/Lasagne/Recipes/master/examples/resnet50/imagenet_classes.txt

1.2.使用TorchVision加载预先训练网络

既然我们已经具备了模型推理的知识,以及预训练模型意味着什么,让我们看看如何在TorchVision模块的帮助下使用它们。

首先,让我们安装TorchVision模块。使用

pip3 install torchvision

或者

python3 -m pip install torchvision

接下来,让我们从torchvison模块导入models模块,看看我们可以使用哪些不同的模型和网络结构。

  1. from torchvision import models
  2. import torch
  3. dir(models)

结果如下:

  1. ['AlexNet',
  2. 'DenseNet',
  3. 'GoogLeNet',
  4. 'Inception3',
  5. 'MobileNetV2',
  6. 'ResNet',
  7. 'ShuffleNetV2',
  8. 'SqueezeNet',
  9. 'VGG',
  10. ...
  11. 'alexnet',
  12. 'densenet',
  13. 'densenet121',
  14. 'densenet161',
  15. 'densenet169',
  16. 'densenet201',
  17. 'detection',
  18. 'googlenet',
  19. 'inception',
  20. 'inception_v3',
  21. ...
  22. ]

请注意,有一个条目名为AlexNet,还有一个条目名为alexnet。大写名称指的是Python类(AlexNet),而alexnet是一个功能函数,它返回从AlexNet类实例化的模型。这些功能函数也可能有不同的版本,例如,densenet121, densenet161, densenet169, densenet201都返回DenseNet 类的实例,但层数不同-分别为121、161、169和201。

1.3.利用AlexNet进行图像分类

让我们首先从AlexNet开始,它是图像识别中早期突破性的网络之一。 其网络结构如下:

第一步:加载预先训练好的模型

在第一步中,我们将创建一个网络实例。我们需要传递一个参数,以便函数可以下载模型的权重。。

  1. alexnet = models.alexnet(pretrained=True)
  2. # You will see a similar output as below
  3. # Downloading: "https://download.pytorch.org/models/alexnet-owt- 4df8aa71.pth" to /home/frank/.cache/torch/checkpoints/alexnet-owt-4df8aa71.pth

模型下载需要10分钟左右,模型文件会缓存在用户的相应目录中,下次加载就不需要下载了。

注:我的电脑上缓存在了/home/frank/.cache/torch/checkpoints/alexnet-owt-4df8aa71.pth

请注意,通常PyTorch模型的扩展为.pt或.pt, 一旦权重被下载,我们就可以接着继续其他步骤。我们还可以查看网络结构的一些细节如下。

print(alexnet)

上面这条语句只显示AlexNet网络结构中的各种操作和层,不显示具体的权重和偏差。

步骤2:设置图像转换

一旦我们有了模型,下一步就是转换输入图像,使它们具有正确的形状和其他特征,如均值和标准差。这些值应该类似于在训练模型时使用的值。这确保了网络将产生有正确的推断结果。 我们可以借助TochVision模块中存在的transforms对输入图像进行预处理。本例中,我们可以对AlexNet和ResNet使用以下转换。

  1. from torchvision import transforms
  2. transform = transforms.Compose([ #[1]
  3. transforms.Resize(256), #[2]
  4. transforms.CenterCrop(224), #[3]
  5. transforms.ToTensor(), #[4]
  6. transforms.Normalize( #[5]
  7. mean=[0.485, 0.456, 0.406], #[6]
  8. std=[0.229, 0.224, 0.225] #[7]
  9. )])

第[1]行:这里我们定义了一个变量,它是对输入图像进行的所有图像转换的组合。

第[2]行:将图像调整为256×256像素。

第[3]行:将图像中心裁剪出来,大小为224×224像素。

第[4]行:将图像转换为PyTorch张量(tensor)数据类型。

行[5-7]:通过将图像的平均值和标准差设置为指定的值来正则化图像。

步骤3:加载输入图像并对其进行预处理

接下来,让我们加载输入图像并执行上面指定的图像转换。请注意,我们将广泛使用pillow(PIL)模块,因为它是torchvision支持的默认的图像后端(引擎)。

  1. from PIL import Image
  2. img = Image.open("dog.jpg")

注:一只黄色的拉布拉多狗

接下来,对图像进行预处理,并将图像tensor增加一个维度,因为一张图像只有3个维度,但模型要求输入是4纬张量,也就是默认是输入一批图像,而不是一张。经过处理后,我们的batch_t也代表一批图像,不过其中只有一张图像而已。

  1. img_t = transform(img)
  2. batch_t = torch.unsqueeze(img_t, 0)

步骤4:模型推断

最后,是时候使用预训练模型来看看模型认为图像是什么了。 首先,我们需要将我们的模型置于eval模式。然后进行推断。

  1. alexnet.eval()
  2. out = alexnet(batch_t)
  3. print(out.shape)

out为一个二维向量,行为1,列为1000。前面我们提到,模型输入要求是一批图像,如果我们输入5张图像,则out的行为5,列为1000,列表示1000个类,每个类的置信度,行代表每个图像。故每一行中的1000个元素,分别表示该行对应图像为每个类的可能性。

我们如何处理out这个二维向量呢?我们还没有得到图像的类(或标签)。为此,我们将首先从一个包含所有1000个标签的列表的文本文件中读取和存储标签。请注意,行号确定了类号,因此确保不更改该顺序是非常重要的。

  1. with open('imagenet_classes.txt') as f:
  2. classes = [line.strip() for line in f.readlines()]

classes为含有1000个类名称字符串的列表(ImageNet数据集共包含1000个类)。

由于AlexNet和ResNet已经在相同的Image Net数据集上进行了训练,我们可以对这两种模型使用相同的类列表。 现在,我们需要找出输出向量out中的最大置信度发生在哪个位置。我们将用这个位置的下标来得出预测。

  1. _, index = torch.max(out, 1)
  2. percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
  3. print(classes[index[0]], percentage[index[0]].item())

_, index = torch.max(out, 1) 的功能是,取出二维向量out中每一行的最大值及下标。index为每行最大值的下标组成的列表。

percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100 的功能是,对二维向量out中的每一行进行归一化(softmax是常用的归一化指数函数),然后取出第一行并使每个元素乘以100,得到本例中拉布拉多狗对应的每种类型的可能性(即置信度)。

print(classes[index[0]], percentage[index[0]].item())的功能是,打印类名及其置信度。classes[index[0]]中,index[0]是第一行最大值的下标,即第一张图片的最大置信度的下标,index[1]为第二张图片的,index[2]是第三张图片的,以此类推。classes[index[0]]即是最大置信度对应的类名称。所以classes列表的元素顺序不可更改。percentage[index[0]].item()中,index[0]的含义同上,percentage[index[0]]代表最大置信度那一项,.item()取出该项的值。

好了!该模型预测图像是一只拉布拉多猎犬,置信度为41.58%。

但这听起来太低了。让我们看看模型认为图像属于其他类的置信度。

  1. _, indices = torch.sort(out, descending=True)
  2. [(classes[idx], percentage[idx].item()) for idx in indices[0][:5]]

torch.sort将out进行排序,默认对每一行排序,这里我们指定以递减的方式排序。

[(classes[idx], percentage[idx].item()) for idx in indices[0][:5]]中,indices[0][:5]产生一个临时的一维列表,包含indices的第一行前5个元素,也就是置信度最高的5个元素的下标值。

结果如下:

  1. [('Labrador retriever', 41.585166931152344),
  2. ('golden retriever', 16.59166145324707),
  3. ('Saluki, gazelle hound', 16.286880493164062),
  4. ('whippet', 2.8539133071899414),
  5. ('Ibizan hound, Ibizan Podenco', 2.3924720287323)]

不知道你注意到没有,所有这些都是狗的品种。因此,模型设法预测,这是一只狗的置信度很高,但它不是很确定的狗的品种。 就这样!你所需要的就是这4个步骤来使用预训练模型进行图像分类。 我们试试ResNet怎么样?

1.4.使用ResNet进行图像分类

我们将使用ResNet50(一个50层卷积神经网络)。 让我们快速地看一下使用ResNet50进行图像分类所需的步骤。

  1. # First, load the model
  2. resnet = models.resnet50(pretrained=True)
  3. # Second, put the network in eval mode
  4. resnet.eval()
  5. # Third, carry out model inference
  6. out = resnet(batch_t)
  7. # Forth, print the top 5 classes predicted by the model
  8. _, indices = torch.sort(out, descending=True)
  9. percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
  10. [(classes[idx], percentage[idx].item()) for idx in indices[0][:5]]

结果如下

  1. [('Labrador retriever', 48.25556945800781),
  2. ('dingo, warrigal, warragal, Canis dingo', 7.900787353515625),
  3. ('golden retriever', 6.916920185089111),
  4. ('Eskimo dog, husky', 3.6434383392333984),
  5. ('bull mastiff', 3.0461232662200928)]

就像AlexNet一样,ResNet预测它是一只狗,并预测它是一只拉布拉多猎犬,概率为48.26%。

2.模型比较

到目前为止,我们已经讨论了如何使用预先训练的模型来执行图像分类,但我们还没有回答的一个问题是,我们如何决定为特定的任务选择哪种模型。在本节中,我们将根据以下标准对预先训练的模型进行比较:

1、top-1错误:如果模型预测的置信度最高的类与真正的类不相同,则会发生top-1错误。

2、top-5错误:当真正的类不在模型预测置信度最高的前5个类中时,会发生前top-5错误(按置信度排序)。

3、CPU的推断时间:推断时间是模型推理过程所花费的时间。

4、GPU的推断时间 : 当推理运行于gpu时,所花费的推理时间

5、模型大小:这里的大小代表由PyTorch提供的预训练模型的.pth(或.pt) 文件所占用的物理空间。

一个好的模型将具有低的TOP-1错误,低的TOP-5错误,低的CPU和GPU上的推理时间和低的模型大小。 所有的实验都是在同一个输入图像上进行的,并且多次进行,这样就可以将特定模型的所有结果的平均值进行分析。实验是在GoogleColab上进行的。现在,让我们看看所获得的结果。

2.1.模型的准确性比较

我们要讨论的第一个标准是TOP-1和TOP-5错误。TOP-1错误要求较为苛刻,还有另一个错误度量称为TOP-5错误。如果前5个预测类中没有一个是正确的,则预测被归类为错误。

从图表中注意到,这两个错误都遵循类似的趋势。AlexNet是基于深度学习的第一次尝试,此后错误率有所改善。值得一提的是GoogLeNet、ResNet、VGGNet、ResNext。

2.2.推断时间比较

接下来,我们将根据模型推理所需的时间对模型进行比较。多次向每个模型提供一幅图像,并对所有迭代的推理时间进行平均。对CPU和GoogleColab上的GPU执行了类似的过程。尽管顺序上有一些变化,但我们可以看到,SqueseNet、ShuffleNet和ResNet-18的推理时间非常低,这正是我们想要的。

2.3.模型大小比较

很多时候,当我们在Android或iOS设备上使用深度学习模型时,模型大小成为决定因素,有时甚至比准确性更重要。挤压网的最小模型大小(5MB),其次是ShuffleNetV2(6MB)和MobileNetV2(14MB)。很明显,为什么这些模型在使用深度学习的移动应用程序中是首选的。

2.4.总体比较

我们讨论了哪种模型在某个特定标准上表现更好。我们可以在一个气泡图中包含所有这些重要的细节,然后我们可以参考这些细节,根据我们的需求来决定要选择哪个模型。 我们使用的x坐标是Top-1错误(较低更好)。其中y坐标是GPU上的推理时间,单位为毫秒(更低更好),气泡大小代表模型大小(较低更好)

注: 较小的气泡在模型大小方面更好。 在原点附近的气泡在精度和速度方面都更好。

3.最终结果

从上面的图表可以清楚地看出,ResNet50是所有三个参数(尺寸小,更接近原点)的最佳模型。 在推理时间上,DenseNets和ResNext101是昂贵的。 AlexNet 和SqueezeNet 有相当高的错误率。

在这篇文章中,我们介绍了如何使用TorchVision模块来使用预先训练的模型进行图像分类,只需4步。我们还进行了模型比较,以决定选择什么样的模型,当然这取决于我们的项目需求。在下一篇文章中,我们将介绍在PyTorch中如何使用迁移学习,在自定义数据集上训练模型。

参考链接:https://www.learnopencv.com/pytorch-for-beginners-image-classification-using-pre-trained-models/

使用PyTorch中的预训练模型进行图像分类相关推荐

  1. Pytorch使用预训练模型进行图像分类

    在本文中,我们将介绍一些使用预训练网络的实际例子,这些网络出现在TorchVision模块的图像分类中. Torchvision包包括流行的数据集,模型体系结构,和通用的图像转换为计算机视觉.基本上, ...

  2. pytorch加载预训练模型_Pytorch-Transformers 1.0发布,支持六个预训练框架,含27个预训练模型...

    AI 科技评论按:刚刚在Github上发布了开源 Pytorch-Transformers 1.0,该项目支持BERT, GPT, GPT-2, Transfo-XL, XLNet, XLM等,并包含 ...

  3. AI:2020年6月22日北京智源大会演讲分享之09:40-10:10Mari 教授《基于显式上下文表征的语言处理》、10:10-10:40周明教授《多语言及多模态任务中的预训练模型》

    AI:2020年6月22日北京智源大会演讲分享之09:40-10:10Mari 教授<基于显式上下文表征的语言处理>.10:10-10:40周明教授<多语言及多模态任务中的预训练模型 ...

  4. pytorch官网预训练模型百度云下载 VGG16,Densnet169,inception_v3

    在深度学习领域采用预训练的模型参数进行迁移学习往往会得到事半功倍的效果.但是在使用pytorch加载预训练模型是往往会在下载模型时报错.VGG16报错:https://download.pytorch ...

  5. 周明教授《多语言及多模态任务中的预训练模型》Mari 教授《基于显式上下文表征的语言处理》

    AI:2020年6月22日北京智源大会演讲分享之09:40-10:10Mari 教授<基于显式上下文表征的语言处理>. 10:10-10:40周明教授<多语言及多模态任务中的预训练模 ...

  6. 机器学习花朵图像分类_在PyTorch中使用转移学习进行图像分类

    想了解更多好玩的人工智能应用,请关注公众号"机器AI学习 数据AI挖掘","智能应用"菜单中包括:颜值检测.植物花卉识别.文字识别.人脸美妆等有趣的智能应用.. ...

  7. 离线或在线加载pytorch、mmdetection预训练模型vgg、resnet、alexnet等

    pytorch预训练模型包含多个经典网络,比如resnet系列.vgg系列和alexnet等,预训练模型可以提高网络提取特征的能力,提升训练模型的性能.下面介绍一下加载预训练模型的两种方式: 第一种是 ...

  8. 在PyTorch中使用卷积神经网络建立图像分类模型

    概述 在PyTorch中构建自己的卷积神经网络(CNN)的实践教程 我们将研究一个图像分类问题--CNN的一个经典和广泛使用的应用 我们将以实用的格式介绍深度学习概念 介绍 我被神经网络的力量和能力所 ...

  9. Pytorch中更改预训练权重文件的下载位置

    目录 1. 参考链接 2. 更改方法 3. 一个小技巧 1. 参考链接 Pytorch更改预训练权重下载位置 pytorch---修改预训练模型下载路径 2. 更改方法 在线加载的预训练权重默认存放位 ...

最新文章

  1. android在线切图工具,9Cut切图工具
  2. [译]Node v5.0.0 (Stable)
  3. Bootstrap学习遇到的role属性--- 无障碍网页应用属性
  4. 【问链财经-区块链基础知识系列】 第二十一课 区块链应用于大宗商品供应链金融
  5. Task 1 天池赛 - 二手车交易价格预测
  6. 客户要求ASP.NET Core API返回特定格式,怎么办?
  7. pycharm引用python_在Python/Pycharm中找不到引用“xxx”
  8. PowerDesigner生成建表脚本时字段超过15字符就发生错误
  9. 监控mysql主从的工具_zabbix利用percona-toolkit工具监控Mysql主从同步状态
  10. 【Mac】 自带的播放器quicktimeplayer 如何带声音2倍速播放
  11. 在 uniapp 中使用阿里图标
  12. Qt QTreeView 详解
  13. ac管理器管理员密码忘记了_选择密码管理器
  14. JSch执行shell命令
  15. Zookeeper安装部署调试命令
  16. Serial Programming HOWTO——Linux 串口编程HOWTO
  17. 【IntelliJ IDEA 2019.2】java读取发送pc串口数据
  18. 蓝绿发布,灰度发布及滚动发布
  19. springboot打war包
  20. 索引(SqlServer2008)

热门文章

  1. Java中Arrays数组工具类的使用全解
  2. 蒸馏神经网络(Distill the Knowledge in a Neural Network)
  3. 检测ip和port是否可连接
  4. 微信ios版本的两个灰度功能和一些小改变
  5. js 字符串删除首尾_js去除字符串首尾空格
  6. 基因组学(Geonomics)
  7. MySQL数据库课程设计_Wincc实现与数据库的交互以及报表的实现方式
  8. 【T+】畅捷通T+登录的时候,提示仅支持以手机号或邮箱登录。
  9. java计算工作日_Java工作日计算工具类
  10. 解决win10 自动同步时间灰色