pytorch 深入理解 tensor.scatter_ ()用法

在 pytorch 库下理解 torch.tensor.scatter()的用法。作者在网上搜索了很多方法,最后还是觉得自己写一篇更为详细的比较好,转载请注明。
首先,scatter() 和 scatter_() 的作用是一样的,但是 scatter() 不会直接修改原来的 Tensor,而 scatter_() 会修改原先的 Tensor。

1 API格式

torch.Tensor.scatter_(dim, index, src) → Tensor

字面意思:对一个 torch.Tensor 进行操作,dim,index,src三个为输入的参数。

  • dim 就是在哪个维度进行操作,注意,dim 的不同,在其他条件相同的条件下得到的output 也不同。
  • index 是输入的索引。
  • src 就是输入的向量,也就是 input。

最后,函数返回一个 Tensor。

2 具体示例

import torch as th
# import torch 包a = th.rand(2,5)
# 初始化向量 a,size 为 (2, 5),二维向量,2行5列,每个元素是 0 到 1 的均匀分布采样
# 把 a 作为 src,也就是 input
# a 的初始化数值如下:
src tensor:
tensor([[0.6789, 0.7350, 0.6104, 0.7777, 0.9613],[0.1432, 0.8788, 0.3269, 0.0063, 0.6070]])# 初始化 b 为size 为 (3, 5) 的向量,二维向量,3行5列,每个元素被初始化为 0
b = th.zeros(3, 5).scatter_(dim = 0,index = th.LongTensor([[0, 1, 2, 0, 0],[2, 0, 0, 1, 2]]),src = a
)
# dim = 0, out:
tensor([[0.6789, 0.8788, 0.3269, 0.7777, 0.9613],[0.0000, 0.7350, 0.0000, 0.0063, 0.0000],[0.1432, 0.0000, 0.6104, 0.0000, 0.6070]])# 初始化 c 为size 为 (3, 5) 的向量,二维向量,3行5列,每个元素被初始化为 0
c = th.zeros(3, 5).scatter_(dim = 1,th.LongTensor([[0, 1, 2, 0, 0],[2, 0, 0, 1, 2]]),src = a
)
# dim = 1, out:
tensor([[0.9613, 0.7350, 0.6104, 0.0000, 0.0000],[0.3269, 0.0063, 0.6070, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])

下面来解释一下,b,c 内的元素分别是怎么得到的。

2.1 dim = 0 下的结果分析

先说 b,也就是 dim =0 下得到的结果。我们来看下官方给的说明文字:

self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2

因为这时 dim = 0,且只有 2 个维度,所以我们只用看第一行就行。

self [index[i][j]] [j] = src[i][j] # if dim == 0

仅用这一个公式就确定了 b 中所有元素的取值,与 a 的映射关系。这里等号左边的 self 可看做 output,也就是 b;src 是我们的输入向量,也就是 a。这里的 i,j 分别是输入向量 src 的 size 的取值。比如,本例中 a 的 size 为 (2,5),也就是说,对于 a 中的元素,i 的取值为 0,1;j 的取值为 0,1,2,3,4。a 中的元素的索引也就是(0,0),(0,1),… (0,4);(1,0),(1,1),…(1,4) 完毕,一共 2*5 = 10 个元素。
了解了这些以后,通过举例来说明 b 中的元素都是如何确定的。

index = th.LongTensor([[0, 1, 2, 0, 0],[2, 0, 0, 1, 2]]),
我们列举一些元素来说明其映射关系当 i = 0,j = 0 时,
我们用类似上述确定 a 索引的方式确定了 index[i][j] = 0,
这里的 0 就是 [0,1,2,0,0] 中最左边的 0,
则 b = out[index[i][j]][j] = out[0][0] = src[0][0] = 0.6789当 i = 0,j = 1 时,index[0][1] = 1,
这里的 1 就是 [0,1,2,0,0] 中的 1,
同理,b = out[index[i][j]][j] = out[1][1] = src[0][1] = 0.7350当 i = 0,j = 2 时,index[0][2] = 2,
这里的 2 就是 [0,1,2,0,0] 中的 2,
同理,b = out[index[i][j]][j] = out[2][2] = src[0][2] = 0.6104
注意,这里的out[2][2] 不是第 2 行,第 2 列的元素,是第 3 行,第 3 列的元素当 i = 1,j = 1 时,index[1][1] = 0,
这里的 0 就是 [2,0,0,1,2] 中最**左**边的 0,
同理,b = out[index[i][j]][j] = out[0][1] = src[1][1] = 0.8788当 i = 1,j = 3 时,index[1][3] = 0,
这里的 0 就是 [2,0,0,1,2] 中最**右**边的 0,
同理,b = out[index[i][j]][j] = out[0][1] = src[1][3] = 0.0063当 i = 1,j = 4 时,index[1][4] = 2,
这里的 2 就是 [2,0,0,1,2] 中最**左**边的 0,
同理,b = out[index[i][j]][j] = out[0][1] = src[1][4] = 0.6070由此得到了 b 中有映射关系的元素,剩余的元素,由于 b 被初始化为全 0 向量,所以剩余的元素均为 0 。

dim = 1的时候,同理。只是换了一种映射机制,如法炮制。

有任何关于内容不够详细,解释不清,错误等欢迎留言。转载请注明,支持原创,谢谢。

pytorch 深入理解 tensor.scatter_ ()用法相关推荐

  1. PyTorch torch.Tensor.contiguous() 用法与理解

    中文文档: contiguous() → Tensor         返回一个内存连续的有相同数据的 tensor,如果原 tensor 内存连续则返回原 tensor 英文文档: contiguo ...

  2. 实践教程 | 浅谈 PyTorch 中的 tensor 及使用

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者 | xiaopl@知乎(已授权) 来源 | https://z ...

  3. 【深度学习理论】一文搞透pytorch中的tensor、autograd、反向传播和计算图

    转载:https://zhuanlan.zhihu.com/p/145353262 前言 本文的主要目标: 一遍搞懂反向传播的底层原理,以及其在深度学习框架pytorch中的实现机制.当然一遍搞不定两 ...

  4. Pytorch中的collate_fn函数用法

    Pytorch中的collate_fn函数用法 官方的解释:   Puts each data field into a tensor with outer dimension batch size ...

  5. pytorch 创建张量tensor

    pytorch 创建张量tensor 先看下面一张图 通过上图有了一个直观了解后,我们开始尝试创建一下. 先创建一个标量和一个向量 a = torch.tensor([1]) #标量 print(a) ...

  6. pytorch 实现张量tensor,图片,CPU,GPU,数组等的转换

    pytorch 实现张量tensor,图片,CPU,GPU,数组等的转换 1, 创建pytorch 的Tensor张量: torch.rand((3,224,224)) #创建随机值的三维张量,大小为 ...

  7. YDOOK:Pytorch : AI : torch.tensor.size() 与 torch.tensor.shape 的区别

    YDOOK:Pytorch : AI : torch.tensor.size() 与 torch.tensor.shape 的区别 区别: 1. torch.tensor.size() 可用通过 :t ...

  8. YDOOK:Pytorch教程:tensor 张量内各个值同时相加一个数

    YDOOK:Pytorch教程:tensor 张量内各个值同时相加一个数 © YDOOK Jinwei Lin, shiye.work import torch import numpy as npt ...

  9. 如何理解yield的用法

    原创不易,转载前请注明博主的链接地址:Blessy_Zhu https://blog.csdn.net/weixin_42555080 本次代码的环境: 运行平台: Windows Python版本: ...

最新文章

  1. 触摸矫正+android,android触摸矫正解方程
  2. 京东数科首次公开:强一致、高性能分布式事务中间件JDTX
  3. 嵌入式C语言查表法的项目应用
  4. javascript中的for in循环和for循环的使用
  5. 浅谈设计模式01-策略模式
  6. MySQL 5.6 my.cnf 参数说明(转)
  7. 对弈类游戏的人工智能(3)--博弈树优化
  8. 常常被人忽略的VC备份
  9. parzen窗估计如何进行结果分析_Parzen窗方法的分析和研究
  10. ARM处理器指定运行核
  11. 关于 Docker ,你必须了解的核心都在这里
  12. XP操作系统安装的硬盘空间要求
  13. MyScript ---LateX公式编辑排版
  14. REST-assured简介
  15. 从初级开发者到资深架构师,看这
  16. ceph 集群报 mds cluster is degraded 故障排查
  17. 王晓亮:关于技术人的十年!
  18. Jmeter入门(一)使用Jmeter进行简单的性能测试
  19. TCP连接的建立和中止
  20. 包和 jar 文件的创建

热门文章

  1. MyEclipse6.5的SVN插件的安装
  2. 使用VS.NET2003操作SQLServer DTS.
  3. [译]深入 NGINX: 为性能和扩展所做之设计
  4. 通过Docker进程pid获取容器id
  5. Log4j与common-logging联系与区别
  6. java调用短信接口使用实例
  7. 开源软件的商业化策略模型
  8. VB.NET 按键代码 及组合键
  9. C Linux 文件加锁 lock fcntl
  10. 【神经网络】基于RBF神经网络的六关节机械臂无模型控制