目录

  • 理解Tensor的dim
  • 索引
    • 简单索引
    • 用1维的list,numpy,tensor索引
    • 用booltensor索引
  • 切片
  • 花式索引
  • 结语

前一段时间遇到一个花式索引的问题,在搜索良久之后没有找到确切的答案,苦苦摸索许久才理解清楚,于是想要写一篇博客来详细的讲讲我对Pytorch中Tensor索引的一些理解。包括普通的索引,切片,一起花式索引。

理解Tensor的dim

我们可以从1维的tensor开始,例如:

>>> a = torch.tensor([0, 1 ,2, 3, 4])

a可以简单的理解为一个一维的数组,但还不够形象。我么先换个形象的角度,暂时抛去数组的概念,我们可以理解为现在有五个元素(数字)排列一个轴上(类似于数轴,也就是dim),那么我们自然而然的就会想要去给每一个元素一个下标,这也就是我们使用的从0开始的整数下标。那么对于这个轴的方向,每一个下标都会确定一个元素(例如下标2,对应着a里面的2)。那么直观的来看就像下面的图一样:

我们可以注意到,现在参与排队的元素是数字,并且只在一个方向上排队,这就构成了最简单的1维的tensor。那么从1维到2维的过程有两种理解方式:

  • 仍然是在一个方向上排队,但参与排队的元素是已经按照另一个方向排好队的等长的队伍(维度增加)
  • 排队元素仍然是一个数字,但是现在按照两个方向(两个维度)进行排队

那么我们可以看看这样一个2维的tensor,

>>> b = torch.tensor([[0, 1],
>>>                  [2, 3],
>>>                  [4, 5]])

那么按照上面两种理解可以理解为:

  • 现在在dim=0的方向上有三个元素要参与排列,分别是 [ 0 , 1 ] , [ 2 , 3 ] , [ 3 , 4 ] [0, 1], [2,3],[3,4] [0,1],[2,3],[3,4]。这个时候他们自己排列的方向为dim=1,由此构成了一个2维的tensor。
  • 现在有6个元素在两个方向上dim=0,dim=1上排列,其中dim=0方向有三个位置,dim=1的方向上有两个位置。由此构成一个二维的数组。

如下图所示:

那么如果是一个3维或者跟高维度的tensor呢,同样的道理,在多个方向上对已经在某些方向上排列好了的一组数字进行排列。这也就是高维数组是数组的数组说法的一个体现。

回过来再看看二维tensor,如果我们把dim=1方向上的所有数字看作一个整体(1维的一个tensor, [ 0 , 1 ] , [ 2 , 3 ] , [ 4 , 5 ] [0,1],[2,3],[4,5] [0,1],[2,3],[4,5]三个),那么我们可以理解为b是这三个1维tensor的一个1维的tensor,那么这个时候如果我们把dim=0方向上的数字看作一个整体呢,那我们得到的是 [ 0 , 2 , 4 ] , [ 1 , 3 , 5 ] [0,2,4],[1,3,5] [0,2,4],[1,3,5]构成的的一个维tensor。那如果我们把两个维度看作一个整体,那我们得到的是单单一个二维tensor,不用排列。

那么如果是一个3维的或者更高维度的tensor,我们可以将多个维度(m)看作是一个整体,那么剩余的维度(n)构成了一个所有元素都是m维tensor的n维tensor。

那么接下来的索引,切片过程我们也能方便的理解。

索引

简单索引

使用数字进行索引,一般对于一个n维的tensor,索引形式为:
T [ d 0 , d 1 , . . . , d t ] , t < n . T[d_0, d_1,...,d_t], t<n. T[d0​,d1​,...,dt​],t<n.
那么这就相当于将其看作是一个tensor的tensor,将元素按照下标取出。
例如:

>>> a[0]
tensor(0)
>>> a[3]
tensor(3)
>>> b[1]
tensor([2, 3])
>>> b[:, 1]  # 展示将dim=0作为整体
tensor([1, 3, 5])
>>> b[1, 1]
tensor(3)

用1维的list,numpy,tensor索引

将整个tensor看作是一个由n-1维tensor构成的1维tensor。将每个取出的元素排列,构成一个新的tensor。

