gather

paddle. gather ( x, index, axis=None, name=None ) [源代码]

根据索引 index 获取输入 x 的指定 aixs 维度的条目,并将它们拼接在一起。

  1. X = [[1, 2],
  2. [3, 4],
  3. [5, 6]]
  4. Index = [1, 2]
  5. axis = 0
  6. Then:
  7. Out = [[3, 4],
  8. [5, 6]]

参数:

  • x (Tensor) - 输入 Tensor, 秩 rank >= 1 , 支持的数据类型包括 int32、int64、float32、float64 和 uint8 (CPU)、float16(GPU) 。

  • index (Tensor) - 索引 Tensor,秩 rank = 1, 数据类型为 int32 或 int64。

  • axis (Tensor) - 指定index 获取输入的维度, axis 的类型可以是int或者Tensor,当 axis 为Tensor的时候其数据类型为int32 或者int64。

  • name (str,可选)- 具体用法请参见 Name ,一般无需设置,默认值为None。

返回:和输入的秩相同的输出Tensor。

代码示例

  1. import numpy as np
  2. import paddle
  3. input_1 = np.array([[1,2],[3,4],[5,6]])
  4. index_1 = np.array([0,1])
  5. input = paddle.to_tensor(input_1)
  6. index = paddle.to_tensor(index_1)
  7. output = paddle.gather(input, index, axis=0)
  8. # expected output: [[1,2],[3,4]]