Pytorch中gather函数的个人理解方法
之前一直理解不了Pytorch中gather的用法,看了官方的文档也是一头雾水。然后自己琢磨,找规律,用以下方法进行理解。
一、官方文档
torch.gather(input, dim, index, out=None) → TensorGathers values along an axis specified by dim.For a 3-D tensor the output is specified by:out[i][j][k] = input[index[i][j][k]][j][k] # dim=0out[i][j][k] = input[i][index[i][j][k]][k] # dim=1out[i][j][k] = input[i][j][index[i][j][k]] # dim=2Parameters: input (Tensor) – The source tensordim (int) – The axis along which to indexindex (LongTensor) – The indices of elements to gatherout (Tensor, optional) – Destination tensorExample:>>> t = torch.Tensor([[1,2],[3,4]])>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))1 14 3[torch.FloatTensor of size 2x2]
二、个人理解方法
首先,torch.gather中含有3个参数,第一个是input (Tensor),第二个是dim,第三个是index。
我是将index中的数字n当成取第n维的数字,然后根据index的位置从input中找到对应的值。下面举几个例子。
例1:官方文档
>>> t = torch.Tensor([[1,2],[3,4]])>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))1 14 3
如这里,dim=1,所以取的是最里面的元素:
第一个数是0表示的是从[1, 2]中取第0个数,得1;
第二个数是0表示的是从[1, 2]中取第0个数,得1;
第三个数是1表示的是从[3, 4]中取第1个数,得4;
第四个数是0表示的是从[3, 4]中取第0个数,得3;
例2:
b = torch.Tensor([[1,2,3],[4,5,6]])
print b
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print(torch.gather(b, dim=1, index=index_1))
print(torch.gather(b, dim=0, index=index_2))# 结果如下:1 2 34 5 6
[torch.FloatTensor of size 2x3]1 26 4
[torch.FloatTensor of size 2x2]1 5 61 2 3
[torch.FloatTensor of size 2x3]
首先看第一个结果:[[1, 2], [6, 4]]。
dim=1,所以取的是最里面的元素:
第一个数是0表示的是从[1, 2, 3]中取第0个数,得1;
第二个数是1表示的是从[1, 2, 3]中取第1个数,得2;
第三个数是2表示的是从[4, 5, 6]中取第2个数,得6;
第四个数是0表示的是从[4, 5, 6]中取第0个数,得4;
再看第二个结果:[[1, 5, 6], [1, 2, 3]]。
dim=0,所以index中的数字表示的是对应维度的数组:
第一个数是0表示的是从[[1, 2, 3]]中取位置为(0, 0)的数,得1;
第二个数是1表示的是从[[4, 5, 6]]中取位置为(0, 2)的数,得5;
第三个数是1表示的是从[[4, 5, 6]]中取位置为(0, 3)的数,得6;
第四个数是0表示的是从[[1, 2, 3]]中取位置为(0, 0)的数,得1;
第五个数是0表示的是从[[1, 2, 3]]中取位置为(0, 1)的数,得2;
第六个数是0表示的是从[[1, 2, 3]]中取位置为(0, 2)的数,得3。
注:最后输出的形状和index是保持一致的。
Pytorch中gather函数的个人理解方法相关推荐
- pytorch中gather函数的理解
官方解释,很清楚了 torch.gather(input,dim,index,out=None) → Tensortorch.gather(input, dim, index, out=None) → ...
- PyTorch中gather()函数的用法
torch.gather(input, dim, index, out=None) → Tensor 沿给定轴,按照索引张量将原张量的指定位置的元素重新聚合成一个新的张量 参数含义: input (T ...
- Pytorch 中 gather 函数讲解
文章目录 官方解读分析 小例子 官方解读分析 该函数的功能为:沿着 dim 指定的轴收集值 torch.gather(input, dim, index, out=None) → TensorGath ...
- pytorch中repeat()函数理解
pytorch中repeat()函数理解 最近在学习过程中遇到了repeat()函数的使用,这里记录一下自己对这个函数的理解. 情况1:repeat参数个数与tensor维数一致时 a = torch ...
- pytorch 中 contiguous() 函数理解
pytorch 中 contiguous() 函数理解 文章目录 pytorch 中 contiguous() 函数理解 引言 使用 contiguous() 后记 文章抄自 Pytorch中cont ...
- matlab中fprintf函数的具体使用方法
matlab中fprintf函数的具体使用方法实例如下: fprintf函数可以将数据按指定格式写入到文本文件中.其调用格式为: 数据的格式化输出:fprintf(fid, format, varia ...
- Oracle中wm_concat函数报错解决方法
Oracle中wm_concat函数报错解决方法 参考文章: (1)Oracle中wm_concat函数报错解决方法 (2)https://www.cnblogs.com/52net/archive/ ...
- Makefile中wildcard函数的应用理解
文章目录 前言 1 "*"通配符使用场景 2 "*"通配符实例 总结 前言 如果我们想定义一系列比较类似的文件,我们很自然地就想起使用通配符.make 支持三种 ...
- python中bool函数用法_在python中bool函数的取值方法
bool是Boolean的缩写,只有真(True)和假(False)两种取值 bool函数只有一个参数,并根据这个参数的值返回真或者假. 1.当对数字使用bool函数时,0返回假(False),任何其 ...
最新文章
- C++中的接口(抽象类)
- copy时候明细消失没有了
- 2.7 负采样-深度学习第五课《序列模型》-Stanford吴恩达教授
- Windows驱动开发学习笔记(六)—— Inline HOOK
- VTK:InfoVis之WordCloud
- AVL树双旋转+图解
- 网管光纤收发器产品硬件功能及网管收发器优点介绍
- 微信小程序 AppID和AppSecret的获取方式
- SQL Server 2005 Analysis Services实践(一)
- 【mysql问题】can't connect to mysql server on 'localhost' (10060)
- 判断中文文本是否为utf8编码类型的javascript实现_Go语言实现LeetCode算法:393 UTF-8编码校验...
- 软设考点精要,精确到每页!
- DWORD winapi java_DWORD WINAPI?stdcall?
- ES索引生命周期管理ILM
- 没有音响,把手机当作电脑音响的操作。
- 基于JAVA计算机类专业考研交流学习平台计算机毕业设计源码+数据库+lw文档+系统+部署
- 计算机课程与就业关系,计算机专业课程及就业方向
- 群晖监控备份方案,为金融企业信息安全保驾护航
- mysql数据库怎么冷备份恢复_MySQL数据库的备份与恢复
- 如何OIM 11.1.1.5.0打补丁到11.1.1.5.2
热门文章
- BLE 连接和通信 的实现
- 实现同时控制施耐德(Schneider Electric)和西门子(Siemens)PLC WCS系统
- [hiho 22]线段树-lazy标记的下放
- 租房押金条丢了房东不退钱如何解决
- 通过端口 1433 连接到主机 localhost 的 TCP/IP 连接失败。错误:“Connection refused: connect。
- dedecms标签大全(非常经典)
- lunix安装tornado
- alsa框架与音频芯片移植基础
- VMware虚拟机开机黑屏问题
- fetch和ajax的区别?