pytorch中数组维度理解与numpy中类似,pytorch中维度用dim表示,numpy中用axis表示
这里主要想说下维度的变化。
dim = x ,表示在第x为上进行操作,那个维度会发生变化。

一、二维数组

1. 两个二维数组的拼接

维度为(2,3)与(2,4)的数组拼接后的维度是(2,7)

import torch
a = torch.Tensor(np.arange(6).reshape(2,3))
b = torch.Tensor(np.arange(8).reshape(2,4))
print(a,'\n ',a.shape)
print(b,'\n',b.shape)
c = torch.cat((a,b),dim = 1)
print('concatenate:\n',c,'\n',c.shape)

结果

tensor([[0., 1., 2.],[3., 4., 5.]]) a: torch.Size([2, 3])
tensor([[0., 1., 2., 3.],[4., 5., 6., 7.]]) torch.Size([2, 4])
concatenate:tensor([[0., 1., 2., 0., 1., 2., 3.],[3., 4., 5., 4., 5., 6., 7.]]) torch.Size([2, 7])

2. 二维数组求sum、max等

dim = 0,第一个维度划掉,得到一个一维向量。比如,a是(2,3),dim = 0,得到的结果是(3,)维的;如果dim=1,得到的结果是(2,)

print('sum dim=0',torch.sum(a,dim=0))
print('sum dim=1',torch.sum(a,dim=1))
print('******* max *****')
print('max dim=0',torch.max(a,dim=0))
print('max dim=1',torch.max(a,dim=1))

输出

tensor([[0., 1., 2.],[3., 4., 5.]]) torch.Size([2, 3])
sum dim=0 tensor([3., 5., 7.])
sum dim=1 tensor([ 3., 12.])
******* max *****
max dim=0 torch.return_types.max(
values=tensor([3., 4., 5.]),
indices=tensor([1, 1, 1]))
max dim=1 torch.return_types.max(
values=tensor([2., 5.]),
indices=tensor([2, 2]))

二、三维数组

1. 两个三维数组的拼接

两个三位数组拼接,有个要求,除了dim维,其余维的维度要相同。

  • 比如 a是(2,3,4),b是(3,2,4)那么a与b无论在哪个维上都不能拼接。因为它们没有两个相同的维度
  • 如果a与b维度相同,都是(2,3,4),那么他们无论在哪个维上都可以拼接。dim = 0,结果是(4,3,4),dim = 1,结果是(2,6,4),dim =2,结果是(2,3,8)
  • dim = x,就将两个数组dim维上的数字相加,得到最终输出维度。
a = torch.Tensor(np.arange(24).reshape(2,3,4))
b = torch.Tensor(np.arange(24,48).reshape(2,3,4))
print(a,'\n ',a.shape)
print(b,'\n',b.shape)
c = torch.cat((a,b),dim = 2)
print('concatenate:\n',c,'\n',c.shape)

输出结果

tensor([[[ 0.,  1.,  2.,  3.],[ 4.,  5.,  6.,  7.],[ 8.,  9., 10., 11.]],[[12., 13., 14., 15.],[16., 17., 18., 19.],[20., 21., 22., 23.]]]) torch.Size([2, 3, 4])
tensor([[[24., 25., 26., 27.],[28., 29., 30., 31.],[32., 33., 34., 35.]],[[36., 37., 38., 39.],[40., 41., 42., 43.],[44., 45., 46., 47.]]]) torch.Size([2, 3, 4])
concatenate:tensor([[[ 0.,  1.,  2.,  3., 24., 25., 26., 27.],[ 4.,  5.,  6.,  7., 28., 29., 30., 31.],[ 8.,  9., 10., 11., 32., 33., 34., 35.]],[[12., 13., 14., 15., 36., 37., 38., 39.],[16., 17., 18., 19., 40., 41., 42., 43.],[20., 21., 22., 23., 44., 45., 46., 47.]]]) torch.Size([2, 3, 8])

2. 三维数组求sum、max等

  • 类似于二维数组,会消去dim维度
  • shape=(2,3,4)的数组,在dim=0上求和或者取最大后,结果的shape = (3,4)
  • pytorch求max,同时返回两个值(max,indices)
a = torch.Tensor(np.arange(24).reshape(2,3,4))
print(a,'\n',a.shape)
print('sum dim=0',torch.sum(a,dim=0))
print('sum dim=1',torch.sum(a,dim=1))
print('sum dim=2',torch.sum(a,dim=2))
print('******* max *****')
print('max dim=0',torch.max(a,dim=0))
print('max dim=1',torch.max(a,dim=1))
print('max dim=2',torch.max(a,dim=2))

结果