>>> a[[1, 2 ,1]]
tensor([1, 2, 1])
>>> b[[2, 0, 2]]
tensor([[4, 5],
>>>     [0, 1],
>>>     [4, 5]])
>>> c = torch.rand([4, 3])
>>> c
tensor([[0.6478, 0.3120, 0.6656],[0.4470, 0.6383, 0.6878],[0.9854, 0.9709, 0.4868],[0.1797, 0.3453, 0.9005]])
>>> c[[1, 2, 1]]
tensor([[0.4470, 0.6383, 0.6878],[0.9854, 0.9709, 0.4868],[0.4470, 0.6383, 0.6878]])

用booltensor索引

使用booltensor B对T进行索引,需要满足如下条件:
B . s i z e ( ) : ( b 0 , b 1 , . . . , b t − 1 ) T . s i z e ( ) : ( d 0 , d 1 , . . . , d t − 1 , . . . , d n − 1 ) b i = d i , ∀ i < t , n ≥ t . B.size():(b_0, b_1, ...,b_{t-1}) \\ T.size():(d_0, d_1, ...,d_{t-1}, ...,d_{n-1}) \\ b_i = d_i,\forall i < t,\\ n\ge t. B.size():(b0​,b1​,...,bt−1​)T.size():(d0​,d1​,...,dt−1​,...,dn−1​)bi​=di​,∀i<t,n≥t.
那么其意义就是将True位置的元素取出,构成一个 n − t + 1 n-t+1 n−t+1维的新的tensor。例如:

>>> boolt_1 = torch.tensor([False, True, False, True, True])
>>> a[boolt_1]
tensor([1, 3, 4])
>>> boolt_2 = torch.tensor([True, False, True])
>>> b[boolt_2]
tensor([[0, 1],[4, 5]])
>>> boolt_3 = torch.tensor([[False, True],
>>>                         [True, False],
>>>                         [True, True]])
tensor([1, 2, 4, 5])
>>> boolt_4 = torch.tensor([[False, True],
>>>                         [True, False]])
>>> b[boolt_4]  # 如果不符合条件
Traceback (most recent call last):File "<input>", line 1, in <module>
IndexError: The shape of the mask [2, 2] at index 0 does not match the shape of the indexed tensor [3, 2] at index 0

可以看到维数条件是严格要求B的所有维度的长度正好等于被索引的tensor的对应维度的长度。

切片

相信再list里面就已经学过切片的概念了,主要使用形如 [ s t a r t : e n d : s t e p ] [start : end : step] [start:end:step]的组合进行子序列的抽取,其中 : s t e p : step :step可选,默认为1, s t a r t start start和 e n d end end也可选,分别默认为0和len(obj)。
那么在tensor中遇到的主要是如下形式:
T [ s 0 : e 0 : s t 0 , . . . , s n − 1 : e n − 1 : s t n − 1 ] T[s_0:e_0:st_0,...,s_{n-1}:e_{n-1}:st_{n-1}] T[s0​:e0​:st0​,...,sn−1​:en−1​:stn−1​]
在每个维度处的切片都相当于将对该维度进行相应的切片操作,在该维度上保留对应下标的元素(数,或者tensor,或者什么都不剩)。例如:

>>> a[1:3]
tensor([1, 2])
>>> a[0:4:2]
tensor([0, 2])
>>> b[0:2, 0:1]
tensor([[0],[2]])
>>> b[1:, 1:]
tensor([[3],[5]])
>>> d = torch.randint(5, (3, 3, 3))
>>> d
tensor([[[2, 1, 4],[4, 1, 1],[1, 2, 4]],[[1, 4, 4],[0, 3, 4],[1, 2, 2]],[[4, 4, 4],[1, 3, 3],[0, 0, 4]]])
>>> d[:, 1:2, 0:2]
tensor([[[4, 1]],[[0, 3]],[[1, 3]]])

显然切片和索引可以进行组合,效果就是将所有切片操作的维度在切片之后构成整体,看作是排列的元素,剩余索引的维度就是取出对应的这些元素。例如:

