where

paddle. where ( condition, x, y, name=None ) [源代码]

该OP返回一个根据输入 condition, 选择 xy 的元素组成的多维 Tensor

where - 图1

参数:

  • condition (Tensor)- 选择 xy 元素的条件 。

  • x (Tensor)- 多维 Tensor ,数据类型为 float32float64int32int64

  • y (Tensor)- 多维 Tensor ,数据类型为 float32float64int32int64

  • name (str,可选)- 具体用法请参见 Name ,一般无需设置,默认值为None。

返回:数据类型与 x 相同的 Tensor

返回类型:Tensor。

代码示例:

  1. import paddle
  2. x = paddle.to_tensor([0.9383, 0.1983, 3.2, 1.2])
  3. y = paddle.to_tensor([1.0, 1.0, 1.0, 1.0])
  4. out = paddle.where(x>1, x, y)
  5. print(out)
  6. #out: [1.0, 1.0, 3.2, 1.2]