where

paddle.fluid.layers. where ( condition ) [源代码]

该OP计算输入元素中为True的元素在输入中的坐标(index)。

参数:

  • condition (Variable)– 输入秩至少为1的多维Tensor,数据类型是bool类型。

返回:输出condition元素为True的坐标(index),将所有的坐标(index)组成一个2-D的Tensor。

返回类型:Variable,数据类型是int64。

代码示例

  1. import paddle.fluid as fluid
  2. import paddle.fluid.layers as layers
  3. import numpy as np
  4. # tensor 为 [True, False, True]
  5. condition = layers.assign(np.array([1, 0, 1], dtype='int32'))
  6. condition = layers.cast(condition, 'bool')
  7. out = layers.where(condition) # [[0], [2]]
  8. # tensor 为 [[True, False], [False, True]]
  9. condition = layers.assign(np.array([[1, 0], [0, 1]], dtype='int32'))
  10. condition = layers.cast(condition, 'bool')
  11. out = layers.where(condition) # [[0, 0], [1, 1]]
  12. # tensor 为 [False, False, False]
  13. condition = layers.assign(np.array([0, 0, 0], dtype='int32'))
  14. condition = layers.cast(condition, 'bool')
  15. out = layers.where(condition) # [[]]