之前一直理解不了Pytorch中gather的用法,看了官方的文档也是一头雾水。然后自己琢磨,找规律,用以下方法进行理解。

一、官方文档

torch.gather(input, dim, index, out=None) → TensorGathers values along an axis specified by dim.For a 3-D tensor the output is specified by:out[i][j][k] = input[index[i][j][k]][j][k]  # dim=0out[i][j][k] = input[i][index[i][j][k]][k]  # dim=1out[i][j][k] = input[i][j][index[i][j][k]]  # dim=2Parameters:  input (Tensor) – The source tensordim (int) – The axis along which to indexindex (LongTensor) – The indices of elements to gatherout (Tensor, optional) – Destination tensorExample:>>> t = torch.Tensor([[1,2],[3,4]])>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))1  14  3[torch.FloatTensor of size 2x2]

二、个人理解方法

首先,torch.gather中含有3个参数,第一个是input (Tensor),第二个是dim,第三个是index。

我是将index中的数字n当成取第n维的数字,然后根据index的位置从input中找到对应的值。下面举几个例子。

例1:官方文档

    >>> t = torch.Tensor([[1,2],[3,4]])>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))1  14  3

如这里,dim=1,所以取的是最里面的元素:

第一个数是0表示的是从[1, 2]中取第0个数,得1;

第二个数是0表示的是从[1, 2]中取第0个数,得1;

第三个数是1表示的是从[3, 4]中取第1个数,得4;

第四个数是0表示的是从[3, 4]中取第0个数,得3;

例2:

b = torch.Tensor([[1,2,3],[4,5,6]])
print b
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print(torch.gather(b, dim=1, index=index_1))
print(torch.gather(b, dim=0, index=index_2))# 结果如下:1  2  34  5  6
[torch.FloatTensor of size 2x3]1  26  4
[torch.FloatTensor of size 2x2]1  5  61  2  3
[torch.FloatTensor of size 2x3]

首先看第一个结果:[[1, 2], [6, 4]]。

dim=1,所以取的是最里面的元素:

第一个数是0表示的是从[1, 2, 3]中取第0个数,得1;

第二个数是1表示的是从[1, 2, 3]中取第1个数,得2;

第三个数是2表示的是从[4, 5, 6]中取第2个数,得6;

第四个数是0表示的是从[4, 5, 6]中取第0个数,得4;

再看第二个结果:[[1, 5, 6], [1, 2, 3]]。

dim=0,所以index中的数字表示的是对应维度的数组:

第一个数是0表示的是从[[1, 2, 3]]中取位置为(0, 0)的数,得1;

第二个数是1表示的是从[[4, 5, 6]]中取位置为(0, 2)的数,得5;

第三个数是1表示的是从[[4, 5, 6]]中取位置为(0, 3)的数,得6;

第四个数是0表示的是从[[1, 2, 3]]中取位置为(0, 0)的数,得1;

第五个数是0表示的是从[[1, 2, 3]]中取位置为(0, 1)的数,得2;

第六个数是0表示的是从[[1, 2, 3]]中取位置为(0, 2)的数,得3。

注:最后输出的形状和index是保持一致的。

Pytorch中gather函数的个人理解方法相关推荐

  1. pytorch中gather函数的理解

    官方解释,很清楚了 torch.gather(input,dim,index,out=None) → Tensortorch.gather(input, dim, index, out=None) → ...

  2. PyTorch中gather()函数的用法

    torch.gather(input, dim, index, out=None) → Tensor 沿给定轴,按照索引张量将原张量的指定位置的元素重新聚合成一个新的张量 参数含义: input (T ...

  3. Pytorch 中 gather 函数讲解

    文章目录 官方解读分析 小例子 官方解读分析 该函数的功能为:沿着 dim 指定的轴收集值 torch.gather(input, dim, index, out=None) → TensorGath ...

  4. pytorch中repeat()函数理解

    pytorch中repeat()函数理解 最近在学习过程中遇到了repeat()函数的使用,这里记录一下自己对这个函数的理解. 情况1:repeat参数个数与tensor维数一致时 a = torch ...

  5. pytorch 中 contiguous() 函数理解

    pytorch 中 contiguous() 函数理解 文章目录 pytorch 中 contiguous() 函数理解 引言 使用 contiguous() 后记 文章抄自 Pytorch中cont ...

  6. matlab中fprintf函数的具体使用方法

    matlab中fprintf函数的具体使用方法实例如下: fprintf函数可以将数据按指定格式写入到文本文件中.其调用格式为: 数据的格式化输出:fprintf(fid, format, varia ...

  7. Oracle中wm_concat函数报错解决方法

    Oracle中wm_concat函数报错解决方法 参考文章: (1)Oracle中wm_concat函数报错解决方法 (2)https://www.cnblogs.com/52net/archive/ ...

  8. Makefile中wildcard函数的应用理解

    文章目录 前言 1 "*"通配符使用场景 2 "*"通配符实例 总结 前言 如果我们想定义一系列比较类似的文件,我们很自然地就想起使用通配符.make 支持三种 ...

  9. python中bool函数用法_在python中bool函数的取值方法

    bool是Boolean的缩写,只有真(True)和假(False)两种取值 bool函数只有一个参数,并根据这个参数的值返回真或者假. 1.当对数字使用bool函数时,0返回假(False),任何其 ...

最新文章

  1. C++中的接口(抽象类)
  2. copy时候明细消失没有了
  3. 2.7 负采样-深度学习第五课《序列模型》-Stanford吴恩达教授
  4. Windows驱动开发学习笔记(六)—— Inline HOOK
  5. VTK:InfoVis之WordCloud
  6. AVL树双旋转+图解
  7. 网管光纤收发器产品硬件功能及网管收发器优点介绍
  8. 微信小程序 AppID和AppSecret的获取方式
  9. SQL Server 2005 Analysis Services实践(一)
  10. 【mysql问题】can't connect to mysql server on 'localhost' (10060)
  11. 判断中文文本是否为utf8编码类型的javascript实现_Go语言实现LeetCode算法:393 UTF-8编码校验...
  12. 软设考点精要,精确到每页!
  13. DWORD winapi java_DWORD WINAPI?stdcall?
  14. ES索引生命周期管理ILM
  15. 没有音响,把手机当作电脑音响的操作。
  16. 基于JAVA计算机类专业考研交流学习平台计算机毕业设计源码+数据库+lw文档+系统+部署
  17. 计算机课程与就业关系,计算机专业课程及就业方向
  18. 群晖监控备份方案,为金融企业信息安全保驾护航
  19. mysql数据库怎么冷备份恢复_MySQL数据库的备份与恢复
  20. 如何OIM 11.1.1.5.0打补丁到11.1.1.5.2

热门文章

  1. BLE 连接和通信 的实现
  2. 实现同时控制施耐德(Schneider Electric)和西门子(Siemens)PLC WCS系统
  3. [hiho 22]线段树-lazy标记的下放
  4. 租房押金条丢了房东不退钱如何解决
  5. 通过端口 1433 连接到主机 localhost 的 TCP/IP 连接失败。错误:“Connection refused: connect。
  6. dedecms标签大全(非常经典)
  7. lunix安装tornado
  8. alsa框架与音频芯片移植基础
  9. VMware虚拟机开机黑屏问题
  10. fetch和ajax的区别?