听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index。

用法

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

  • input:一个tensor数据
  • k:指明是得到前k个数据以及其index
  • dim: 指定在哪个维度上排序, 默认是最后一个维度
  • largest:如果为True,按照大到小排序; 如果为False,按照小到大排序
  • sorted:返回的结果按照顺序返回
  • out:可缺省,不要

topk最常用的场合就是求一个样本被网络认为前k个最可能属于的类别。我们就用这个场景为例,说明函数的使用方法。
假设一个tensor F ∈ R N × D F \in R^{N \times D} F∈RN×D,N是样本数目,一般等于batch size, D是类别数目。我们想知道每个样本的最可能属于的那个类别,其实可以用torch.max得到。如果要使用topk,则k应该设置为1。

import torchpred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(1, dim=1, largest=True, sorted=True)
print(indices)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=1, keepdim=True)print(indices_max == indices)
# pred
tensor([[-0.1480, -0.9819, -0.3364,  0.7912, -0.3263],[-0.8013, -0.9083,  0.7973,  0.1458, -0.9156],[-0.2334, -0.0142, -0.5493,  0.0673,  0.8185],[-0.4075, -0.1097,  0.8193, -0.2352, -0.9273]])
# indices, shape为 【4,1】,
tensor([[3],   #【0,0】代表 第一个样本最可能属于第一类别[2],   # 【1, 0】代表第二个样本最可能属于第二类别[4],[2]])
# indices_max等于indices
tensor([[True],[True],[True],[True]])

现在在尝试一下k=2

import torchpred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(2, dim=1, largest=True, sorted=True)  # k=2
print(indices)
# pred
tensor([[-0.2203, -0.7538,  1.8789,  0.4451, -0.2526],[-0.0413,  0.6366,  1.1155,  0.3484,  0.0395],[ 0.0365,  0.5158,  1.1067, -0.9276, -0.2124],[ 0.6232,  0.9912, -0.8562,  0.0148,  1.6413]])
# indices
tensor([[2, 3],[2, 1],[2, 1],[4, 1]])

可以发现indices的shape变成了【4, k】,k=2。
其中indices[0] = [2,3]。其意义是说明第一个样本的前两个最大概率对应的类别分别是第3类和第4类。

大家可以自行print一下values。可以发现values的shape和indices的shape是一样的。indices描述了在values中对应的值在pred中的位置。

PyTorch中的topk函数详解相关推荐

  1. PyTorch中torch.norm函数详解

    torch.norm() 是 PyTorch 中的一个函数,用于计算输入张量沿指定维度的范数.具体而言,当给定一个输入张量 x 和一个整数 p 时,torch.norm(x, p) 将返回输入张量 x ...

  2. PyTorch中的matmul函数详解

    PyTorch中的两个张量的乘法可以分为两种: 两个张量对应的元素相乘(element-wise),在PyTorch中可以通过torch.mul函数(或者∗*∗运算符)实现 两个张量矩阵相乘(Matr ...

  3. timm 视觉库中的 create_model 函数详解

    timm 视觉库中的 create_model 函数详解 最近一年 Vision Transformer 及其相关改进的工作层出不穷,在他们开源的代码中,大部分都用到了这样一个库:timm.各位炼丹师 ...

  4. 【Pytorch】torch.argmax 函数详解

    文章目录 一.一个参数时的 torch.argmax 函数 1. 介绍 2. 实例 二.多个参数时的 torch.argmax 函数 1. 介绍 2. 实例 实例1:二维矩阵 实例2:三维矩阵 实例3 ...

  5. python getattr_Python中的getattr()函数详解:

    标签:Python中的getattr()函数详解: getattr(object, name[, default]) -> value Get a named attribute from an ...

  6. python input函数详解_对Python3中的input函数详解

    下面介绍python3中的input函数及其在python2及pyhton3中的不同. python3中的ininput函数,首先利用help(input)函数查看函数信息: 以上信息说明input函 ...

  7. Python中的bbox_overlaps()函数详解

    Python中的bbox_overlaps()函数详解 想要编写自己的目标检测算法,就需要掌握bounding box(边界框)之间的关系.在这之中,bbox_overlaps()函数是一个非常实用的 ...

  8. java的匿名函数_JAVA语言中的匿名函数详解

    本文主要向大家介绍了JAVA语言中的匿名函数详解,通过具体的内容向大家展示,希望对大家学习JAVA语言有所帮助. 一.使用匿名内部类 匿名内部类由于没有名字,所以它的创建方式有点儿奇怪.创建格式如下: ...

  9. PyTorch入门笔记-matmul函数详解

    PyTorch入门笔记-matmul函数详解 本文转载自:PyTorch入门笔记-matmul函数详解 - 腾讯云开发者社区-腾讯云 (tencent.com) 41409)]

最新文章

  1. java局部变量说法不正确的是_关于Java的成员变量和局部变量,下面说法错误的是...
  2. matlab中基本函数的用法
  3. linux awk 脚本格式,偷偷学习shell脚本之awk编辑器
  4. Java程序设计实验2
  5. 联想g400从u盘启动计算机,联想g400怎么进bios设置u盘启动图文教程
  6. Windows显卡切换
  7. 【GT跑车】GT跑车是什么意思 GT跑车有哪些
  8. mysql工作日_mysql自定义函数计算时间段内的工作日(支持跨年)
  9. 编制投标书常见的115个错误
  10. 51单片机 AT24C02 PROTEUS 读写程序 源码
  11. CANoe隐藏属性——Multi CANoe
  12. MathType的下载和安装以及添加到word中
  13. Intellij idea 第一天
  14. 解决WORD “未找到引用源”问题
  15. Android实现 制作隐藏图片效果 (幻影坦克)
  16. 波士顿矩阵图的制作--基于Excel
  17. 学习笔记1--自动驾驶汽车介绍
  18. 关于汉字与Ascii码
  19. SQL SERVER 解析XML字符串
  20. 海康摄像头SDK开机启动第一个摄像头不显示问题

热门文章

  1. 云目录(DaaS )快速入门
  2. C/C++抽红包系统
  3. 英语-新视野大学英语四课后翻译(全)
  4. 新5G网络架构较复杂 设立面对不少挑战
  5. 随机过程(random process)
  6. IDC许可证是什么证?IDC办理条件及材料
  7. 双通道(双CPU)服务器主板上内存条的安装方式
  8. 学生宿舍管理数据库设计(下)
  9. dba怎么报考_2019年报考DBA需要什么条件,要求是不是很高?
  10. 主板维修从入门到精通