slice

  • paddle.fluid.layers.slice(input, axes, starts, ends)[源代码]

该OP沿多个轴生成 input 的切片。与numpy类似: https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html 该OP使用 axesstartsends 属性来指定轴列表中每个轴的起点和终点位置,并使用此信息来对 input 切片。如果向 startsends 传递负值如

slice - 图1 ,则表示该轴的反向第 slice - 图2 个位置(这里以0为初始位置)。如果传递给 startsend 的值大于n(维度中的元素数目),则表示n。当切片一个未知数量的维度时,建议传入 INT_MAXaxesstartsends 三个参数的元素数目必须相等。以下示例将解释切片如何工作:

  1. 示例1
  2. 给定:
  3. data=[[1,2,3,4],[5,6,7,8],]
  4. axes=[0,1]
  5. starts=[1,0]
  6. ends=[2,3]
  7. 则:
  8. result=[[5,6,7],]
  9.  
  10. 示例2
  11. 给定:
  12. data=[[1,2,3,4],[5,6,7,8],]
  13. starts=[0,1]
  14. ends=[-1,1000] # 此处-1表示第0维的反向第0个位置,索引值是1。
  15. 则:
  16. result=[[2,3,4],] # 即 data[0:1, 1:4]
  • 参数:
    • input (Variable)- 多维 TensorLoDTensor,数据类型为 float16float32float64int32,或 int64
    • axes (list|tuple)- 数据类型是 int32。表示进行切片的轴。它是可选的,如果不存在,将被视为 slice - 图3
    • starts (list|tuple|Variable)- 数据类型是 int32。如果 starts 的类型是 list 或 tuple,它的元素可以是整数或者形状为[1]的 TensorLoDTensor。如果 starts 的类型是 Variable,则是1-D TensorLoDTensor。表示在各个轴上切片的起始索引值。
    • ends (list|tuple|Variable)- 数据类型是 int32。如果 ends 的类型是 list 或 tuple,它的元素可以是整数或者形状为[1]的 TensorLoDTensor。如果 ends 的类型是 Variable,则是1-D TensorLoDTensor。表示在各个轴上切片的结束索引值。

返回:多维 TensorLoDTensor,数据类型与 input 相同。

返回类型:Variable。

  • 抛出异常:
    • TypeErrorstarts 的类型应该是 list、tuple 或 Variable。
    • TypeErrorends 的类型应该是 list、tuple 或 Variable。

代码示例:

  1. import paddle.fluid as fluid
  2.  
  3. input = fluid.layers.data(
  4. name="input", shape=[3, 4, 5, 6], dtype='float32')
  5.  
  6. # example 1:
  7. # attr starts is a list which doesn't contain tensor Variable.
  8. axes = [0, 1, 2]
  9. starts = [-3, 0, 2]
  10. ends = [3, 2, 4]
  11. sliced_1 = fluid.layers.slice(input, axes=axes, starts=starts, ends=ends)
  12. # sliced_1 is input[:, 0:3, 0:2, 2:4].
  13.  
  14. # example 2:
  15. # attr starts is a list which contain tensor Variable.
  16. minus_3 = fluid.layers.fill_constant([1], "int32", -3)
  17. sliced_2 = fluid.layers.slice(input, axes=axes, starts=[minus_3, 0, 2], ends=ends)
  18. # sliced_2 is input[:, 0:3, 0:2, 2:4].