Pytorch常用函数

  • 一、torch.max
    • 1.调用方式
    • 2.相关介绍
    • 3.代码实例及图示理解
  • 二、torch.argmax
    • 1.调用方式
    • 2.相关介绍
    • 3.代码实例及图示理解
  • 三、torch.max与torch.argmax的联系

一、torch.max

1.调用方式

1)torch.max(input):只需送入输入张量;

2)torch.max(input, dim, keepdim=False, *, out=None):送入张量的同时,需要指定沿着哪个维度进行最大值运算;
这两种调用方式对输入张量的形状没有要求,一维数据或者多维数据都可以。

2.相关介绍

1)返回输入张量中最大值相关数据:

  • 方式一,即不指定dim时,默认将张量展开成一维张量,然后返回第一个最大值;
  • 方式二,即指定dim时,沿着指定的dim维进行最大值运算,输出结果由剩下的维度组成,比如原始维度为H,W,若指定dim=0(即H维),则输出结果由W个元素构成;

2)如果有多个最大值则返回第一个最大值;

3.代码实例及图示理解

首先定义一个简单的方法,当传入张量x和维度dim参数时,分别打印两种调用方式对应的输出:

def print_maxvalue(x,dim=0):max_value=torch.max(x)print(max_value)print('-'*10)max_value,max_index=torch.max(x,dim=dim)print(max_value)print(max_index)

对于二维数据,其形状为(H,W)=(10,2):

x=torch.tensor([[0, 1],[2, 5],[7, 3],[5, 1],[8, 7],[7, 6],[9, 6],[4, 4],[2, 0],[9, 9]])
print_maxvalue(x,dim=0)

输出结果:

tensor(9)  # 所有元素中的第一个最大值
----------
tensor([9, 9])  # 沿着指定dim维进行最大值运算
tensor([6, 9])  # 沿着指定dim维进行最大值运算,并返回最大值对应的下标

结果分析:
(1)方式一
将张量展开成一维张量,其长度为L=10×2=20,然后返回第一个最大值9

(2)方式二
指定dim=0,此维度长度为10,表示沿着第0维进行最大值运算,分别对第0维的10个元素取最大值,并返回其对应下标

二、torch.argmax

1.调用方式

1)torch.argmax(input):只需送入输入张量;

2)torch.argmax(input, dim, keepdim=False):送入张量的同时,需要指定沿着哪个维度进行运算;
这两种调用方式对输入张量的形状没有要求,一维数据或者多维数据都可以。

2.相关介绍

1)返回输入张量中最大值的索引:

  • 方式一,即不指定dim时,默认将张量展开成一维张量,然后返回对应的下标;
  • 方式二,即指定dim时,沿着指定的dim维进行选择,输出结果由剩下的维度组成,比如原始维度为H,W,若指定dim=0(即H维),则输出结果由W个元素构成;

2)如果有多个最大值则返回第一个最大值的下标;
3)返回torch.max函数指定dim时返回的第二个值;

3.代码实例及图示理解

首先定义一个简单的方法,当传入张量x和维度dim参数时,分别打印两种调用方式对应的输出:

def print_(x,dim=0):# print(x)# print(x.shape)print('-' * 10)# 方式一max_index = torch.argmax(x)print(max_index)print('-' * 10)# 方式二max_index = torch.argmax(x, dim=dim)print(max_index)print('-' * 10)

1)一维数据:L

x=torch.tensor([8, 2, 7, 15, 1])
print_(x,dim=0)

输出结果:

tensor(3)
tensor(3)

结果分析:

这是最简单的一种方式,就类似一维数组查询最大元素对应下标的过程一致:

  • 对于方式一,传入一维张量后,直接返回第一个最大值15对应的下标3;
  • 对于方式二, 此时数据只有一个维度,故只能指定沿着维度dim=0进行运算,实质还是在所有元素中寻找最大值并返回其下标;

2)二维数据:(H,W)

x=torch.tensor([[0, 1],[2, 5],[7, 3],[5, 1],[8, 7],[7, 6],[9, 6],[4, 4],[2, 0],[9, 9]])
print_(x,dim=0)
# print_(x,dim=1)

输出结果:

