space_to_depth

  • paddle.fluid.layers.space_to_depth(x, blocksize, name=None)[源代码]

该OP对成块的空间数据进行重组,输出一个输入张量的拷贝,其高度和宽度维度上的值移至通道维度。

重组时,依据 blocksize 指明的数据块大小, 对形为 [batch, channel, height, width] 的输入张量进行space_to_depth(广度至深度)运算,生成形为 [batch, channel * blocksize * blocksize, height/blocksize, width/blocksize] 的输出:

  • 在各位置上,不重叠的,大小为 blocksize * blocksize 的块重组入深度depth
  • 输入各个块中的Y, X坐标变为输出张量通道索引的高序部位
  • 输入 channel 需可以被 blocksize 的平方整除
  • 输入的高度和宽度需可以被 blocksize 整除

该OP适用于在卷积间重放缩激活函数,并保持所有的数据。

范例如下:

  1. 给定形状为[1, 1, 4, 4]的输入 x
  2. x.data = [[[[1, 2, 5, 6],
  3. [3, 4, 7, 8],
  4. [9, 10, 13, 14],
  5. [11, 12, 15, 16]]]]
  6. 设置 blocksize = 2
  7.  
  8. 得到形状为[1, 4, 2, 2]的输出 out
  9. out.data = [[[[1, 2], [3, 4]],
  10. [[5, 6], [7, 8]],
  11. [[9, 10], [11, 12]],
  12. [[13, 14], [15, 16]]]]
  • 参数:
    • x (Variable) – 输入,形状为 [batch, channel, height, width] 的4维Tensor或LoD Tensor。数据类型支持int32,int64,float32或float64。
    • blocksize (int) – 在每个特征图上选择元素时采用的块大小,应该 >= 2
    • name (str,可选) - 具体用法请参见 Name ,一般无需设置,默认值为None。

返回:输出,形状为 [batch, channel * blocksize * blocksize, height/blocksize, width/blocksize] 的4维Tensor或LoD Tensor。数据类型与输入 x 一致。

返回类型:Variable

  • 抛出异常:
    • TypeError - blocksize 必须是int64类型

代码示例

  1. import paddle.fluid as fluid
  2. import numpy as np
  3.  
  4. data = fluid.data(
  5. name='data', shape=[1, 4, 2, 2], dtype='float32')
  6. space_to_depthed = fluid.layers.space_to_depth(
  7. x=data, blocksize=2)
  8.  
  9. exe = fluid.Executor(fluid.CPUPlace())
  10. data_np = np.arange(0,16).reshape((1,4,2,2)).astype('float32')
  11.  
  12. print(data_np)
  13. #array([[[[ 0., 1.], [ 2., 3.]],
  14. # [[ 4., 5.], [ 6., 7.]],
  15. # [[ 8., 9.], [10., 11.]],
  16. # [[12., 13.], [14., 15.]]]], dtype=float32)
  17.  
  18. out_main = exe.run(fluid.default_main_program(),
  19. feed={'data': data_np},
  20. fetch_list=[space_to_depthed])
  21.  
  22. print(out_main)
  23. #[array([[[[ 0.]], [[ 4.]], [[ 1.]], [[ 5.]],
  24. # [[ 8.]], [[12.]], [[ 9.]], [[13.]],
  25. # [[ 2.]], [[ 6.]], [[ 3.]], [[ 7.]],
  26. # [[10.]], [[14.]], [[11.]], [[15.]]]], dtype=float32)]