第二卷 第三章 理解rank-1&rank-5精度

在讨论高级深度学习主题(例如迁移学习)之前,让我们先退后一步,讨论1级、5级和N级准确率的概念。在阅读深度学习文献时,尤其是在计算机视觉和图像分类领域,您可能会遇到排名准确性的概念。例如,几乎所有介绍在ImageNet数据集上评估的机器学习方法的论文都根据1级和5级准确度展示了他们的结果(我们将找出为什么1级和5级准确度都在后面报告)在这一章当中)。

1级和5级准确率到底是什么?它们与传统的准确度(即精度)有何不同?在本章中,我们将讨论排序准确率,学习如何实现它,然后将其应用于在Flowers-17和CALTECH-101数据集上训练的机器学习模型。

        1、排名准确率

左:我们的神经网络将尝试分类的青蛙的输入图像。右:汽车的输入图像。

排名准确率最好用一个例子来解释。假设我们正在评估在CIFAR-10数据集上训练的神经网络,其中包括十个类别:飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车。给定以下输入图像(图左),我们要求神经网络计算每个类标签的概率——然后神经网络返回表(左)中列出的类标签概率。

左:我们的神经网络为图4.1(左)返回的类标签概率。右:我们的网络为图4.1(右)返回的类标签概率

概率最大的类标签是青蛙(97.3%),这确实是正确的预测。如果我们重复这个过程:

1.第1步:计算数据集中每个输入图像的类标签概率。

2.第2步:确定真实标签是否等于具有最大概率的预测类标签。

3.第3步:计算第2步为真的次数。

我们将得到1级准确率。因此,Rank-1准确度是最高预测与真实标签匹配的预测百分比——这是我们用来计算的“标准”准确度类型:取正确预测的总数,然后除以数据集中的数据点数。

然后我们可以将这个概念扩展到5级精度。我们不仅关心排名第一的预测,还关心排名前5的预测。我们的评估过程现在变成:

1.第1步:计算数据集中每个输入图像的类标签概率。

2.第2步:按降序对预测的类别标签概率进行排序,以便将概率较高的标签放在列表的前面。

3.第3步:确定第2步的前5个预测标签中是否存在真实标签。

4.第4步:计算第3步为真的次数。

Rank-5只是rank-1准确率的扩展:我们将考虑来自网络的前5个预测,而不是只关心来自分类器的#1预测。例如,让我们再次考虑将基于任意神经网络归入CIFAR-10类别的输入图像(图右)。通过我们的网络后,我们获得了表(右)中详述的类标签概率。

我们的形象显然是一辆汽车;然而,我们的网络将卡车报告为最高预测——这将被视为对1级准确度的错误预测。但是如果我们检查网络的top-5预测,我们会看到汽车实际上是第二个预测,这在计算rank-5准确率时是准确的。这种方法也可以很容易地扩展到任意rank-N精度;然而,我们通常只计算1级和5级精度——这就提出了一个问题,为什么要计算5级精度呢?

对于CIFAR-10数据集,计算5级精度有点傻,但对于大型、具有挑战性的数据集,尤其是细粒度分类,查看给定CNN的前5名预测通常很有帮助。也许我们可以在Szegedy等人中找到我们为什么计算rank-1和rank-5精度的最好例子。我们可以在左边看到一只西伯利亚哈士奇犬和一只爱斯基摩犬在右侧(图)。大多数人无法识别这两种动物之间的区别。然而,这两个类都是ImageNet数据集中的有效标签。

左:西伯利亚雪橇犬。右图:爱斯基摩犬。

在处理包含许多具有相似特征的类标签的大型数据集时,我们经常检查5级精度作为1级精度的扩展,以了解我们的网络的性能。在理想情况下,我们的1级准确率会以与5级准确率相同的速度增加,但在具有挑战性的数据集上,情况并非总是如此。

因此,我们还检查了5级的准确率,以确保我们的网络在以后的时代仍在“学习”。可能会出现这样的情况,即1级准确度在训练结束时停滞不前,但5级准确度会随着我们的网络学习更多判别特征而继续提高(但判别力不足以超过排名第一的预测)。最后,根据图像分类挑战(ImageNet是典型示例),您需要同时报告1级和5级精度。

        (1)计算Rank-1和Rank-5

计算rank-1和rank-5的精度可以通过构建一个简单的实用函数来完成。在我们的pyimagesearch模块中,我们将通过添加一个名为rating.py的文件将此功能添加到utils子模块:

ranked.py,参考代码如下:

