PyTorch 中对 tensor 的很多操作如 sum、softmax 等都可以设置 dim 参数用来指定操作在哪一维进行。PyTorch 中的 dim 类似于 numpy 中的 axis,这篇文章来总结一下 PyTorch 中的 dim 操作。首先看一下这个图,图中给出了维度标号,注意区分正负,从左往右数,括号代表的维度分别是 0 和 1 和 2,从右往做为 -3 和 -2 和 -1。待会儿会用到。

图1

括号之间是嵌套关系,代表了不同的维度。从左往右数,两个括号代表的维度分别是 0 和 1 ,在第 0 维遍历得到向量,在第 1 维遍历得到标量.

a = torch.tensor([[1,2],[3,4]])

则 3 个括号代表的维度从左往右分别为 0, 1, 2,在第 0 维遍历得到矩阵,在第 1 维遍历得到向量,在第 2 维遍历得到标量。

b = torch.tensor([[[3, 2], [1, 4]],[[5, 6], [7, 8]]])#张量

在某一维度求和(或者进行其他操作)就是对该维度中的元素进行求和。对于矩阵 a

a = torch.tensor([[1,2],[3,4]])

求 a 在第 0 维的和,因为第 0 维代表最外边的括号,括号中的元素为向量[1, 2],[3, 4],第 0 维的和就是第 0 维中的元素相加,也就是两个向量[1, 2],[3, 4]相加,所以结果为[4,6]

s = torch.sum(a, dim=0)
print(s)

输出

tensor([4, 6])

可以看到,a 是 2 维矩阵,而相加的结果为 1 维向量,可以使用参数keepdim=True来保证维度数目不变。

s = torch.sum(a, dim=0, keepdim=True)
print(s)

输出

tensor([[4, 6]])

同理的现在对dim=1进行操作

a = torch.tensor([[1,2],[3,4]])
s = torch.sum(a,dim=1,keepdim=False)
print(s)
#输出 tensor([3, 7])

keepdim = True

a = torch.tensor([[1,2],[3,4]])
s = torch.sum(a,dim=1,keepdim=True)
print(s)
# 输出 tensor([[3],[7]])

现在对三维张量进行操作

b = torch.tensor([[[3, 2], [1, 4]], [[5, 6], [7, 8]]])
print(b)
# 输出 tensor([[[3, 2],[1, 4]],[[5, 6],[7, 8]]])      

将 b 在第 0 维相加,第 0 维为最外层括号,最外层括号中的元素为矩阵[[3, 2], [1, 4]]和[[5, 6], [7, 8]]。在第 0 维求和,就是将第 0 维中的元素(矩阵)相加

b = torch.tensor([[[3, 2], [1, 4]], [[5, 6], [7, 8]]])
#print(b)
s = torch.sum(b,dim=0)
print(s)
# 输出
tensor([[ 8,  8],[ 8, 12]])
keepdim = True
#输出tensor([[[ 8,  8],[ 8, 12]]])  

求 b 在第 1 维的和,就是将 b 第 1 维中的元素[3, 2]和[1, 4], [5, 6]和 [7, 8]相加,所以

[3,2]+[1,4]=[4,6],[5,6]+[7,8]=[12,14]

b = torch.tensor([[[3, 2], [1, 4]], [[5, 6], [7, 8]]])
#print(b)
s = torch.sum(b,dim=1)
print(s)
#输出 tensor([[ 4,  6],[12, 14]])
keepdim = True
#输出
tensor([[[ 4,  6]],[[12, 14]]])     

则在 b 的第 2 维求和,就是对标量 3 和 2, 1 和 4, 5 和 6 , 7 和 8 求和

b = torch.tensor([[[3, 2], [1, 4]], [[5, 6], [7, 8]]])
#print(b)
s = torch.sum(b,dim=2)
print(s)
#输出 tensor([[ 5,  5],[11, 15]])
keepdim = True
#输出 tensor([[[ 5],[ 5]],[[11],[15]]])    

现在再来看看其他dim有关的api

1.torch.max

在二维中

dim = 0

这个时候取的是矩阵的最大值以及下标即[[1,2],[3,4]]中的最大值和下标,那么应该是[3,4]

a = torch.tensor([[1,2],[3,4]])
print(a)
print(torch.max(a,dim=0))
#输出
values=tensor([3, 4]),
indices=tensor([1, 1]))

dim = 1

这个时候取的是标量是在[1,2]和[3,4]中找到最大值

a = torch.tensor([[1,2],[3,4]])
print(a)
print(torch.max(a,dim=1))
# 输出
values=tensor([2, 4]),
indices=tensor([1, 1]))

在三维中

dim = 0

则是对比这两个矩阵返回同位置最大的值

b = torch.tensor([[[3, 2], [1, 4]], [[5, 6], [7, 8]]])
print(b)
print(torch.max(b,dim=0))
# 输出
tensor([[[3, 2],[1, 4]],[[5, 6],[7, 8]]])
torch.return_types.max(
values=tensor([[5, 6],[7, 8]]),
indices=tensor([[1, 1],[1, 1]]))

dim = 1

[3, 2], [1, 4],[5, 6], [7, 8]这四个中返回最大值

b = torch.tensor([[[3, 2], [1, 4]], [[5, 6], [7, 8]]])
print(b)
print(torch.max(b,dim=1))
# 输出
tensor([[[3, 2],[1, 4]],[[5, 6],[7, 8]]])
torch.return_types.max(
values=tensor([[3, 4],[7, 8]]),
indices=tensor([[0, 1],[1, 1]]))

