PyTorch中的topk方法以及分类Top-K准确率的实现

Top-K 准确率

在分类任务中的类别数很多时(如ImageNet中1000类),通常任务是比较困难的,有时模型虽然不能准确地将ground truth作为最高概率预测出来,但通过学习,至少groud truth的准确率能够在所有类中处于很靠前的位置,这在现实生活中也是有一定应用意义的,因此除了常规的Top-1 Acc,放宽要求的Tok-K Acc也是某些分类任务的重要指标之一。

Tok-K准确率:即指在模型的预测结果中,前K个最高概率的类中有groud truth,就认为在Tok-K准确率的要求下,模型分类成功了。

PyTorch中的topk方法

PyTorch中并没有直接提供计算模型Top-K分类准确率的接口,但是提供了一个topk方法,用来获得某tensor某维度中最高或最低的K个值。

函数接口

torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor)

同样有tensor.topk的使用方式,参数及返回值类似。

参数说明

input:输入张量

dim:指定在哪个维度取topk

k:前k大或前k小值

largest:取最大(True)或最小(False)

sorted:返回值是否有序

返回值说明

返回两个张量:values和indices,分别对应前k大/小值的数值和索引,注意返回值的各维度的意义,不要搞反了,后面实验会说。

实验

我们在这里模拟常见的分类任务的情况,设置batch size为4,类别数为10,这样模型输出应为形状为(10,4)的张量。

output = torch.rand(4, 10)
print(output)
print('*'*100)
values, indices = torch.topk(output, k=2, dim=1, largest=True, sorted=True)
print("values: ", values)
print("indices: ", indices)
print('*'*100)
print(output.topk(k=2, dim=1, largest=True, sorted=False))      # tensor.topk的用法

输出:

tensor([[0.7082, 0.5335, 0.9494, 0.7792, 0.3288, 0.6303, 0.0335, 0.6918, 0.0778,0.6404],[0.3881, 0.8676, 0.7700, 0.6266, 0.8843, 0.8902, 0.4336, 0.5385, 0.8372,0.1204],[0.9717, 0.2727, 0.9086, 0.7797, 0.1216, 0.4793, 0.1149, 0.1544, 0.7292,0.0459],[0.0424, 0.0809, 0.1597, 0.4177, 0.4798, 0.7107, 0.9683, 0.7502, 0.1536,0.3994]])
****************************************************************************************************
values:  tensor([[0.9494, 0.7792],[0.8902, 0.8843],[0.9717, 0.9086],[0.9683, 0.7502]])
indices:  tensor([[2, 3],[5, 4],[0, 2],[6, 7]])
****************************************************************************************************
torch.return_types.topk(
values=tensor([[0.9494, 0.7792],[0.8902, 0.8843],[0.9717, 0.9086],[0.9683, 0.7502]]),
indices=tensor([[2, 3],[5, 4],[0, 2],[6, 7]]))

注意输出的行是用户指定的dim的k个最大/小值(实验中sorted=True,所以是有序返回的),列是其他未指定的维度,不要搞反了。

分类Top-K准确率的实现

实现

借助刚刚介绍的PyTorch中的topk方法实现的分类任务的Top-K准确率计算方法。

def accuracy(output, target, topk=(1, )):       # output.shape (bs, num_classes), target.shape (bs, )"""Computes the accuracy over the k top predictions for the specified values of k"""with torch.no_grad():maxk = max(topk)batch_size = target.size(0)_, pred = output.topk(maxk, 1, True, True)pred = pred.t()correct = pred.eq(target.view(1, -1).expand_as(pred))res = []for k in topk:correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)res.append(correct_k.mul_(100.0 / batch_size))return res

实验

我们同样拿上面的分类任务做实验,batch size为4,类别数为10,给定label:2,1,8,5,为了方便观察,计算Top-1,2准确率(ImageNet-1K中通常计算Top-1,5准确率)。

测试代码:

output = torch.rand(4, 10)
label = torch.Tensor([2, 1, 8, 5]).unsqueeze(dim=1)
print(output)
print('*'*100)
values, indices = torch.topk(output, k=2, dim=1, largest=True, sorted=True)
print("values: ", values)
print("indices: ", indices)
print('*'*100)print(accuracy(output, label, topk=(1, 2)))

输出:

tensor([[0.8721, 0.7391, 0.1365, 0.3017, 0.2840, 0.2400, 0.6473, 0.3965, 0.5449,0.7518],[0.7120, 0.8533, 0.2809, 0.9515, 0.2971, 0.8182, 0.5498, 0.0797, 0.8027,0.6916],[0.4540, 0.8468, 0.9022, 0.5144, 0.2007, 0.7292, 0.5559, 0.0290, 0.6664,0.2076],[0.1793, 0.0205, 0.7322, 0.4918, 0.6194, 0.9179, 0.1639, 0.6346, 0.8829,0.3573]])
****************************************************************************************************
values:  tensor([[0.8721, 0.7518],[0.9515, 0.8533],[0.9022, 0.8468],[0.9179, 0.8829]])
indices:  tensor([[0, 9],[3, 1],[2, 1],[5, 8]])
[tensor([25.]), tensor([50.])]

