Pytorch中tensor维度和torch.max()函数中dim参数的理解


维度

参考了 https://blog.csdn.net/qq_41375609/article/details/106078474 ,
对于torch中定义的张量,感觉上跟矩阵类似,不过常见的矩阵是二维的。当定义一个多维的张量时,比如使用 a =torch.randn(2, 3, 4) 创建一个三维的张量,返回的是一个

[[[-0.5166,  0.8298,  2.4580, -1.9504],[ 0.1119, -0.3321, -1.3478, -1.9198],[ 0.0522, -0.6053,  0.8119, -1.3469]],[[-0.3774,  0.9283,  0.7996, -0.3882],[-1.1077,  1.0664,  0.1263, -1.0631],[-0.9061,  1.0081, -1.2769,  0.1035]]
]

当使用 a.size() 返回维度结果时,结果为 torch.Size([2, 3, 4]),这里面有三个数值,数值的个数代表维度的个数 ,所以这里有三个维度,可以理解为一个中括号代表一个维度。数值 2 处在第一个位置,第一个位置代表是第一维度,2代表这个维度有两个元素,也就是第一个 [ ] 里面两个元素,3代表在第二个维度,也就是在第一个 [ ] 中的两个元素里面,又有三个元素,依次类推。这里格式十分固定,一旦定义,必须是一个元素里面有两个元素,这两个元素中每个再包含三个元素,再包含,依此类推,否则会报错。类似与树,维数等于相似的树的深度-1(以根为第一层),每一层就是一维。
如生成一个

torch.tensor([[[1, 2, 3, 4][3, 4, 2, 1][4, 1, 2, 3]][[2, 1, 3, 4][3, 4, 2, 1][4, 1, 2, 3]]]
)

方便理解,以下图的形式展示,这里竖线代表一个维度,竖线上所有节点代表同一维度的所有元素。在下面所有图中,同颜色的元素都是按照从上往下按顺序排列的。


一、dim参数

在使用torch.max()函数和其他的一些函数时,会有dim这个参数。官网中定义使用torch.max()函数时,生成的张量维度会比原来的维度减少一维,除非原来的张量只有一维了. 要减少消去的是哪一维便是由dim参数决定的,dim参数实际上指的是我们计算过程中所要消去的维度。因为在比较时必须要指定使用哪些数字来比较 ,或者进行其他计算,比如 max 使一些数据中只要大的,sum只取和的结果,自然就会删减其他的一些数据从而引起降维。


以上面生成的三维的张量为例子,有三个维度,但是维度的数字顺序是 dim = 0, 1, 2;
当指定torch.max(a,dim=0)时,也就是要删除第一个维度,删除第一个维度的话,那还剩下两个维度,也就是dim =1 ,2 。 剩下的两个维度的参数是 3 和 4,那么删除第一个维度后应该剩下torch.tensor(3, 4)这样形式的张量, dim参数可以使用负数,也就是负的索引,与列表中的索引相似,在本例中dim = -1 与dim = 2是一样的。
返回的

values=tensor([[-0.3774,  0.9283,  2.4580, -0.3882],[ 0.1119,  1.0664,  0.1263, -1.0631],[ 0.0522,  1.0081,  0.8119,  0.1035]]),
indices=tensor([[1, 1, 0, 1], [0, 1, 1, 1],[0, 1, 0, 1]]))

从返回的结果看是这种形式,产生这种结果是因为删除了第一个维度那么该返回 3 * 4 这种二维的张量,第一维中两个元素的形式正好是 3 * 4, 那么就将这个维度的两个子元素中的相应的位置的值比较一下大小,那么会生成一个新的 3 * 4 的张量,再返回一下正好可以,indices记录的是 "在比较中胜利的元素“ 原来所属的元素的位置。例如在第一个位置上,-0.3774比 -0.5166大,所以返回-0.3774,-0.3774是在第一维度里面的第二个元素的位置上,这个位置索引为1.剩下的位置的同理。

用树状图理解

图中的不同颜色的三个子元素,在相同位置比较,大的返回形成新的元素,其他位置同理。那么黑色的维度 dim = 1 也就消除了.


dim = 0时,如图,两个3*4的子元素张量 相对应的位置 比较大小,剩下一个3 * 4的二维张量

当dim = 2或者 dim = -1,删除的是最后一个维度,在这个例子中吗,将所有的第三维的子元素最大的值返回,返回2 * 3,看起来就像是找所在矩阵一行里面的最大值一样。

values=tensor([[2.4580, 0.1119, 0.8119],[0.9283, 1.0664, 1.0081]]),
indices=tensor([[2, 0, 2],[1, 1, 1]]))

举一个sum()例子,当使用上述使用torch.sum(a,dim = 1),消去第二个维度,剩下一,三维度,也就是2 * 4形状的张量。将第二维上面的三个子元素相同位置的相加,第二维也就不见了,第一维中的两个元素的子元素就从3*4形成了一个1 *4的,总的形状就变成了2 * 4

