scatter

paddle.scatter ( x, index, updates, overwrite=True, name=None ) [源代码]

通过基于 updates 来更新选定索引 index 上的输入来获得输出。具体行为如下:

  1. import numpy as np
  2. #input:
  3. x = np.array([[1, 1], [2, 2], [3, 3]])
  4. index = np.array([2, 1, 0, 1])
  5. # shape of updates should be the same as x
  6. # shape of updates with dim > 1 should be the same as input
  7. updates = np.array([[1, 1], [2, 2], [3, 3], [4, 4]])
  8. overwrite = False
  9. # calculation:
  10. if not overwrite:
  11. for i in range(len(index)):
  12. x[index[i]] = np.zeros((2))
  13. for i in range(len(index)):
  14. if (overwrite):
  15. x[index[i]] = updates[i]
  16. else:
  17. x[index[i]] += updates[i]
  18. # output:
  19. out = np.array([[3, 3], [6, 6], [1, 1]])
  20. out.shape # [3, 2]

Notice: 因为 updates 的应用顺序是不确定的,因此,如果索引 index 包含重复项,则输出将具有不确定性。

参数

  • x (Tensor) - ndim> = 1的输入N-D张量。 数据类型可以是float32,float64。

  • index (Tensor)- 一维Tensor。 数据类型可以是int32,int64。 index 的长度不能超过 updates 的长度,并且 index 中的值不能超过输入的长度。

  • updates (Tensor)- 根据 index 使用 update 参数更新输入 x 。 形状应与输入 x 相同,并且dim>1的dim值应与输入 x 相同。

  • overwrite (bool,可选)- 指定索引 index 相同时,更新输出的方式。如果为True,则使用覆盖模式更新相同索引的输出,如果为False,则使用累加模式更新相同索引的输出。默认值为True。

  • name (str,可选)- 具体用法请参见 Name ,一般无需设置,默认值为None。

返回

Tensor,与x有相同形状和数据类型。

代码示例

  1. import paddle
  2. import numpy as np
  3. x_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float32)
  4. index_data = np.array([2, 1, 0, 1]).astype(np.int64)
  5. updates_data = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype(np.float32)
  6. x = paddle.to_tensor(x_data)
  7. index = paddle.to_tensor(index_data)
  8. updates = paddle.to_tensor(updates_data)
  9. output1 = paddle.scatter(x, index, updates, overwrite=False)
  10. # [[3., 3.],
  11. # [6., 6.],
  12. # [1., 1.]]
  13. output2 = paddle.scatter(x, index, updates, overwrite=True)
  14. # CPU device:
  15. # [[3., 3.],
  16. # [4., 4.],
  17. # [1., 1.]]
  18. # GPU device maybe have two results because of the repeated numbers in index
  19. # result 1:
  20. # [[3., 3.],
  21. # [4., 4.],
  22. # [1., 1.]]
  23. # result 2:
  24. # [[3., 3.],
  25. # [2., 2.],
  26. # [1., 1.]]