dim = 2

对标量做比较

b = torch.tensor([[[3, 2], [1, 4]], [[5, 6], [7, 8]]])
print(b)
print(torch.max(b,dim=2))
# 输出
tensor([[[3, 2],[1, 4]],[[5, 6],[7, 8]]])
torch.return_types.max(
values=tensor([[3, 4],[6, 8]]),
indices=tensor([[0, 1],[1, 1]]))

那么同样的现在来思考下dim=-3,-2,-1的情况,图1中已经给出来了正负数维度对应的关系了,现在我们就取dim=-1看看

b = torch.tensor([[[3, 2], [1, 4]], [[5, 6], [7, 8]]])
c = torch.max(b,dim=-1)
print(c)

输出:

torch.return_types.max(
values=tensor([[3, 4],[6, 8]]),
indices=tensor([[0, 1],[1, 1]]))

发现了吗dim=-1和dim=2输出的是相同的,剩余的可自行验证。

pytorch中维度dim的理解相关推荐

  1. Pytorch中维度dim的理解使用

    0 引言 pytorch中的维度dim主要被用在torch.softmax和torch.max等等函数中.理清dim的意思对于正确使用这些函数有重要意义. 1 相关博文: Pytorch笔记:维度di ...

  2. pytorch 中维度(Dimension)概念的理解

    pytorch 中维度(Dimension)概念的理解 Dimension为0(即维度为0时) 维度为0时,即tensor(张量)为标量.例如:神经网络中损失函数的值即为标量. 接下来我们创建一个di ...

  3. pytorch中张量的阶数理解

    pytorch中张量的阶数理解 推荐打开2个页面,对比原四阶张量理解各阶的对应关系. 创建一个四阶张量: import torch x = torch.linspace(0,71,72).view(2 ...

  4. pytorch中gather函数的理解

    官方解释,很清楚了 torch.gather(input,dim,index,out=None) → Tensortorch.gather(input, dim, index, out=None) → ...

  5. pytorch中torch.manual_seed()的理解

    使用

  6. pytorch中数组维度的理解

    pytorch中数组维度理解与numpy中类似,pytorch中维度用dim表示,numpy中用axis表示 这里主要想说下维度的变化. dim = x ,表示在第x为上进行操作,那个维度会发生变化. ...

  7. 关于Pytorch中dim使用的一点记录

    pytorch的许多函数,例如torch.cat().torch.max().torch.mul()等,都包含了dim参数.关于dim这个函数,我想许多人跟我一样,一知半解,比较模糊,下面我就把自己关 ...

  8. PyTorch中F.cross_entropy()函数

    对PyTorch中F.cross_entropy()的理解 PyTorch提供了求交叉熵的两个常用函数: 一个是F.cross_entropy(), 另一个是F.nll_entropy(), 是对F. ...

  9. Pytorch中tensor维度和torch.max()函数中dim参数的理解

    Pytorch中tensor维度和torch.max()函数中dim参数的理解 维度 参考了 https://blog.csdn.net/qq_41375609/article/details/106 ...

  10. Pytorch中dim的理解

    dim的定义 dim 表示维度 x = torch.randn(2, 3, 3)print(x) print(x.size()) print(x.dim()) 输出: tensor([[[-1.694 ...

最新文章

  1. 获取App Store中App的ipa包
  2. 常用,好用的js代码
  3. 史上最大规模,天猫新零售如何爆改100家大润发?
  4. 备战“双11”,阿里云为企业提供一站式资源保障服务
  5. android 开发中java.lang.verifyerror问题
  6. python获取最近N天工作日列表、节假日列表
  7. 为啥用redis解决会话呢?
  8. (数据库系统概论|王珊)第一章绪论-第三节:数据库系统的结构
  9. SAP License:给SAP顾问的5个小贴士
  10. c 无回显读取字符/不按回车即获取字符
  11. html5视频播放器 知乎,iPhone、iPad 如何播放网页调用优酷视频?
  12. 内网穿透:看这一篇就够了!
  13. 北理在线作业答案c语言,北理乐学C语言答案,最新.doc
  14. 安装maven(mvn命令)
  15. [Boost.asio] 深入linux网络编程(四):使用asio搭建商用服务器
  16. 易语言 服务器抓包,易语言调用wincap实现网卡抓包
  17. 国际标准智商测试题目
  18. 携程日处理20亿数据,实时用户行为服务系统架构实践
  19. 最早采用二进制的计算机,计算机 | 中国古代人最早提出的二进制思想?
  20. Linux 之 开机自启动

热门文章

  1. ❤️万字攻略,详解腾讯面试❤️
  2. 数据结构之SWUSTOJ1038: 顺序表中重复数据的删除
  3. docker build -t myip .报错怎么办?
  4. 转载:ecCodes 学习 利用ecCodes Python API对GRIB文件进行读写
  5. 三国杀全武将台词大全(标准+神话再临+一将成名12345+SP+国战+其他+皮肤,更新中)
  6. 【网络实验】10G网络下的真实带宽——CPU负载与网卡TSO、GSO
  7. Redis基础知识 底层数据结构的实现 redis中的对象概念
  8. android 连接tftp 服务器
  9. 傅里叶变换中采样频率(fs)的解读
  10. 计算机语言phal语言,phalapi