# import the necessary packages
import numpy as npdef rank5_accuracy(preds, labels):# initialize the rank-1 and rank-5 accuraciesrank1 = 0rank5 = 0# loop over the predictions and ground-truth labelsfor (p, gt) in zip(preds, labels):# sort the probabilities by their index in descending# order so that the more confident guesses are at the# front of the listp = np.argsort(p)[::-1]# check if the ground-truth label is in the top-5# predictionsif gt in p[:5]:rank5 += 1# check to see if the ground-truth is the #1 predictionif gt == p[0]:rank1 += 1

        (2)实现Rank

为了演示如何计算数据集的rank-1和rank-5精度,让我们回到第2章,我们在ImageNet数据集上使用预训练的卷积神经网络作为特征提取器。基于这些提取的特征,我们对数据训练了逻辑回归分类器并评估了模型。我们现在将扩展我们的准确性报告以包括5级准确性。

当我们为逻辑回归模型计算1级和5级准确度时,请记住,任何机器学习、神经网络或深度学习模型都可以计算1级和5级准确度——这很常见在深度学习社区之外遇到这两个指标。说了这么多,打开一个新文件,将其命名为rank_accuracy.py,然后插入以下代码:

# import the necessary packages
from ranked import rank5_accuracy
import argparse
import pickle
import h5py# construct the argument parse and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-d", "--db", required=True,
help="path HDF5 database")
ap.add_argument("-m", "--model", required=True,
help="path to pre-trained model")
args = vars(ap.parse_args())# load the pre-trained model
print("[INFO] loading pre-trained model...")
model = pickle.loads(open(args["model"], "rb").read())# open the HDF5 database for reading then determine the index of
# the training and testing split, provided that this data was
# already shuffled *prior* to writing it to disk
db = h5py.File(args["db"], "r")
i = int(db["labels"].shape[0] * 0.75)
# make predictions on the testing set then compute the rank-1
# and rank-5 accuracies
print("[INFO] predicting...")
preds = model.predict_proba(db["features"][i:])
(rank1, rank5) = rank5_accuracy(preds, db["labels"][i:])# display the rank-1 and rank-5 accuracies
print("[INFO] rank-1: {:.2f}%".format(rank1 * 100))
print("[INFO] rank-5: {:.2f}%".format(rank5 * 100))# close the database
db.close()

        (3)在Flowers-17数据集上计算Rank

计算Flowers-17数据集的rank-1和rank-5准确度:

在Flowers-17数据集上,我们使用对从VGG16架构中提取的特征进行训练的逻辑回归分类器获得了92.06%的1级准确率。检查5级准确率,我们发现我们的分类器几乎是完美的,获得了99.41%的5级准确率。

        (4)在CALTECH-101数据集上计算Rank

让我们再试试另一个例子,这个在更大的CALTECH-101数据集上:

在这里,我们获得了95.58%的1级准确率和99.45%的5级准确率,与之前难以突破60%分类准确率的计算机视觉和机器学习技术相比,这是一个重大改进。

        2、小结

在本章中,我们回顾了1级和5级准确率的概念。Rank-1准确率是我们的真实标签以最大概率等于我们的类标签的次数。Rank-5准确度在rank-1准确度的基础上扩展,使其更加“宽松”——这里我们计算rank-5准确度作为我们的真实标签出现在前5个预测类标签中的次数最大的概率。

我们通常在大型、具有挑战性的数据集(例如ImageNet)上报告5级准确率,在这种情况下,即使是人类通常也很难正确标记图像。在这种情况下,如果地面实况标签仅存在于其前5个预测中,我们将认为我们模型的预测是“正确的”。正如我们在StarterBundle的第9章中讨论的那样,一个真正泛化能力强的网络将在其top-5概率中产生上下文相似的预测。

最后,请记住,1级和5级准确率并非特定于深度学习和图像分类——您也经常会在其他分类任务中看到这些指标。

