PyLayer

class paddle.autograd. PyLayer [源代码]

Paddle通过创建 PyLayer 子类的方式实现Python端自定义算子,这个子类必须遵守以下规则:

  1. 子类必须包含静态的 forwardbackward 函数,它们的第一个参数必须是 PyLayerContext ,如果 backward 的某个返回值在 forward 中对应的 Tensor 是需要梯度,这个返回值必须为 Tensor

  2. backward 除了第一个参数以外,其他参数都是 forward 函数的输出 Tensor 的梯度,因此, backward 输入的 Tensor 的数量必须等于 forward 输出 Tensor 的数量。如果你需在 backward 中使用 forward 的输入 Tensor ,你可以将这些 Tensor 输入到 PyLayerContextsave_for_backward 方法,之后在 backward 中使用这些 Tensor

  3. backward 的输出可以是 Tensor 或者 list/tuple(Tensor) ,这些 Tensorforward 输出 Tensor 的梯度。因此, backward 的输出 Tensor 的个数等于 forward 输入 Tensor 的个数。

构建完自定义算子后,通过 apply 运行算子。

示例代码

  1. import paddle
  2. from paddle.autograd import PyLayer
  3. # Inherit from PyLayer
  4. class cus_tanh(PyLayer):
  5. @staticmethod
  6. def forward(ctx, x, func1, func2=paddle.square):
  7. # ctx is a context object that store some objects for backward.
  8. ctx.func = func2
  9. y = func1(x)
  10. # Pass tensors to backward.
  11. ctx.save_for_backward(y)
  12. return y
  13. @staticmethod
  14. # forward has only one output, so there is only one gradient in the input of backward.
  15. def backward(ctx, dy):
  16. # Get the tensors passed by forward.
  17. y, = ctx.saved_tensor()
  18. grad = dy * (1 - ctx.func(y))
  19. # forward has only one input, so only one gradient tensor is returned.
  20. return grad
  21. data = paddle.randn([2, 3], dtype="float64")
  22. data.stop_gradient = False
  23. z = cus_tanh.apply(data, func1=paddle.tanh)
  24. z.mean().backward()
  25. print(data.grad)

forward ( ctx, args, kwargs* )

forward 函数必须被子类重写,它的第一个参数是 PyLayerContext 的对象,其他输入参数的类型和数量任意。

参数

  • *args (tuple) - 自定义算子的输入

  • **kwargs (dict) - 自定义算子的输入

返回:Tensor或至少包含一个Tensor的list/tuple

示例代码

  1. import paddle
  2. from paddle.autograd import PyLayer
  3. class cus_tanh(PyLayer):
  4. @staticmethod
  5. def forward(ctx, x):
  6. y = paddle.tanh(x)
  7. # Pass tensors to backward.
  8. ctx.save_for_backward(y)
  9. return y
  10. @staticmethod
  11. def backward(ctx, dy):
  12. # Get the tensors passed by forward.
  13. y, = ctx.saved_tensor()
  14. grad = dy * (1 - paddle.square(y))
  15. return grad

backward ( ctx, args, kwargs* )

backward 函数的作用是计算梯度,它必须被子类重写,其第一个参数为 PyLayerContext 的对象,其他输入参数为 forward 输出 Tensor 的梯度。它的输出 Tensorforward 输入 Tensor 的梯度。

参数

  • *args (tuple) - forward 输出 Tensor 的梯度。

  • **kwargs (dict) - forward 输出 Tensor 的梯度。

返回: forward 输入 Tensor 的梯度。

示例代码

  1. import paddle
  2. from paddle.autograd import PyLayer
  3. class cus_tanh(PyLayer):
  4. @staticmethod
  5. def forward(ctx, x):
  6. y = paddle.tanh(x)
  7. # Pass tensors to backward.
  8. ctx.save_for_backward(y)
  9. return y
  10. @staticmethod
  11. def backward(ctx, dy):
  12. # Get the tensors passed by forward.
  13. y, = ctx.saved_tensor()
  14. grad = dy * (1 - paddle.square(y))
  15. return grad

apply ( cls, args, kwargs* )

构建完自定义算子后,通过 apply 运行算子。

参数

  • *args (tuple) - 自定义算子的输入

  • **kwargs (dict) - 自定义算子的输入

返回:Tensor或至少包含一个Tensor的list/tuple

示例代码

  1. import paddle
  2. from paddle.autograd import PyLayer
  3. class cus_tanh(PyLayer):
  4. @staticmethod
  5. def forward(ctx, x, func1, func2=paddle.square):
  6. ctx.func = func2
  7. y = func1(x)
  8. # Pass tensors to backward.
  9. ctx.save_for_backward(y)
  10. return y
  11. @staticmethod
  12. def backward(ctx, dy):
  13. # Get the tensors passed by forward.
  14. y, = ctx.saved_tensor()
  15. grad = dy * (1 - ctx.func(y))
  16. return grad
  17. data = paddle.randn([2, 3], dtype="float64")
  18. data.stop_gradient = False
  19. # run custom Layer.
  20. z = cus_tanh.apply(data, func1=paddle.tanh)