#torch.scatter函数官方解释

scatter(output, dim, index, src) → Tensor

Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

For a 3-D tensor, self is updated as:

  • output[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
  • output[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
  • output[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

This is the reverse operation of the manner described in gather().

self, index and src (if it is a Tensor) should have same number of dimensions. It is also required that index.size(d) <= src.size(d) for all dimensions d, and that index.size(d) <= self.size(d) for all dimensions d != dim.

Moreover, as for gather(), the values of index must be between 0 and self.size(dim) - 1 inclusive, and all values in a row along the specified dimension dim must be unique.

Parameters

  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to scatter, can be either empty or the same size of src. When empty, the operation returns identity
  • src (Tensor) – the source element(s) to scatter, incase value is not specified
  • value (float) – the source element(s) to scatter, incase src is not specified

总结:scatter函数就是把src数组中的数据重新分配到output数组当中,index数组中表示了要把src数组中的数据分配到output数组中的位置,若未指定,则填充0.

#通过例子理解函数

import torchinput = torch.randn(2, 4)
print(input)
output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
output = output.scatter(1, index, input)
print(output)#得到输出
tensor([[-0.2558, -1.8930, -0.7831,  0.6100],[ 0.3246,  2.1289,  0.5887,  1.5588]])tensor([[ 0.6100, -1.8930, -0.7831, -0.2558,  0.0000],[ 0.5887,  0.3246,  2.1289,  1.5588,  0.0000]])

建议从input数组出发,结合官方给出的位置替换进行理解。

数据位置发生的变化都是在第1维上,第0维不变。若dim=0,则同理变换input第一维的下标。

  • input[0][0] = output[0][index[0][0]] = output[0][3]
  • input[0][1] = output[0][index[0][1]] = output[0][1]
  • input[0][2] = output[0][index[0][2]] = output[0][2]
  • input[0][3] = output[0][index[0][3]] = output[0][0]
  • Input[1][0] = output[1][index[1][0]] = output[1][1]
  • input[1][1] = output[1][index[1][1]] = output[1][2]
  • input[1][2] = output[1][index[1][2]] = output[1][0]
  • input[1][3] = output[1][index[1][3]] = output[1][3]

一般scatter用于生成onehot向量,如下所示:

index = torch.tensor([[1], [2], [0], [3]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)#输出
tensor([[0., 1., 0., 0.],[0., 0., 1., 0.],[1., 0., 0., 0.],[0., 0., 0., 1.]])#如果input是一个数字的话,代表这用于分配到output的数字是多少。

torch.scatter函数详解相关推荐

  1. Torch.arange函数详解

    torch.arange函数详解 官方文档:torch.arange 函数原型 arange(start=0, end, step=1, *, out=None, dtype=None, layout ...

  2. 【Pytorch】torch.argmax 函数详解

    文章目录 一.一个参数时的 torch.argmax 函数 1. 介绍 2. 实例 二.多个参数时的 torch.argmax 函数 1. 介绍 2. 实例 实例1:二维矩阵 实例2:三维矩阵 实例3 ...

  3. torch.flatten()函数详解

    自己的理解: 介绍torch.flatten()函数的具体使用方法1.首先创建一个三维张量2.调用torch.flatten()函数 import torchx = torch.randn(2, 3, ...

  4. Pytorch的scatter函数详解

    文章目录 前言 1.官方文档解释 2.举个例子 总结 前言  在看FCOS算法源码时,发现获取正样本点用到了scatter这个函数,故记录下. 1.官方文档解释   先贴出链接:scatter官方解读 ...

  5. 【Torch API】pytorch 中torch.ones_like和torch.zeros_like函数详解

    torch.ones_like函数和torch.zeros_like函数的基本功能是根据给定张量,生成与其形状相同的全1张量或全0张量,示例如下: input = torch.rand(2, 3) p ...

  6. torch.zeros() 函数详解

    torch.zeros()函数 返回一个形状为为size,类型为torch.dtype,里面的每一个值都是0的tensor torch.zeros(*size, out=None, dtype=Non ...

  7. PyTorch中torch.norm函数详解

    torch.norm() 是 PyTorch 中的一个函数,用于计算输入张量沿指定维度的范数.具体而言,当给定一个输入张量 x 和一个整数 p 时,torch.norm(x, p) 将返回输入张量 x ...

  8. torch.topk() 函数详解

    作用: 返回 列表中最大的n个值 例子1:m=torch.arange(0,10)              print(m.topk(3)) torch.return_types.topk( val ...

  9. scatter函数详解

    原作者地址https://blog.csdn.net/qiu931110/article/details/68130199

最新文章

  1. 计算机书籍-Go语言并发之道
  2. apk里面的图片不显示是加密了吗_【App】智能电视机视频盒子软件,你们缺吗?...
  3. python3 的 zip
  4. VC控件 Edit Control
  5. 锁定机制和数据并发管理(笔记)
  6. CSS清除默认样式,技术详细介绍
  7. java oauth2登录以及权限_还得看 Java!Gitee 4月最火 Java 项目大盘点
  8. python数组和列表的区别_JS数组方法与python列表方法的比较
  9. Spring MVC 学习笔记2 - 利用Spring Tool Suite创建一个web 项目
  10. iOS 使用自定义字体
  11. 关于安装NTKO Office插件的方法
  12. 前端事件练习之轮播图代码
  13. Deepin 系统下安装VMware并激活.
  14. RabbitMQ的使用(Java语言传统操作)
  15. Alpha、Beta、RC、GA版本的区别
  16. 部署ServletContext的时候报错 Class com.xxxxx.ContextServlet is not a Servlet
  17. 解锁忘记密码的iPhone X
  18. house of cat
  19. POJ - 2955 Brackets (区间DP)
  20. ABAQUS UEL

热门文章

  1. 汉诺塔(Hanoi)移动步骤问题
  2. 碧桂园的“秘密”:只想着安抚“闹事者”
  3. 我们宿舍里那群“禽兽”的极品笑话
  4. 《系统架构设计师教程》 第二章:计算机与网络基础知识
  5. 尊重知识产权,使用正版软件
  6. 关于计算机系统基础debian 11安装(感觉也适用debian10啥的)
  7. borland与microsoft之争
  8. win10禁用驱动签名
  9. 【笔记】Opencv 实现拼图板小游戏
  10. 阿斯达年代记思维导图