可以看到在top1准确率时只有最后一个样本与标签对应,故Top-1准确率为1 / 4 =25%,而在top2准确率时样本2,4预测成功了,Top-2准确率为50%,符合我们的预期。

有疑惑或异议欢迎留言讨论。

Ref:https://pytorch.org/docs/master/generated/torch.topk.html#torch-topk

PyTorch中的topk方法以及分类Top-K准确率的实现相关推荐

  1. Lesson 15.2 学习率调度在PyTorch中的实现方法

    Lesson 15.2 学习率调度在PyTorch中的实现方法   学习率调度作为模型优化的重要方法,也集成在了PyTorch的optim模块中.我们可以通过下述代码将学习率调度模块进行导入. fro ...

  2. Pytorch函数之topk()方法

    根据Pytorch中的手册可以看到,topk()方法用于返回输入数据中特定维度上的前k个最大的元素. torch.topk(input, k, dim=None, largest=True, sort ...

  3. pytorch 中的topk函数

    pytorch中topk() 函数用法 1. 函数介绍 最近在代码中看到这两个语句 maxk = max(topk) _, pred = output.topk(maxk, 1, True, True ...

  4. PyTorch中的topk函数详解

    听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index. 用法 torch.topk(input, k, dim=None, largest=True, sor ...

  5. numpy实现torch的topk方法

    torch中提供了topk方法用来返回矩阵中对应维度中最大的K个元素以及在对应维度中的index,但是numpy并没有提供和torch一样的topk方法,所以在这里通过numpy的argpartiti ...

  6. Top k问题(线性时间选择算法)

    问题描述:给定n个整数,求其中第k小的数. 分析:显然,对所有的数据进行排序,即很容易找到第k小的数.但是排序的时间复杂度较高,很难达到线性时间,哈希排序可以实现,但是需要另外的辅助空间. 这里我提供 ...

  7. 循环神经网络实现文本情感分类之Pytorch中LSTM和GRU模块使用

    循环神经网络实现文本情感分类之Pytorch中LSTM和GRU模块使用 1. Pytorch中LSTM和GRU模块使用 1.1 LSTM介绍 LSTM和GRU都是由torch.nn提供 通过观察文档, ...

  8. 利用Pytorch中深度学习网络进行多分类预测(multi-class classification)

    从下面的例子可以看出,在 Pytorch 中应用深度学习结构非常容易 执行多类分类任务. 在 iris 数据集的训练表现几乎是完美的. import torch.nn as nn import tor ...

  9. PyTorch中nn.Module类中__call__方法介绍

    在PyTorch源码的torch/nn/modules/module.py文件中,有一条__call__语句和一条forward语句,如下: __call__ : Callable[-, Any] = ...

最新文章

  1. 线性回归介绍及分别使用最小二乘法和梯度下降法对线性回归C++实现
  2. python int str_python int str
  3. 29 仿京东放大镜案例
  4. 那些关于区块链革命的事情
  5. 基础功能2-python修改文件中所有文件名
  6. Linux时间date与timedatectl
  7. c mysql 异步查询_C#Mysql – 使用锁在数据库上查询异步等待服务器
  8. SAP License:如何修改科目为为未清项目管理
  9. Ubuntu添加swap分区
  10. linux触摸屏代码解析,Linux触摸屏驱动解析
  11. 项目总结25:海康威视SDK-Java二次开发-客流量分析
  12. SpringBoot修改默认端口号
  13. php curl 下载文件
  14. bzoj3162 独钓寒江雪
  15. openwrt运行n2n服务器,Windows下使用N2N搭建局域网,全球局域网(重写)
  16. 哪些公司有计算机财务管理,计算机财务管理汇总.doc
  17. 文件夹总是在新窗口打开
  18. SODA-大型活动大规模人群的识别和疏散:从公交2.0到公交3.0
  19. python 傅里叶曲线拟合
  20. 设计计算机程序时 要考虑计算的过程,算法和程序设计练习题

热门文章

  1. DM7数据库DMAP服务异常,报错“ dmap init failed, code[-7157]: 管道文件已存在”
  2. 孔雀石绿磷酸盐检测试剂盒的特点和应用
  3. 一个实战案例带你走完python数据分析全流程:豆瓣电影评论的关键词云图制作
  4. 基本共射放大电路的动态分析(低频、Ri、Ro大小对电路影响的分析)
  5. MyBatisPlus代码生成器使用
  6. 谷底飞龙的技术博客集
  7. 复选框checkbox实现批量删除
  8. 苹果cms模板_首涂第三套苹果CMSv10自适应视频站模板
  9. 苹果cms模板_苹果CMS V10 开源影视系统,搭建一个属于自己的影视网
  10. DVWA11_Insecure CAPTCHA(不安全的验证码)