dim=0:H,W->W
tensor(12)
tensor([6, 9]) # 一般分类问题就适用这种情况,在一个批次的预测输出中确定每个样本的类别,输出结果中每个元素即表示批次中每个样本对应的类别
dim=1: H,W->H
tensor(12)
tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0])

结果分析:
(1)方式一
先将输入张量沿着所有维度展开为一维数据,然后返回第一个最大值9对应的下标12

(2)方式二
函数沿着指定的dim维度进行运算,
dim=0表示张量沿着第0维的方向进行运算,比如此处dim=0维长度为10,则表示在每列的10个元素中找到最大值并返回其下标:
此处第一列最大值为9,而其下标为6

dim=1表示张量沿着第1维的方向进行运算,比如此处dim=1维长度为2,则表示在每行的2个元素中找到最大值并返回其下标:
此处第一行最大值为1,而其下标为1

3)多维数据:(N,C,H,W)

x=torch.tensor([[[[1, 3],[7, 8]],[[8, 1],[5, 3]],[[2, 8],[4, 4]]],[[[3, 0],[2, 0]],[[0, 4],[7, 16]],[[4, 8],[4, 3]]]])print_(x,dim=0)
# print_(x,dim=1)
# print_(x,dim=2)
# print_(x,dim=3)

输出结果:

dim=0:N,C,H,W->C,H,W
tensor(19)
tensor([[[1, 0],[0, 0]],[[0, 1],[1, 1]],[[1, 0],[0, 0]]])dim=1:N,C,H,W->N,H,W
tensor(19)
tensor([[[1, 2],[0, 0]],[[2, 2],[1, 1]]])dim=2:N,C,H,W->N,C,W
tensor(19)
tensor([[[1, 1],[0, 1],[1, 0]],[[0, 0],[1, 1],[0, 0]]])dim=3:N,C,H,W->N,C,H
tensor(19)
tensor([[[1, 1],[0, 0],[1, 0]],[[0, 0],[1, 1],[1, 0]]])

结果分析:
开始就说到了,

  • 当调用方式二,指定dim时,函数会沿着指定的维度进行运算,其输出结果的维度由剩余的维度决定;
  • 使用方式一时会直接将张量展开为一维数据,然后返回第一个最大值的下标;

(1)方式一
输入张量形状为(N,C,H,W)=(2,3,2,2),可以清晰地看到,将张量展开为一维数据为长度为L=2×3×2×2=24,且第一个最大值16此时对应的下标为19。

(2)方式二
dim=0维长度为2,剩余维度为(3,2,2)

dim=1维长度为3,剩余维度为(2,2,2)

依次类推…

总结:
其实该函数应用场景最多的是分类任务在进行测试时,判断预测结果的对应类别,此时函数的输入通常为二维数据,只需要使用torch.argmax(x,dim=1)即可达到想要的结果。

三、torch.max与torch.argmax的联系

1)torch.max在寻找输入张量中最大值,而torch.argmax则是寻找最大值对应的下标;
2)二者均使用第一种方式,即未指定dim时,直接将张量展开为一维数据,torch.max返回第一个最大值本身,而torch.argmax则返回最大值的下标;
3)二者均使用第二种方式,即指定dim时,torch.max沿着指定的dim维选取最大值,同时返回最大值本身及其对应下标,而torch.argmax只返回最大值对应的下标。换句话说,torch.argmax的输出结果其实是torch.max指定dim时返回结果中的第二个元素,对应最大值的下标索引;

举个例子:
对于输入张量:

x=torch.tensor([[0, 1],[2, 5],[7, 3],[5, 1],[8, 7],[7, 6],[9, 6],[4, 4],[2, 0],[9, 9]])

torch.argmax(x,dim=0)的输出结果为:

tensor([6, 9])

torch.max(x,dim=0)的输出结果为:

torch.return_types.max(values=tensor([9, 9]),indices=tensor([6, 9]))

其中indices即表示指定dim时找到的最大值的对应下标。