tensor([[[ 0.,  1.,  2.,  3.],[ 4.,  5.,  6.,  7.],[ 8.,  9., 10., 11.]],[[12., 13., 14., 15.],[16., 17., 18., 19.],[20., 21., 22., 23.]]]) torch.Size([2, 3, 4])
sum dim=0 tensor([[12., 14., 16., 18.],[20., 22., 24., 26.],[28., 30., 32., 34.]])
sum dim=1 tensor([[12., 15., 18., 21.],[48., 51., 54., 57.]])
sum dim=2 tensor([[ 6., 22., 38.],[54., 70., 86.]])
******* max *****
max dim=0 torch.return_types.max(
values=tensor([[12., 13., 14., 15.],[16., 17., 18., 19.],[20., 21., 22., 23.]]),
indices=tensor([[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1]]))
max dim=1 torch.return_types.max(
values=tensor([[ 8.,  9., 10., 11.],[20., 21., 22., 23.]]),
indices=tensor([[2, 2, 2, 2],[2, 2, 2, 2]]))
max dim=2 torch.return_types.max(
values=tensor([[ 3.,  7., 11.],[15., 19., 23.]]),
indices=tensor([[3, 3, 3],[3, 3, 3]]))

pytorch中数组维度的理解相关推荐

  1. numpy中数组维度的理解

    参考 这篇文章主要是为了弄清楚数组按每个维度进行计算时,具体的操作是什么样的. 一.数组中的各个维度表示的是什么? 为了便于理解,用单位体表示,剥去一层中括号后,得到的数据. 1. 以二维数组为例 i ...

  2. Pytorch中tensor维度和torch.max()函数中dim参数的理解

    Pytorch中tensor维度和torch.max()函数中dim参数的理解 维度 参考了 https://blog.csdn.net/qq_41375609/article/details/106 ...

  3. pytorch中张量的阶数理解

    pytorch中张量的阶数理解 推荐打开2个页面,对比原四阶张量理解各阶的对应关系. 创建一个四阶张量: import torch x = torch.linspace(0,71,72).view(2 ...

  4. pytorch中同维度张量matmul运算

    pytorch中同维度张量matmul运算 零维矩阵乘法 零维矩阵乘法实际上就是标量(数的)乘法. import torch M01=torch.tensor(3) M02=torch.tensor( ...

  5. pytorch中gather函数的理解

    官方解释,很清楚了 torch.gather(input,dim,index,out=None) → Tensortorch.gather(input, dim, index, out=None) → ...

  6. pytorch中torch.manual_seed()的理解

    使用

  7. Pytorch中维度dim的理解使用

    0 引言 pytorch中的维度dim主要被用在torch.softmax和torch.max等等函数中.理清dim的意思对于正确使用这些函数有重要意义. 1 相关博文: Pytorch笔记:维度di ...

  8. PyTorch中F.cross_entropy()函数

    对PyTorch中F.cross_entropy()的理解 PyTorch提供了求交叉熵的两个常用函数: 一个是F.cross_entropy(), 另一个是F.nll_entropy(), 是对F. ...

  9. pytorch 中维度(Dimension)概念的理解

    pytorch 中维度(Dimension)概念的理解 Dimension为0(即维度为0时) 维度为0时,即tensor(张量)为标量.例如:神经网络中损失函数的值即为标量. 接下来我们创建一个di ...

最新文章

  1. 要活102年,阿里凭借的是什么?
  2. Jvm 系列(十一)Java 语法糖背后的真相
  3. 大数据概述 ——林子雨老师第一课
  4. python中tushare数据可以导出嘛_Python与交易策略分析tushare/baostock库介绍(附代码)...
  5. 用c语言求解n阶线性矩阵方程组,用C语言求解N阶线性矩阵方程Axb简单解法.docx
  6. 《移动应用开发》实验报告——疫情地图
  7. 远程控制漏洞CNVD-2022-10270/CNVD-2022-03672 向日葵RCE复现与解决
  8. 算法模板java_我的Java设计模式-模板方法模式
  9. gluoncv 目标检测,训练自己的数据集
  10. net: 熟悉传统的交换机芯片
  11. python android开发视频教程_程序员学习视频教程汇总
  12. AE插件:saber插件mac版怎么安装?saber插件汉化版安装教程
  13. datagrid表格序号列
  14. UE4实现风格化渲染(一):UserNormalTranslator工具的使用
  15. 局部加权回归LOESS(locally weighted regression)
  16. 太原工业学院计算机实训中心,法学实训实验中心
  17. 移动端和前端开发的共性
  18. 苏宁小BIU诞生日 机器人员工正式“入职”
  19. Kinect for Windows SDK 1.6的改进及新特性
  20. 【考研资料】计算机/软件超过百所大学的考研初试复试资料!

热门文章

  1. sqlite3 表里插入系统时间(时间戳)
  2. java实战调用数据库_实战php调用java类由java类读数据库完成相关操作(InberWrite)_PHP...
  3. 副族元素从上到下原子半径_长知识:化学元素大阅兵
  4. redhat5安装oracle详细步骤,redhat5安装oracle11g详细教程
  5. 交换机短路_你了解交换机的相关知识吗?还不赶快收藏起来
  6. Django 笔记4 -- 模板
  7. 7th思妙想 Fun事连连,今天范式7岁啦!
  8. 【Python】全国气温骤降,Python一键生成御寒指南,助你温暖过冬!!
  9. 【Python】这款拓展让你的jupyter lab使用更高效
  10. 【机器学习】基于LightGBM算法实现数据挖掘!