tensor([[-0.3525, -0.1076,  1.9221, -5.2171],[-2.3912,  3.0028, -0.3510, -1.3478]])

再举一个例子,使用torch.randn(2, 3, 4, 5) 创建一个四维张量,使用torch.max(dim=-3),也就是torch.max(dim=1)

torch.tensor([[[[ 0.7106,  1.3332, -1.0423, -0.1609, -0.2846],[ 0.6400,  2.2507, -0.5740, -0.9986,  0.0066],[-0.0527,  1.4097, -0.4439,  0.4846,  1.5418],[ 1.0027,  0.9398,  1.5202, -1.1660, -0.1230]],[[ 0.5725, -1.7838, -0.7320, -1.4419,  1.5762],[ 0.6407,  0.0527,  1.7005,  1.6350, -0.2610],[ 1.3307, -0.3210, -1.7203,  0.9050,  0.2442],[ 0.9418, -0.1511,  0.8248, -0.0786, -0.6153]],[[ 1.0182,  0.3190, -0.3408, -2.1801, -0.3931],[ 1.2325, -0.3304,  1.0116,  0.0791, -1.1174],[ 0.2331, -0.9062,  0.5680,  1.6061, -1.0933],[ 0.6935, -0.5140, -0.5178,  1.2557,  0.2319]]],[[[ 1.0916,  0.7171, -0.7936,  1.1741, -0.5457],[-0.6541, -0.6720, -0.7892, -0.6961, -1.1030],[ 1.8680, -0.1746,  0.8455, -1.1021,  0.6855],[ 1.2070, -0.6152, -1.3345, -0.0724,  1.2062]],[[-0.5130, -0.5510, -0.8278, -0.2279, -1.4425],[ 0.2073,  1.3065, -0.0326, -1.2566,  0.6097],[-1.0413,  1.2638, -0.8479, -0.0353, -0.7191],[ 0.0662,  0.7683,  0.2145, -0.0988, -2.3348]],[[ 0.6631, -0.0040, -0.0681,  1.1681,  1.3904],[-0.1761,  1.4668,  0.9670, -0.5629,  0.2941],[-0.6235,  0.1844, -0.4321, -0.0581, -0.9352],[ 0.1717, -0.9188,  0.3014, -0.0734, -0.1324]]]])

在这里面,当dim = 1,也就是要动第二个维度手,那么删掉它后剩下torch.randn(2,4, 5)形式,那么就
[[ 0.7106, 1.3332, -1.0423, -0.1609, -0.2846],
[ 0.6400, 2.2507, -0.5740, -0.9986, 0.0066],
[-0.0527, 1.4097, -0.4439, 0.4846, 1.5418],
[ 1.0027, 0.9398, 1.5202, -1.1660, -0.1230]]


[[ 0.5725, -1.7838, -0.7320, -1.4419, 1.5762],
[ 0.6407, 0.0527, 1.7005, 1.6350, -0.2610],
[ 1.3307, -0.3210, -1.7203, 0.9050, 0.2442],
[ 0.9418, -0.1511, 0.8248, -0.0786, -0.6153]]
还有
[[ 1.0182, 0.3190, -0.3408, -2.1801, -0.3931],
[ 1.2325, -0.3304, 1.0116, 0.0791, -1.1174],
[ 0.2331, -0.9062, 0.5680, 1.6061, -1.0933],
[ 0.6935, -0.5140, -0.5178, 1.2557, 0.2319]]

这三个子元素相应为位置比较大小,大的留下,生成新的张量,列如对于第一个位置,1.0182 比 0.5725 和 0.7106 大,所以它留下,它在元素在要是动手的维度里面的位置索引为2,其它同理
但是这个维度还之前还有一个维度,那么只要对所有的同维度的做相同操作就可以了,所以返回之如下

values=tensor([[[ 1.0182,  1.3332, -0.3408, -0.1609,  1.5762],[ 1.2325,  2.2507,  1.7005,  1.6350,  0.0066],[ 1.3307,  1.4097,  0.5680,  1.6061,  1.5418],[ 1.0027,  0.9398,  1.5202,  1.2557,  0.2319]],[[ 1.0916,  0.7171, -0.0681,  1.1741,  1.3904],[ 0.2073,  1.4668,  0.9670, -0.5629,  0.6097],[ 1.8680,  1.2638,  0.8455, -0.0353,  0.6855],[ 1.2070,  0.7683,  0.3014, -0.0724,  1.2062]]]),
indices=tensor([[[2, 0, 2, 0, 1],[2, 0, 1, 1, 0],[1, 0, 2, 2, 0],[0, 0, 0, 2, 2]],[[0, 0, 2, 0, 2],[1, 2, 2, 2, 1],[0, 1, 0, 1, 0],[0, 1, 2, 0, 0]]]))

