PyTorch中的topk函数详解
听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index。
用法
torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
- input:一个tensor数据
- k:指明是得到前k个数据以及其index
- dim: 指定在哪个维度上排序, 默认是最后一个维度
- largest:如果为True,按照大到小排序; 如果为False,按照小到大排序
- sorted:返回的结果按照顺序返回
- out:可缺省,不要
topk最常用的场合就是求一个样本被网络认为前k个最可能属于的类别。我们就用这个场景为例,说明函数的使用方法。
假设一个tensor F ∈ R N × D F \in R^{N \times D} F∈RN×D,N是样本数目,一般等于batch size, D是类别数目。我们想知道每个样本的最可能属于的那个类别,其实可以用torch.max得到。如果要使用topk,则k应该设置为1。
import torchpred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(1, dim=1, largest=True, sorted=True)
print(indices)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=1, keepdim=True)print(indices_max == indices)
# pred
tensor([[-0.1480, -0.9819, -0.3364, 0.7912, -0.3263],[-0.8013, -0.9083, 0.7973, 0.1458, -0.9156],[-0.2334, -0.0142, -0.5493, 0.0673, 0.8185],[-0.4075, -0.1097, 0.8193, -0.2352, -0.9273]])
# indices, shape为 【4,1】,
tensor([[3], #【0,0】代表 第一个样本最可能属于第一类别[2], # 【1, 0】代表第二个样本最可能属于第二类别[4],[2]])
# indices_max等于indices
tensor([[True],[True],[True],[True]])
现在在尝试一下k=2
import torchpred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(2, dim=1, largest=True, sorted=True) # k=2
print(indices)
# pred
tensor([[-0.2203, -0.7538, 1.8789, 0.4451, -0.2526],[-0.0413, 0.6366, 1.1155, 0.3484, 0.0395],[ 0.0365, 0.5158, 1.1067, -0.9276, -0.2124],[ 0.6232, 0.9912, -0.8562, 0.0148, 1.6413]])
# indices
tensor([[2, 3],[2, 1],[2, 1],[4, 1]])
可以发现indices的shape变成了【4, k】,k=2。
其中indices[0] = [2,3]。其意义是说明第一个样本的前两个最大概率对应的类别分别是第3类和第4类。
大家可以自行print一下values。可以发现values的shape和indices的shape是一样的。indices描述了在values中对应的值在pred中的位置。
PyTorch中的topk函数详解相关推荐
- PyTorch中torch.norm函数详解
torch.norm() 是 PyTorch 中的一个函数,用于计算输入张量沿指定维度的范数.具体而言,当给定一个输入张量 x 和一个整数 p 时,torch.norm(x, p) 将返回输入张量 x ...
- PyTorch中的matmul函数详解
PyTorch中的两个张量的乘法可以分为两种: 两个张量对应的元素相乘(element-wise),在PyTorch中可以通过torch.mul函数(或者∗*∗运算符)实现 两个张量矩阵相乘(Matr ...
- timm 视觉库中的 create_model 函数详解
timm 视觉库中的 create_model 函数详解 最近一年 Vision Transformer 及其相关改进的工作层出不穷,在他们开源的代码中,大部分都用到了这样一个库:timm.各位炼丹师 ...
- 【Pytorch】torch.argmax 函数详解
文章目录 一.一个参数时的 torch.argmax 函数 1. 介绍 2. 实例 二.多个参数时的 torch.argmax 函数 1. 介绍 2. 实例 实例1:二维矩阵 实例2:三维矩阵 实例3 ...
- python getattr_Python中的getattr()函数详解:
标签:Python中的getattr()函数详解: getattr(object, name[, default]) -> value Get a named attribute from an ...
- python input函数详解_对Python3中的input函数详解
下面介绍python3中的input函数及其在python2及pyhton3中的不同. python3中的ininput函数,首先利用help(input)函数查看函数信息: 以上信息说明input函 ...
- Python中的bbox_overlaps()函数详解
Python中的bbox_overlaps()函数详解 想要编写自己的目标检测算法,就需要掌握bounding box(边界框)之间的关系.在这之中,bbox_overlaps()函数是一个非常实用的 ...
- java的匿名函数_JAVA语言中的匿名函数详解
本文主要向大家介绍了JAVA语言中的匿名函数详解,通过具体的内容向大家展示,希望对大家学习JAVA语言有所帮助. 一.使用匿名内部类 匿名内部类由于没有名字,所以它的创建方式有点儿奇怪.创建格式如下: ...
- PyTorch入门笔记-matmul函数详解
PyTorch入门笔记-matmul函数详解 本文转载自:PyTorch入门笔记-matmul函数详解 - 腾讯云开发者社区-腾讯云 (tencent.com) 41409)]
最新文章
- java局部变量说法不正确的是_关于Java的成员变量和局部变量,下面说法错误的是...
- matlab中基本函数的用法
- linux awk 脚本格式,偷偷学习shell脚本之awk编辑器
- Java程序设计实验2
- 联想g400从u盘启动计算机,联想g400怎么进bios设置u盘启动图文教程
- Windows显卡切换
- 【GT跑车】GT跑车是什么意思 GT跑车有哪些
- mysql工作日_mysql自定义函数计算时间段内的工作日(支持跨年)
- 编制投标书常见的115个错误
- 51单片机 AT24C02 PROTEUS 读写程序 源码
- CANoe隐藏属性——Multi CANoe
- MathType的下载和安装以及添加到word中
- Intellij idea 第一天
- 解决WORD “未找到引用源”问题
- Android实现 制作隐藏图片效果 (幻影坦克)
- 波士顿矩阵图的制作--基于Excel
- 学习笔记1--自动驾驶汽车介绍
- 关于汉字与Ascii码
- SQL SERVER 解析XML字符串
- 海康摄像头SDK开机启动第一个摄像头不显示问题