
scatter(output, dim, index, src) → Tensor

Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

For a 3-D tensor, self is updated as:

  • output[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
  • output[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
  • output[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

This is the reverse operation of the manner described in gather().

self, index and src (if it is a Tensor) should have same number of dimensions. It is also required that index.size(d) <= src.size(d) for all dimensions d, and that index.size(d) <= self.size(d) for all dimensions d != dim.

Moreover, as for gather(), the values of index must be between 0 and self.size(dim) - 1 inclusive, and all values in a row along the specified dimension dim must be unique.


  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to scatter, can be either empty or the same size of src. When empty, the operation returns identity
  • src (Tensor) – the source element(s) to scatter, incase value is not specified
  • value (float) – the source element(s) to scatter, incase src is not specified



import torchinput = torch.randn(2, 4)
output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
output = output.scatter(1, index, input)
tensor([[-0.2558, -1.8930, -0.7831,  0.6100],[ 0.3246,  2.1289,  0.5887,  1.5588]])tensor([[ 0.6100, -1.8930, -0.7831, -0.2558,  0.0000],[ 0.5887,  0.3246,  2.1289,  1.5588,  0.0000]])



  • input[0][0] = output[0][index[0][0]] = output[0][3]
  • input[0][1] = output[0][index[0][1]] = output[0][1]
  • input[0][2] = output[0][index[0][2]] = output[0][2]
  • input[0][3] = output[0][index[0][3]] = output[0][0]
  • Input[1][0] = output[1][index[1][0]] = output[1][1]
  • input[1][1] = output[1][index[1][1]] = output[1][2]
  • input[1][2] = output[1][index[1][2]] = output[1][0]
  • input[1][3] = output[1][index[1][3]] = output[1][3]


index = torch.tensor([[1], [2], [0], [3]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
tensor([[0., 1., 0., 0.],[0., 0., 1., 0.],[1., 0., 0., 0.],[0., 0., 0., 1.]])#如果input是一个数字的话,代表这用于分配到output的数字是多少。