>>> d[[2, 0, 1], 1:2, 0:2]
tensor([[[1, 3]],[[4, 1]],[[0, 3]]])

花式索引

花式索引也是索引的一种,就是使用tensor对tensor进行索引,形如:
T [ t 0 , t 1 , . . . , t n − 1 ] T[t_0, t_1,...,t_{n-1}] T[t0​,t1​,...,tn−1​]
其中的 t i , i = 0 , . . . , n − 1 t_i,i=0,...,n-1 ti​,i=0,...,n−1是维度不限的long型tensor。
能够执行这一操作的先决条件是, t i , i = 0 , . . . , n − 1 t_i,i=0,...,n-1 ti​,i=0,...,n−1能够广播成同一形状。广播的机制建议自行了解。
整个语句的过程大致如下:

  • 首先将 t i , i = 0 , . . . , n − 1 t_i,i=0,...,n-1 ti​,i=0,...,n−1广播成同一形状,如果它们不是同一形状的话,假设最终形状为 ( b 0 , b 1 , . . . , b s ) (b_0, b_1, ...,b_s) (b0​,b1​,...,bs​)
  • 这时 [ t 0 , t 1 , . . . , t n − 1 ] [t_0, t_1,...,t_{n-1}] [t0​,t1​,...,tn−1​]这些tensor的对应位置的元素构成1组坐标,总共 b 0 × b 1 × . . . × b s b_0\times b_1\times ...\times b_s b0​×b1​×...×bs​组坐标。
  • 每组坐标进行一次简单索引,取出的元素(可能是数,也可能是tensor)放在形状 ( b 0 , b 1 , . . . , b s ) (b_0, b_1, ...,b_s) (b0​,b1​,...,bs​)的对应位置,例如如果是 [ t 0 , t 1 , . . . , t n − 1 ] [t_0, t_1,...,t_{n-1}] [t0​,t1​,...,tn−1​]的所有 ( 0 , 0 , . . . , 0 ) (0,0, ...,0) (0,0,...,0)构成的坐标,那么将结果放在 ( 0 , 0 , . . . , 0 ) (0,0, ...,0) (0,0,...,0)处,得到最后结果。

那么我们看个例子:

>>> idx_0 = torch.tensor([[3, 2],[1, 4]])
>>> a[idx_0]
tensor([[3, 2],[1, 4]])
>>> b  # 查看一下b,
tensor([[0, 1],[2, 3],[4, 5]])
>>> idx_0 = torch.tensor([[1, 0],[2, 1]])
>>> idx_1 = torch.tensor([0, 1])
>>> b[idx_0, idx_1]
tensor([[2, 1],[4, 3]])

分析一下过程,先是idx_1广播成

tensor([[0, 1],[0, 1]])

构成四组坐标 [ 1 , 0 ] , [ 0 , 1 ] , [ 2 , 0 ] , [ 1 , 1 ] [1, 0],[0, 1],[2, 0],[1, 1] [1,0],[0,1],[2,0],[1,1],对应着 2 , 1 , 4 , 3 2, 1, 4, 3 2,1,4,3,放入对应位置得到最终结果。

那么其实花式索引也是索引的一种,不过是通过多次索引,并且组成新的tensor的更复杂的索引,同理也可以和切片结合。

结语

这些关于tensor的理解,都是个人理解,希望能够帮助到有需要的人,另外如有参考本篇博客,请注明链接,请勿直接搬运。
思考过很久,因为想要写写博客,记录自己成长的过程,但又迟迟没能动手,或者是写写停停,但总算是再一次尝试了一遍,以后也可能是随缘吧,尽力吧