Pytorch中tensor维度和torch.max()函数中dim参数的理解相关推荐

  1. torch.max()函数==》返回该维度的最大值以及该维度最大值对应的索引

    今天在学习TTSR的过程总遇到了一行代码,我发现max()函数竟然可以返回两个值,于是我决定重新学习一下这个函数 R_lv3_star, R_lv3_star_arg = torch.max(R_lv ...

  2. Pytorch中torch.nn.Softmax的dim参数含义

    自己搞了一晚上终于搞明白了,下文说的很透彻,做个记录,方便以后翻阅 Pytorch中torch.nn.Softmax的dim参数含义

  3. 习题 9.5 建立一个对象数组,内放5个学生的数据(学号、成绩),设立一个函数max,用指向对象的指针作函数参数,在max函数中找出5个学生中成绩最高者,并输出其学号。

    C++程序设计(第三版) 谭浩强 习题9.5 个人设计 习题 9.5 建立一个对象数组,内放5个学生的数据(学号.成绩),设立一个函数max,用指向对象的指针作函数参数,在max函数中找出5个学生中成 ...

  4. Python中求最大值和最小值max()函数、min()函数

    [小白从小学Python.C.Java] [Python全国计算机等级考试] [Python数据分析考试必会题] ● 标题与摘要 Python中求最大值和最小值 max()函数.min()函数 ● 选 ...

  5. python max函数中使用key

    博客转移到个人站点:python max函数中使用key 代码: a = dict(((1,3),(0,-1),(3,21))) m = max(a, key=a.get) 为什么这返回与最大值对应的 ...

  6. _,predicted = torch.max(outputs.data,dim)

    dim=1时,按行返回最大值所在索引 dim=0时,按列返回最大值所在索引 _,predicted = torch.max(outputs.data,dim):返回最大值所在索引 predicted ...

  7. 习题 8.21 用指向指针的指针的方法对n个整数排序并输出。要求将排序单独写成一个函数。n个整数在主函数中输入,最后在主函数中输出。

    C程序设计(第四版) 谭浩强 习题8.21 个人设计 习题 8.21 用指向指针的指针的方法对n个整数排序并输出.要求将排序单独写成一个函数.n个整数在主函数中输入,最后在主函数中输出. 代码块: 方 ...

  8. 习题 6.20 用指向指针的指针的方法对n个整数排序并输出。要求将排序单独写成一个函数。整数和n在主函数中输入。最后在主函数中输出。

    C++程序设计(第三版) 谭浩强 习题6.20 个人设计 习题 6.20 用指向指针的指针的方法对n个整数排序并输出.要求将排序单独写成一个函数.整数和n在主函数中输入.最后在主函数中输出. 代码块: ...

  9. PyTorch疑难杂症(1)——torch.matmul()函数用法总结

    目录 一.函数介绍 二.常见用法 2.1 两个一维向量的乘积运算 2.2 两个二维矩阵的乘积运算 2.3 一个一维向量和一个二维矩阵的乘积运算 2.4 一个二维矩阵和一个一维向量的乘积运算 2.5 其 ...

最新文章

  1. 真正的全栈工程师!B站硬核UP主自己造了一个激光雷达
  2. java 微信退款接口_java版微信和支付宝退款接口
  3. 【技术综述】如何Finetune一个小网络到移动端(时空性能分析篇)
  4. [九]RabbitMQ-客户端源码之Consumer
  5. [css] 父元素下有子元素,子元素也有高度但父元素的高度为何为0呢?分析下可能出现的原因及解决方法
  6. pytorch 三维点分类_基于深度学习的三维重建——MVSNet系列论文解读
  7. 哈佛机器人,学会了轻功水上漂
  8. 存储设备在linux名称,Linux下的存储设备的管理
  9. 梁刚:基于云原生技术建设“武汉健康云”云平台架构
  10. 操作系统之三种进程通信方式
  11. 基于Hadoop的数据分析案例-陌陌聊天软件数据分析
  12. Java到底能做什么事情呢?
  13. 卫生间里的上下铺,那滋味~
  14. CAN开发 入门知识总结
  15. virtualbox安装mac os x雪豹
  16. 小米智能插座接入HomeKit
  17. lisp 左手钢筋_左手键配置程序
  18. 判断入射满射c语言编码,数学上可以分三类函数包括() 答案:单射双射满射...
  19. 操作系统的概念、四个特征以及os的发展和分类
  20. STM32的空闲中断

热门文章

  1. mysql将VARBINARY转为字符串显示方法
  2. hbuilder php打包,关于hbuilder打包h5+app
  3. Qt开发技术:QDBus介绍、编译与Demo
  4. linux ip白名单,黑名单
  5. 华大开发板SW失效,无法下载程序
  6. 京东市值达4600亿元创历史新高
  7. Oracle转换MySql之递归start with
  8. HTML5+CSS期末大作业:环保网站设计——环境保护(10页) 含设计报告 HTML+CSS+JavaScript 静态HTML环境保护网页制作下载 DIV+CSS环保网页设计代码
  9. 跟着禅一练功夫-少林八段锦对身体有什么样的益处
  10. php使用常量和变量输出圆的面积,PHP常量和变量分别是什么?有什么区别?