Python视觉深度学习系列教程 第二卷 第3章 理解rank-1rank-5精度相关推荐

  1. Python视觉深度学习系列教程 第二卷 第4章 微调网络

            第二卷 第四章 微调网络         在上一章中,我们学习了如何将预训练的卷积神经网络视为特征提取器.使用这个特征提取器,我们通过网络向前传播我们的图像数据集,提取给定层的激活,并 ...

  2. Python视觉深度学习系列教程 第二卷 第9章 Kaggle竞赛:Cat与Dog

    第二卷 第九章 Kaggle竞赛:Cat与Dog 在本章中,我们将扩展我们的工作并学习如何为HDF5数据集定义一个图像生成器,适用于使用Keras训练卷积神经网络.该生成器将打开HDF5数据集,为要训 ...

  3. Python视觉深度学习系列教程 第二卷 第10章 GoogLeNet

    第二卷 第十章 GoogLeNet 在本章中,我们将研究GoogLeNet 架构. 首先,与 AlexNet 和 VGGNet 相比,模型架构很小(权重本身为约28MB).作者能够通过移除完全连接的层 ...

  4. Python视觉深度学习系列教程 第三卷 第8章 在ImageNet上训练SqueezeNet

            第三卷 第八章 在ImageNet上训练SqueezeNet         关于在ImageNet大规模视觉识别挑战 (ILSVRC) 上训练深度神经网络的最后一章中,将讨论Sque ...

  5. Python视觉深度学习系列教程 第三卷 第5章 在ImageNet上训练VGGNet

            第三卷 第五章 在ImageNet上训练VGGNet 在本章中,我们将从头开始学习如何在 ImageNet 数据集上训练 VGG16 网络架构. 该网络的特点是简单,仅使用3*3 卷积 ...

  6. Python视觉深度学习系列教程 第三卷 第2章 什么是ImageNet?

            第三卷 第二章 什么是ImageNet? 在本章中,我们将讨论 ImageNet 数据集和相关的 ImageNet 大规模视觉识别挑战 (ILSVRC) . 这一挑战是评估图像分类算法 ...

  7. Python视觉深度学习系列教程 第三卷 第9章 Kaggle竞赛:情绪识别

            第三卷 第九章 Kaggle竞赛:情绪识别 在本章中,我们将解决Kaggle的面部表情识别挑战.为了完成这项任务,我们将在训练数据上从头开始训练一个类似VGG的网络,同时考虑到我们的网 ...

  8. Python视觉深度学习系列教程 第三卷 第12章 年龄和性别预测

            第三卷 第十二章 年龄和性别预测 为了构建一个能够识别照片中人物年龄和性别的系统,我们将使用 Adience 数据集.我们训练两个模型,一个用于年龄识别,另一个用于性别识别.此外,我们 ...

  9. Python视觉深度学习系列教程 第三卷 第14章 从头开始训练Faster R-CNN

            第三卷 第十四章 从头开始训练Faster R-CNN 本章的目的是达到以下四点: 1.在您的系统上安装和配置 TensorFlow Object Detection API. 2.在 ...

  10. Python视觉深度学习系列教程 第一卷 第21章 案例:使用CNN破解验证码

            第一卷 第二十一章 案例:使用CNN破解验证码 Breaking captchas with deep learning, Keras, and TensorFlow - PyImag ...

最新文章

  1. 博士称因待遇不公要离职,被学校要求返还51万元补偿费
  2. 新疆乌鲁木齐3D打印智能硬件接活咯
  3. 2019春第一次课程设计实验报告
  4. java聊天室程序_Java简易聊天室程序socket
  5. Problem B: C语言习题 学生成绩输入和输出
  6. C++|Qt之QTcpServer基本用法
  7. 高能!8段代码演示Numpy数据运算的神操作
  8. python 魔法方法常用_12个常用的IPython魔法命令
  9. SQL Profile (总结4)--使用演示示例
  10. 如何写出优秀的开源简历
  11. PKM2 - PKManager 基于内容的个人知识管理工具 5M 绿色免费
  12. 医疗机构做直播前的预告应该怎么做?
  13. 如何使用idea自带的数据库可视化工具
  14. CentOS7中安装PostgreSQL客户端
  15. 【mysql 练习题】查询和“01”号同学所学课程完全相同的其他同学的学号
  16. mysql mtq_Mysql 入门学习指南
  17. NTFS磁盘读写工具Mounty 1.9 Mac免费版
  18. 超级炫酷的动态登陆界面视频背景
  19. 显示屏漏光会有什么影响
  20. 领域驱动设计系列文章(1)——通过现实例子显示领域驱动设计的威力

热门文章

  1. 怎么进入服务器修改跳转域名,域名怎么跳转到别的网站
  2. 我的世界java凋零_我的世界:玩家还原Java版已“消失”的三个结构,造型让人难忘?...
  3. CMD控制台光标无法显示
  4. PPT怎么画出好看的三维示意图
  5. Android 内存剖析 – 发现潜在问题
  6. 转:python中range和xrange的区别
  7. 3D立体显示大屏幕拼接视频墙系统解决方案【转】
  8. 梳理50道经典计算机网络面试题
  9. 黑客全军覆没 书生安全云实战各路高手
  10. 黑客用“勒索病毒”展示肌肉,但你了解什么是“白帽黑客”吗?