【torch.argmax与torch.max详解】相关推荐

  1. sgd 参数 详解_关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)

    torch.optim的灵活使用详解 1. 基本用法: 要构建一个优化器Optimizer,必须给它一个包含参数的迭代器来优化,然后,我们可以指定特定的优化选项, 例如学习速率,重量衰减值等. 注:如 ...

  2. torch 归一化,momentum用法详解

    torch 有两个地方用Momentum动量,冲量, 一,优化器中的Momentum 主要是在训练网络时,最开始会对网络进行权值初始化,但是这个初始化不可能是最合适的:因此可能就会出现损失函数在训练的 ...

  3. 剖析 | torch.nn.functional.softmax维度详解

    写代码,看代码都要心中有数,输入是什么,输出是什么,结果是如何计算出来的. 一维数据: # -*- coding: utf-8 -*- import torch import numpy as np ...

  4. torch.flatten、np.flatten 详解

    超链接:深度学习工作常用方法汇总,矩阵维度变化.图片.视频等操作,包含(torch.numpy.opencv等) B站视频讲解链接 1. 展平 :flatten torch版: x.flatten(n ...

  5. Pytorch 中的数据类型 torch.utils.data.DataLoader 参数详解

    DataLoader是PyTorch中的一种数据类型,它定义了如何读取数据方式.详情也可参考本博主的另一篇关于torch.utils.data.DataLoader(https://blog.csdn ...

  6. Pytorch之torch.nn.functional.pad函数详解

    torch.nn.functional.pad是PyTorch内置的矩阵填充函数 (1).torch.nn.functional.pad函数详细描述如下: torch.nn.functional.pa ...

  7. Python内置函数 max 详解

    python文档中定义了很多内置函数,今天有个同学问到max函数到底在什么情况下可以使用,模模糊糊的记得在序列中都可以使用,但是并不是准确的回答.以下是更详细的内容 一.参数 首先在文档中查看max函 ...

  8. V-Ray 6 带着新工具走来了~V-Ray 6 for 3DS MAX 详解~

    盼星星盼月亮,V-Ray 6 终于按照计划,与大家见面了. V-Ray官方发布了V-Ray 6 的功能以及新玩法~代表着新 版本的vary渲染器向我们走来了~ 那么新版的V-Ray 6又有哪些值得我们 ...

  9. 深度学习关于NLLLoss损失的数学向个人详解

    一.起因与目的 写这篇文章的起因,就是网络上查了很多NLLLoss(Negative Log-Likelihood Loss,负对数似然损失)相关的详解,但是要么没有讲透,要么就是只讲了如何应用.而我 ...

最新文章

  1. Babel 快速入门
  2. [Android] Android颜色对应的xml配置值
  3. 探寻新的治疗方法,研究人员用VR可视化DNA结构
  4. spring源码解析五
  5. 【JAVA编码专题】总结
  6. hdu 4442 Physical Examination (2012年金华赛区现场赛A题)
  7. lstm需要优化的参数_使用PyTorch手写代码从头构建LSTM,更深入的理解其工作原理...
  8. 开源库Magicodes.Storage正式发布
  9. Windows 下 Redis 的下载和安装
  10. 以色列:新发明大幅提高太阳能发电效率
  11. oracle job有定时执行的功能,可以在指定的时间点或每天的某个时间点自行执行任务。...
  12. 云原生数据库风起云涌,华为云GaussDB破浪前行
  13. 使用TensorFlow.js进行AI在网络摄像头中翻译手势和手语
  14. Android第三十三期 - Dialog的应用
  15. shell for循环命令行_24 道 shell 脚本面试题
  16. micropython和python区别-MicroPython与Python速度对比
  17. AcWing 851. spfa求最短路(解决负边权最短路)
  18. 笔记本电脑装android系统安装教程,电脑上安装Android 10小白教程,大屏Android用起来...
  19. border:0和boder:none区别
  20. 活法 - 第五章 宇宙潮涌 因果之法

热门文章

  1. [HITSC] 2021期末复习-第九章
  2. Istio 1.1安装部署实践
  3. pandas打印某一列_Pandas速查手册中文版
  4. 员工跳槽面试美团,两次面试通过却被offer审核放鸽子,结果蒙了
  5. python注释几种类型
  6. TexturePacker 导出 Egret(白鹭引擎)格式的图集和图片字体
  7. jzoj4240 [五校联考5day2]游行 拓扑排序+倍增lca+线段树优化建图
  8. 2021制定计划待办事项清单的记事便签
  9. 52个python基础代码,你全都知道吗?
  10. 用Session存储数据