torch.scatter函数详解
#torch.scatter函数官方解释scatter(output, dim, index, src) → TensorWrites all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is sp
#torch.scatter函数官方解释
scatter(output, dim, index, src) → Tensor
Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.
For a 3-D tensor, self is updated as:
- output[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
- output[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
- output[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
This is the reverse operation of the manner described in gather().
self, index and src (if it is a Tensor) should have same number of dimensions. It is also required that index.size(d) <= src.size(d) for all dimensions d, and that index.size(d) <= self.size(d) for all dimensions d != dim.
Moreover, as for gather(), the values of index must be between 0 and self.size(dim) - 1 inclusive, and all values in a row along the specified dimension dim must be unique.
Parameters
- dim (int) – the axis along which to index
- index (LongTensor) – the indices of elements to scatter, can be either empty or the same size of src. When empty, the operation returns identity
- src (Tensor) – the source element(s) to scatter, incase value is not specified
- value (float) – the source element(s) to scatter, incase src is not specified
总结:scatter函数就是把src数组中的数据重新分配到output数组当中,index数组中表示了要把src数组中的数据分配到output数组中的位置,若未指定,则填充0.
#通过例子理解函数
import torch
input = torch.randn(2, 4)
print(input)
output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
output = output.scatter(1, index, input)
print(output)
#得到输出
tensor([[-0.2558, -1.8930, -0.7831, 0.6100],
[ 0.3246, 2.1289, 0.5887, 1.5588]])
tensor([[ 0.6100, -1.8930, -0.7831, -0.2558, 0.0000],
[ 0.5887, 0.3246, 2.1289, 1.5588, 0.0000]])
建议从input数组出发,结合官方给出的位置替换进行理解。
数据位置发生的变化都是在第1维上,第0维不变。若dim=0,则同理变换input第一维的下标。
- input[0][0] = output[0][index[0][0]] = output[0][3]
- input[0][1] = output[0][index[0][1]] = output[0][1]
- input[0][2] = output[0][index[0][2]] = output[0][2]
- input[0][3] = output[0][index[0][3]] = output[0][0]
- Input[1][0] = output[1][index[1][0]] = output[1][1]
- input[1][1] = output[1][index[1][1]] = output[1][2]
- input[1][2] = output[1][index[1][2]] = output[1][0]
- input[1][3] = output[1][index[1][3]] = output[1][3]
一般scatter用于生成onehot向量,如下所示:
index = torch.tensor([[1], [2], [0], [3]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)
#输出
tensor([[0., 1., 0., 0.],
[0., 0., 1., 0.],
[1., 0., 0., 0.],
[0., 0., 0., 1.]])
#如果input是一个数字的话,代表这用于分配到output的数字是多少。
更多推荐
所有评论(0)