Pytorch中Tensor的索引,切片以及花式索引(fancy indexing)相关推荐

  1. pytorch tensor查找0_在PyTorch中Tensor的查找和筛选例子

    本文源码基于版本1.0,交互界面基于0.4.1 import torch 按照指定轴上的坐标进行过滤 index_select() 沿着某tensor的一个轴dim筛选若干个坐标 >>&g ...

  2. Pytorch中tensor.view().permute().contiguous()函数理解

    Pytorch中tensor.view().permute().contiguous()函数理解 yolov3中有一行这样的代码,在此记录一下三个函数的含义 # 例子中batch_size为整型,le ...

  3. Pytorch中tensor维度和torch.max()函数中dim参数的理解

    Pytorch中tensor维度和torch.max()函数中dim参数的理解 维度 参考了 https://blog.csdn.net/qq_41375609/article/details/106 ...

  4. Pytorch中tensor.expand()和tensor.expand_as()函数

    Pytorch中tensor.expand函数 Tensor.expand()函数详解 Tensor.expand_as()函数 Tensor.expand()函数详解 函数语法: # 官方解释: D ...

  5. PyTorch中tensor介绍

          PyTorch中的张量(Tensor)如同数组和矩阵一样,是一种特殊的数据结构.在PyTorch中,神经网络的输入.输出以及网络的参数等数据,都是使用张量来进行描述.       torc ...

  6. pyTorch中tensor运算

    文章目录 PyTorch的简介 PyTorch中主要的包 PyTorch的安装 使用GPU的原因 使数据在GPU上运行 什么使Tensor(张量) 一些术语介绍 Tensor的属性介绍(Rank,ax ...

  7. pandas中series一维数组的创建、索引的更改+索引切片和布尔索引+dataframe二维数组的创建、基本属性、索引方法(传统方法和lociloc)、nan操作、排序+案例

    目录 一.为什么要学习pandas? 二.pandas的常用数据类型 1.series--一维的且带标签的数组 (1)创建一维数组 (2)通过列表形式创建的series带标签数组可以改变索引,传入索引 ...

  8. pytorch中tensor、backward一些总结

    目录 说明 Tensor Tensor的创建 Tensor(张量)基本数据类型与常用属性 Tensor的自动微分 设置不可积分计算 pytorch 计算图 backward一些细节 该文章解决问题如下 ...

  9. pytorch中tensor常用is_contiguous含义

    is_contiguous 根据名字就可以知道判断是否连续相邻, pytorch中不管任意维度的张量底层都是一维tensor,只是取决于你怎么读,因此每个tensor中标量都是连续的.如果我们将矩阵进 ...

最新文章

  1. mysql慢查询开启及分析方法
  2. 使用nginx源代码编译安装lnmp
  3. 小白入门angular-cli的第一次旅程(学习目标 1.路由的基础知识 参数订阅写法)
  4. 阿里云主机安装开发工具包报错处理
  5. [小技巧]ASP.NET Core中如何预压缩静态文件
  6. windows2008 sp2 x64安装 ocs 2007 r2 笔记
  7. Android 网络评分机制
  8. C++面向对象小练习:几何图形类
  9. win7网络不显示共享计算机,win7查找不到网络计算机怎么办_win7看不到网络计算机怎么解决-win7之家...
  10. android app跳转到微信
  11. 计算机开机闪烁进不去,电脑开机左上角横杠一直闪进不去系统怎么办
  12. Modelsim的tcl命令
  13. BIT2023 智慧社区综合管理系统-一周目
  14. 【每周CV论文推荐】基于GAN的图像修复值得阅读的文章
  15. UE4中三维几何总结——几何学基础
  16. CAD看图如何在电脑上快速找到并打开指定CAD图纸
  17. 武宣计算机培训学校,武宣县职业技术学校
  18. Kinect开发学习笔记之(五)不带游戏者ID的深度数据的提取
  19. 洛谷 P1506 拯救oibh总部 题解(洪水填充法的模板)
  20. Java修炼之凡界篇 筑基期 第02卷 语法 番外1 原码 反码 补码

热门文章

  1. 计算机除法用什么函数,“excel除法函数公式是哪个“excel中公式或者函数怎么用呢?谢谢...
  2. 【Classical Network】DeepLabv3+ 中Dilation ASPP 以及Decoder模块
  3. winEdt使用教程
  4. MacType最强配置图解
  5. 公众号如何关联小程序?
  6. 如何在个人博客上添加自己的备案信息
  7. 唯在珠峰之巅,能欣赏到如许壮阔的5G时代
  8. 互联网公司端午节福利大揭秘[高清图文]
  9. HTML + CSS + Javascript 简易示例
  10. [职场]工作多久才能换工作?下一个工作年薪该多高?