Create an New Operator

Background

What is a Custom Op

OneFlow abstracts all kinds of data processing into op (operator). Op acts on the input tensor and writes the result of the operation to the output tensor. OneFlow provides relatively comprehensive ops and they can be found in ops directory.

When OneFlow’s existing Python operators are not sufficient to build a neural network or when Python operators do not meet performance requirements. You can use C++ to develop custom op in OneFlow.

OneFlow provides a mechanism with which you can create custom op and register it in OneFlow then use custom op in Python.

The following diagram demonstrates the registration system for a custom op in OneFlow.

OneFlow UserOp Existing System

In the OneFlow framework, there are three types of registries associated with custom op.

  • OpGradRegistry:Manage gradient registration for automatic gradient calculation in backward graph.

  • OpRegistry:Manage op registrations for generating forward digraph and building Task Graph.

  • OpKernelRegistry:Manage kernel registrations for performing user logic at runtime.

We actually write custom op in C++ and generate a dynamic link library (so file). By loading the corresponding so file in Python that you can use the custom op.

The data structure of user op can be viewed at user_op_conf.proto

  1. syntax = "proto2";
  2. package oneflow;
  3. import "oneflow/core/framework/user_op_attr.proto";
  4. message UserOpConf {
  5. message ListString {
  6. repeated string s = 1;
  7. }
  8. required string op_type_name = 1;
  9. map<string, ListString> input = 2;
  10. map<string, ListString> output = 3;
  11. map<string, UserOpAttrVal> attr = 4;
  12. }

The op_type_name is a string which representing the class of op and indicate the globally unique ID of the op class. OneFlow queries and confirms the op class by op_type_name which will appear several times in the rest of this document.

Basic Concepts

  • Op_type_name:As mentioned above, op_type_name is the unique ID of op class. OneFlow queries and confirms op class by op_type_name, and then instantiates the op. The relationship between op class and op is similar to the relationship between class and object.

  • Op:Logical operators contain information of input and output shapes for mapping and reasoning, but do not contain logic for processing the data.

  • Kernel:When a logical op running, the processing logic will affect by physical device and data type. The specific processing logic is done by the kernel. Generally op has a one-to-many relationship with the kernel and we need to register the kernel for all the physical devices and data types that op supports.

  • Registration:Registration can be used to establish a link between a custom op and the OneFlow framework. A series of macros named REGISTER_XXX are provided in OneFlow to help with registration of op.

  • Loading the dynamic library:The custom op and its kernel are linked as dynamic library so files that need to be loaded before using them in Python and OneFlow provides oneflow.config.load_library to load the so files of custom op.

  • Python wrapper:Calling a custom op implemented at the C++ layer in Python requires writing a wrapper at the Python layer and OneFlow provides oneflow.user_op_builder to do this task.

Process of Writing a Custom Op

  1. Implementation and registration of op:The implementation of op is primarily used for forward digraph composition which includes specifying the name of op, inputs, outputs, configuration attributes and the necessary functions to infer the shape and data type of the tensor.

  2. Implementation and registration of the kernel for an op: The kernel is responsible for the specific computational process during running and an op may correspond to multiple kernels

  3. (optional) Implementation and registration of op’s corresponding grad: If the custom op needs to support backward spreading. Then we need to implement and register a backward function for it.

  4. Compile and link to get so file

  5. Load the so file in Python and use oneflow.user_op_builder to wrap a custom op written in C++.

  6. Testing.

Example

We will implement a custom op called “myrelu” which supports both CPU and GPU operations. For the complete code please refer to: code/extended_topics/create_user_op.

Implementation and Registration of Op

We defined op and completed the registration in myrelu_op.cpp:

  1. #include "oneflow/core/framework/framework.h"
  2. namespace oneflow {
  3. namespace {
  4. REGISTER_USER_OP("myrelu")
  5. .Input("in")
  6. .Output("out")
  7. .SetTensorDescInferFn(
  8. [](user_op::InferContext *ctx) -> Maybe<void> {
  9. *ctx->Shape4ArgNameAndIndex("out", 0) =
  10. *ctx->Shape4ArgNameAndIndex("in", 0);
  11. *ctx->Dtype4ArgNameAndIndex("out", 0) =
  12. *ctx->Dtype4ArgNameAndIndex("in", 0);
  13. return Maybe<void>::Ok();
  14. });
  15. } // namespace
  16. } // namespace oneflow

Analysis of the above codes:

  • oneflow/core/framework/framework.h contains all the controllers we need to create an op.

  • Almost all the APIs related to user op are in the namespace oneflow::user_op, so we use the namespace oneflow to simplify the type name.

  • The macro REGISTER_USER_OP is used to register the op and accepts myrelu as op_type_name.

  • After registering with REGISTER_USER_OP, it actually returns an OpRegistry class (path: oneflow\coreframework\user_op_registry.h) which can be called to complete the setting of a custom op:

  • Input("in") means that it has an input named “in”.

  • Output("out") means that it has an output named “out”.

  • SetTensorDescInferFn is used to set the shape and data type of the inferring function which describe the relationship between the input of this operator and shape and type of the output of this operator. In the above code, the shape and data type of the output is consistent with input.

Implementation and Registration of CPU Kernel

We implemented the CPU kernel in myrelu_cpu_kernel.cpp and registered it:

  1. #include "oneflow/core/framework/framework.h"
  2. namespace oneflow {
  3. namespace {
  4. template <typename T>
  5. void MyRelu(DeviceCtx *ctx, const int64_t n, const T *x, T *y) {
  6. T zero = (T)(0);
  7. for (int64_t i = 0; i != n; ++i) {
  8. y[i] = std::max(x[i], zero);
  9. }
  10. }
  11. template <DeviceType device_type, typename T>
  12. class ReluKernel final : public user_op::OpKernel {
  13. public:
  14. ReluKernel() = default;
  15. ~ReluKernel() = default;
  16. private:
  17. void Compute(user_op::KernelComputeContext *ctx) const override {
  18. const user_op::Tensor *in_tensor = ctx->Tensor4ArgNameAndIndex("in", 0);
  19. user_op::Tensor *out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0);
  20. MyRelu<T>(ctx->device_ctx(),
  21. in_tensor->shape().elem_cnt(),
  22. in_tensor->dptr<T>(),
  23. out_tensor->mut_dptr<T>());
  24. }
  25. bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
  26. };
  27. #define REGISTER_RELU_KERNEL(device, dtype) \
  28. REGISTER_USER_KERNEL("myrelu") \
  29. .SetCreateFn<ReluKernel<device, dtype>>() \
  30. .SetIsMatchedHob( \
  31. (user_op::HobDeviceTag() == device) & \
  32. (user_op::HobDataType("out", 0) \
  33. == GetDataType<dtype>::value));
  34. REGISTER_RELU_KERNEL(DeviceType::kCPU, float)
  35. REGISTER_RELU_KERNEL(DeviceType::kCPU, double)
  36. } // namespace
  37. } // namespace oneflow

To implement the kernel in OneFlow, you must define a class which inherits from oneflow::user_op::OpKernel and rewrite the virtual functions of it.

In the above code, we rewrite Compute and AlwaysComputeWhenAllOutputsEmpty and their respective meanings are:

  • Compute must be rewritten to implement the specific operating logic.

  • AlwaysComputeWhenAllOutputsEmpty must be rewritten to return false in most cases. For very few ops that need to maintain state internally, and therefore need to call the kernel for calculation even if the output is empty, it should return true.

After implementing the kernel class, you need to call REGISTER_USER_KERNEL to register it. The string parameter that REGISTER_USER_KERNEL("myrelu") accepts is op_type_name which is used to complete registration and querying. You also need to use op_type_name when wrapping op at the Python layer.

REGISTER_USER_KERNEL("myrelu") returns an OpKernelRegistry object. The methods that need to be called to set the registration information are mention in the code above.

  • SetCreateFn<T>(): The method of this template’s parameter T is our implementation of the kernel class which OneFlow will use it to create the kernel object.

  • SetIsMatchedHob:Because an op may have more than one kernels. You need to call SetIsMatchedHob to select a specific kernel for the calculation according to the physical device and data format. This method accepts an expression and when the expression is true, OneFlow will call the kernel to complete the calculation.

Implementation and Registration of GPU Kernel

We implemented the GPU version of the kernel in myrelu_gpu_kernel.cu and registered it:

  1. #include "oneflow/core/framework/framework.h"
  2. #include <cub/cub.cuh>
  3. namespace oneflow {
  4. namespace {
  5. template <typename T>
  6. __global__ void ReluForwardGpu(const int n, const T *x, T *y) {
  7. CUDA_1D_KERNEL_LOOP(i, n) { y[i] = x[i] > 0 ? x[i] : 0; }
  8. }
  9. class ReluGpuFloatKernel final : public user_op::OpKernel {
  10. public:
  11. ReluGpuFloatKernel() = default;
  12. ~ReluGpuFloatKernel() = default;
  13. private:
  14. void Compute(user_op::KernelComputeContext *ctx) const override {
  15. const user_op::Tensor *in_tensor = ctx->Tensor4ArgNameAndIndex("in", 0);
  16. user_op::Tensor *out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0);
  17. int32_t n = in_tensor->shape().elem_cnt();
  18. const float *in_ptr = in_tensor->dptr<float>();
  19. float *out_ptr = out_tensor->mut_dptr<float>();
  20. ReluForwardGpu<float>
  21. <<<32, 1024, 0, ctx->device_ctx()->cuda_stream()>>>(n, in_ptr, out_ptr);
  22. }
  23. bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
  24. };
  25. #define REGISTER_RELU_KERNEL(device, dtype) \
  26. REGISTER_USER_KERNEL("myrelu") \
  27. .SetCreateFn<ReluGpuFloatKernel>() \
  28. .SetIsMatchedHob( \
  29. (user_op::HobDeviceTag() == device) & \
  30. (user_op::HobDataType("out", 0) \
  31. == GetDataType<dtype>::value));
  32. REGISTER_RELU_KERNEL(DeviceType::kGPU, float)
  33. REGISTER_RELU_KERNEL(DeviceType::kGPU, double)
  34. } // namespace
  35. } // namespace oneflow

The process of implementing and registering the GPU kernel is almost identical to the CPU kernel. The main differences are:

  • Because CUDA programming is used, the CUDA header files are included.

  • Compute uses GPU methods.

  • SetIsMatchedHob set the matching device as GPU.

Besides that, because of the use of CUDA, we need to use the nvcc compiler (instead of g++) to compile the GPU kernel.

Compiling Option Description

The oneflow.sysconfig contains the get_compile_flags, get_include, get_lib, and get_link_flags which corresponding to:

  • Compiling Options
  • Dictionary of header file
  • Dictionary of link library
  • Linking options

For example:

  1. >>> import oneflow
  2. >>> oneflow.sysconfig.get_compile_flags()
  3. ['-I/home/yaochi/oneflow/build/python_scripts/oneflow/include', '-DHALF_ENABLE_CPP11_USER_LITERALS=0', '-DWITH_CUDA', '-D_GLIBCXX_USE_CXX11_ABI=0']

You can also get compile and link options directly by using command:

  1. python -c "import oneflow; print(' '.join(oneflow.sysconfig.get_compile_flags()))"
  2. python -c "import oneflow; print(' '.join(oneflow.sysconfig.get_link_flags()))"

For the GPU kernel, the cudart library also needs to be specified when linking.

Get Dynamic Library by Compilation and Linking

For this simple example, you can use the following Makefile to build:

  1. CFLAGS = $(shell python -c "import oneflow; print(' '.join(oneflow.sysconfig.get_compile_flags()))")
  2. LFLAGS = $(shell python -c "import oneflow; print(' '.join(oneflow.sysconfig.get_link_flags()))")
  3. CUDAPATH = /usr/local/cuda-10.1/lib64
  4. all: final_relu.so
  5. myrelu_op.o: myrelu_op.cpp
  6. g++ -std=c++11 -c myrelu_op.cpp \
  7. -o myrelu_op.o \
  8. -fPIC \
  9. ${CFLAGS} \
  10. ${LFLAGS} \
  11. -O2
  12. myrelu_cpu_kernel.o: myrelu_cpu_kernel.cpp
  13. g++ -std=c++11 -c myrelu_cpu_kernel.cpp \
  14. -o myrelu_cpu_kernel.o \
  15. $(CFLAGS) -fPIC
  16. myrelu_gpu_kernel.o: myrelu_gpu_kernel.cu
  17. nvcc -std=c++11 -c myrelu_gpu_kernel.cu \
  18. -o myrelu_gpu_kernel.o \
  19. $(CFLAGS) -x cu -Xcompiler -fPIC
  20. final_relu.so: myrelu_op.o myrelu_cpu_kernel.o myrelu_gpu_kernel.o
  21. g++ -std=c++11 myrelu_op.o \
  22. myrelu_cpu_kernel.o \
  23. myrelu_gpu_kernel.o \
  24. -shared -o final_relu.so \
  25. $(CFLAGS) \
  26. -fPIC \
  27. -L$(CUDAPATH) \
  28. -lcudart \
  29. $(LFLAGS)
  30. clean:
  31. rm -rf *.so *.o

We use g++ to compile myrelu_op.cpp and myrelu_cpu_kernel.cpp, use nvcc to compile myrelu_gpu_kernel.cpp. Then get the target file (“.o” file) and link the target file to final_ relu.so.

We are going to load final_relu.so in Python then use wrappers and custom op.

Using the Custom Op in Python

Using a custom op in Python needs the following steps:

  • Load the so file by oneflow.config.load_library.

  • Use oneflow.user_op_builder to generating Python wrapper for custom op.

  • Call the above result of Python wrapper.

The following code encapsulates myrelu at the Python layer and call it:

  1. import oneflow as flow
  2. import numpy as np
  3. import oneflow.typing as tp
  4. # load modules
  5. flow.config.load_library("final_relu.so")
  6. # default configuration
  7. flow.config.gpu_device_num(1)
  8. # python op wrapper function
  9. def myrelu(input_blob):
  10. op = (
  11. flow.user_op_builder("op_myrelu")
  12. .Op("myrelu")
  13. .Input("in", [input_blob])
  14. .Output("out")
  15. .Build()
  16. )
  17. return op.InferAndTryRun().SoleOutputBlob()
  18. # network code
  19. @flow.global_function()
  20. def MyJob(x: tp.Numpy.Placeholder((5,), dtype=flow.float32)) -> tp.Numpy:
  21. return myrelu(x)
  22. if __name__ == "__main__":
  23. input = np.array([-2, -1, 0, 1, 2], dtype=np.float32)
  24. output = MyJob(input)
  25. print(input)
  26. print(output)

The expected results are:

  1. [-2. -1. 0. 1. 2.]
  2. [0. 0. 0. 1. 2.]

In the above code: flow.config.load_library("final_relu.so") is to load the so file.

We are focus on the process of building and running the python wrapper in myrelu.

flow.user_op_builder("op_myrelu") actually returns a UserOpConfBuilder object named op_myrelu.

  1. op = (
  2. flow.user_op_builder("op_myrelu")
  3. .Op("myrelu")
  4. .Input("in", [input_blob])
  5. .Output("out")
  6. .Build()
  7. )

This object contains Op, Input and and etc methods which are used to encapsulate custom op. Details explanation are as follows:

  • Op("myrelu"): The parameter must be the op_type_name from the previous C++ registration which OneFlow uses to find the registered op type and instantiate the op.

  • Input("in", [input_blob]): Corresponds to Input when op is registered in C++ that the first parameter must be the same as the string set by Input when op is registered in C++. The second parameter is the blob of the input which is a list. Because an op allows multiple inputs.

  • Output("out"): Corresponds to Output when op registered in C++.

  • Build:After the above settings are complete, call Build to get the Python wrapper from the custom op.

The following code will get the blob of the custom op:

  1. return op.InferAndTryRun().SoleOutputBlob()

InferAndTryRun completes the derivation and returns UserOp. If the returned blob has only one output. We cab use SoleOutputBlob to get the unique output. Otherwise use RemoteBlobList to get a list of multiple blobs.

So far, we have built the myrelu which is a relatively simple op. But if we need to build a more complex op, we should use some additional features in the registration process. We’ll introduce it from the aspects of op registration, kernel registration, gradient registration and Python layer wrapping.

Detailed Introduction of OpRegistry

Attr

Some ops require configuration properties in addition to inputs and outputs. For example, the reshape needs to be configured the shape and the conv needs to be configured the alignment method. We can use the Attr at registration to set attributes for op. For example:

  1. OpRegistry& Attr<cpp_type>(const std::string& name);

We just need to specify the name and type of the attribute. For example:

  1. REGISTER_USER_OP("reshape")
  2. .Input("in")
  3. .Output("out")
  4. .Attr<shape>("shape")
  1. REGISTER_USER_OP("conv2d")
  2. .Input("in")
  3. .Input("weight")
  4. .Output("out")
  5. .Attr<std::vector<int32_t>>("padding_before")

In OneFlow, we currently support the following C++ data:

UserOpAttrTypeCorresponding C++ data types
kAtInt32int32_t
kAtInt64int64_t
kAtBoolbool
kAtFloatfloat
kAtDoubledouble
kAtShapeoneflow::Shape
kAtListInt32std::vector
kAtListInt64std::vector
kAtListFloatstd::vector< float >
kAtStringstd::string

We can pass an additional parameter and configure a default value for it which is the corresponding C++ datatype in the table. Such as:

  1. .Attr<bool>("is_transpose", false)
  2. .Attr<int32_t>("size", 10)
  3. .Attr<std::vector<int32_t>>("vector_of_size", std::vector<int32_t>{10, 11, 12})

SetCheckAttrFn

For some Attributes, they require a more detailed delineation of the range which can be specified by SetCheckAttrFn when registering the Op.

Take Conv op as an example, it has a configuration option called data_format which is a string type but the data must be channels_first or channels_last.

  1. .Attr<std::string>("data_format", std::string("NCHW"))
  2. .SetCheckAttrFn(
  3. [](const user_op::UserOpDefWrapper& def,
  4. const user_op::UserOpConfWrapper& conf) -> Maybe<void> {
  5. std::string data_format = conf.attr<std::string>("data_format");
  6. if (data_format == "channels_first" || data_format == "channels_last") {
  7. return Maybe<void>::Ok();
  8. }
  9. return oneflow::Error::CheckFailed()
  10. << "data_format value: "
  11. << data_format
  12. << " for Conv op is illegal.";
  13. })

Set a function to check that returns Maybe<void>::Ok() when the value of the attribute matches the requirement. Otherwise returns oneflow::Error::CheckFailed().

Multiple In/Output

For some ops, they may have multiple input or output and we need to specify the number of inputs and outputs when we register it.

Input example:

  1. // input must have 1 blob
  2. .Input("input")
  3. // input must have 5 blobs
  4. .Input("input", 5)
  5. // input input must have at least 5 blobs
  6. .InputWithMinimum("input", 5)
  7. // input can have no blob or 1 blob
  8. .OptionalInput("input")
  9. // input can have no blob or 5 blobs
  10. .OptionalInput("input", 5)
  11. // input can have no blob or at least 5 blobs
  12. .OptionalInputWithMininum("input", 5)

Output setting is similar to Input.

SetGetSbpFn

SetGetSbpFn is for config the SBP of this op. Example of “add_n”:

  1. REGISTER_USER_OP("add_n")
  2. .InputWithMinimum("in", 2)
  3. .Output("out")
  4. .SetGetSbpFn([](user_op::SbpContext* ctx) {
  5. int64_t num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes();
  6. for (int64_t i = 0; i < num_axes; ++i) {
  7. ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg("out", 0), i).Build();
  8. }
  9. ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(user_op::OpArg("out", 0)).Build();
  10. return Maybe<void>::Ok();
  11. });

Detailed Introduction of OpKernelRegistry

SetInferTmpSizeFn

In some kernel implementations of op, some extra buffer may be required to store temporary data during the Compute.

We can specify the buffer size when registering the kernel by using the SetInferTmpSizeFn. Then we get the buffer and use it in the Compute function.

The following code registers the kernel with SetInferTmpSizeFn to specify a buffer size as 1024 bytes:

  1. REGISTER_USER_KERNEL("XOp")
  2. .SetInferTmpSizeFn(
  3. [](const oneflow::user_op::InferContext*) {
  4. return 1024;
  5. });

Once the buffer size is set by SetInferTmpSizeFn, this buffer can be retrieved in Compute by calling the KernelContext::Tensor4ArgNameAndIndex. This buffer is encapsulated as oneflow::user_op::Tensor which can be converted to other types of pointers by calling the dptr or mut_dptr.

  1. class XKernel final : public oneflow::user_op::OpKernel {
  2. void Compute(oneflow::user_op::KernelContext* ctx) override {
  3. oneflow::user_op::Tensor* tmp = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
  4. //The conversion yields a char* buffer of 1024 bytes.
  5. char* pBuff = tmp->mut_dptr<char>();
  6. ...
  7. }
  8. };

Detailed Introduction of OpGradRegistry

Oneflow is automatically get gradient during backward map expansion and the OneFlow framework uses Automatic Differentiation to get the gradient which means automatically find the gradient of the entire expression using the chain rule.

In order to automatically get gradient a custom op, we need to register it with REGISTER_USER_OP_GRAD. From a mathematical point of view, the registration process is the computation of the backward derivation that we specify for our custom op. From a programming point of view, it is to set up a backward-generating function for a custom op. Within that function, we write code that specifies how the input gradient of that op is to be calculated.

In order to calculate the gradient of a custom op, we need to construct the gradient of the input base on the input and output of the custom op. In most cases, we can represent the process of calculating the gradient of the input through the existing operators and their combination in OneFlow.

The calculation of the input gradient usually consists of the following steps:

  1. Use ctx->DefineOp() and BackwardOpBuilder to represent methods for calculating input gradients. Because input gradient calculations may be combinations of multiple operations. Therefore DefineOp and BackwardOpBuilder may be used for multiple times.

  2. After defining the calculation process in the previous step, the required gradient is finally recorded in the output of some operator. We need to call the ctx->FwOp().InputGradBind() to combine the result of the previous calculation to the input gradient of the custom op.

The following example (the complete code, including tests, can be found in myop_grad repository). A custom op called myop will be used to register backward generating functions. This op is only used in this document to show the registration process which compute function is set as 3*x*x.

Then it is easy to obtain the relationship between its forward and backward propagation as shown below. The gradient of x in the reverse process is computed as 6*x*dy.

User Defined OP - 图2

The forward op of myop is defined as follows:

  1. REGISTER_USER_OP("myop").Input("in").Output("out").SetTensorDescInferFn(
  2. [](user_op::InferContext *ctx) -> Maybe<void> {
  3. *ctx->Shape4ArgNameAndIndex("out", 0) =
  4. *ctx->Shape4ArgNameAndIndex("in", 0);
  5. *ctx->Dtype4ArgNameAndIndex("out", 0) =
  6. *ctx->Dtype4ArgNameAndIndex("in", 0);
  7. return Maybe<void>::Ok();
  8. });

That is myop contains the only input in and the only output out.

The reverse gradient registration of myop is as follows:

  1. REGISTER_USER_OP_GRAD("myop").SetBackwardOpConfGenFn(
  2. [](user_op::BackwardOpConfContext* ctx) {
  3. const auto op1_name = ctx->FwOp().op_name() + "_grad1";
  4. // The operator op1_name is used to calculate the gradient of myop.in
  5. ctx->DefineOp(op1_name,
  6. [&ctx](user_op::BackwardOpBuilder& builder) {
  7. return builder.OpTypeName("multiply")
  8. .InputBind("x", ctx->FwOp().input("in", 0)) //multiply.x <- myop.in
  9. .InputBind("y", ctx->FwOp().output_grad("out", 0)) //multiply.y <- the gradient of myop.out
  10. .Output("out")
  11. .Build();
  12. });
  13. const auto op2_name = ctx->FwOp().op_name() + "_grad2";
  14. // The operator op2_name is used to calculate 6*op1_name.
  15. ctx->DefineOp(op2_name,
  16. [&ctx, &op1_name](user_op::BackwardOpBuilder& builder) {
  17. return builder.OpTypeName("scalar_mul")
  18. .InputBind("in", ctx->GetOp(op1_name).output("out", 0))
  19. .Attr("has_float_operand", true)
  20. .Attr("has_int_operand", false)
  21. .Attr("float_operand", static_cast<double>(6))
  22. .Attr("int_operand", static_cast<int64_t>(6))
  23. .Output("out")
  24. .Build();
  25. });
  26. // (the gradient of myop.in) <- op1_name.out
  27. ctx->FwOp().InputGradBind(user_op::OpArg("in", 0),
  28. [&ctx, &op2_name]() -> const std::string& {
  29. return ctx->GetOp(op2_name)
  30. .output("out", 0);
  31. });
  32. });

The string parameter accepted by REGISTER_USER_OP_GRAD("myop") is op_type_name which needs to be the same as registered with REGISTER_USER_OP.

REGISTER_USER_OP_GRAD("myop") returns an oneflow::user_op::OpGradRegistry object that we can call it to set the custom op’s backward generating function.

In the above gradient registration process, the expression for the gradient of myop is 6*x*dy which is demonstrated in the code.

First op1_name is defined and x*dy is solved by using the existing operator multiply:

  1. // The operator op1_name is used to calculate the gradient of myop.in
  2. ctx->DefineOp(op1_name,
  3. [&ctx](user_op::BackwardOpBuilder& builder) {
  4. return builder.OpTypeName("multiply")
  5. .InputBind("x", ctx->FwOp().input("in", 0)) //multiply.x <- myop.in
  6. .InputBind("y", ctx->FwOp().output_grad("out", 0)) //multiply.y <- myop.out的梯度
  7. .Output("out")
  8. .Build();
  9. });

Then op2_name is defined and use the existing operator op2_name to solve for 6*op1_name.

  1. // The operator op2_name is used to calculate 6*op1_name.
  2. ctx->DefineOp(op2_name,
  3. [&ctx, &op1_name](user_op::BackwardOpBuilder& builder) {
  4. return builder.OpTypeName("scalar_mul")
  5. .InputBind("in", ctx->GetOp(op1_name).output("out", 0))
  6. .Attr("has_float_operand", true)
  7. .Attr("has_int_operand", false)
  8. .Attr("float_operand", static_cast<double>(6))
  9. .Attr("int_operand", static_cast<int64_t>(6))
  10. .Output("out")
  11. .Build();
  12. });

Finally bind the output of op2_name (i.e., 6*x*dy) to the input of myop to complete the registration.

  1. // (the gradient of myop.in) <- op1_name.out
  2. ctx->FwOp().InputGradBind(user_op::OpArg("in", 0),
  3. [&ctx, &op2_name]() -> const std::string& {
  4. return ctx->GetOp(op2_name)
  5. .output("out", 0);
  6. });

The above code is the complete process of registering a gradient and the related classes and methods will be described in below.

SetBackwardOpConfGenFn

We use OpGradRegistry::SetBackwardOpConfGenFn(fn) to set the backward generating function fn which has the following prototype:

  1. void fn(BackwardOpConfContext* ctx);

BackwardOpConfContext* ctx has all information needed to generate the op.

BackwardOpConfContext

The common methods and their purpose used in BackwardOpConfContext as follows:

  • UserOpWrapper& FwOp();: Get forward op.

  • GetOp(op_name): Create and get the corresponding op based on op_name. GetOp uses a lazy init mechanism and the corresponding op is not actually created until GetOp is called.

  • void DefineOp(op_name, fn):Define fn of the op named op_name. When ctx->GetOp(op_name) is called, fn is triggered in the OneFlow for Op creation and if the op has already been created. Then the result is retrieved directly. The fn receives a BackwardOpBuilder parameter for constructing the reverse op. We will introduce BackwardOpBuilder later on.

Detailed Introduction of BackwardOpBuilder

BackwardOpBuilder is used to build a reverse op. The fragment of above code is an example:

  1. ctx->DefineOp(op1_name,
  2. [&ctx](user_op::BackwardOpBuilder& builder) {
  3. return builder.OpTypeName("multiply")
  4. .InputBind("x", ctx->FwOp().input("in", 0)) //multiply.x <- myop.in
  5. .InputBind("y", ctx->FwOp().output_grad("out", 0)) //multiply.y <- myop.out的梯度
  6. .Output("out")
  7. .Build();
  8. });

In this function, we call Build to build a reverse op for computing x*dy. The purpose of each operator is as follows:

  • OpTypeName("multiply") specifies the op_type_name of an op that is used to help us compute the reverse gradient.

  • InputBind(arg_name, blob) binds the input arg_name of multiply to the specified blob and can be called for multiple times. If the arg_name corresponds to multiple blob which means the order of Input is the order of the corresponding index.

  • Output(arg_name, num) Specifies the number of output blobes that actually correspond to the arg_name which defaults to 1 if num is not filled in.

  • Attr(attr_name, val) sets the value of the attribute which same in the registration.

  • Calling Build() after above configuration, then the construction of the reverse op is completed.

Detailed Introduction of UserOpWrapper

Calling ctx->FwOp() will return the UserOpWrapperof myop and complete the gradient binding by calling the UserOpWrapper.

  1. ctx->FwOp().InputGradBind(user_op::OpArg("in", 0),
  2. [&ctx, &op2_name]() -> const std::string& {
  3. return ctx->GetOp(op2_name)
  4. .output("out", 0);
  5. });

Common methods for UserOpWrapper are:

  • InputGradBind(input, grad_fn):Bind the input of the forward op and get the gradient function grad_fn. OneFlow automatically determines whether input needs to generate a backward gradient, if needed, it will trigger grad_fn and binds the input.

  • input(arg_name, index):Get the blob corresponding to the arg_name of input.

  • output(arg_name,index):Get the blob corresponding to the arg_name of output.

  • output_grad(output_arg_name, index):Get the output_arg_name of the forward op which is the blob of the corresponding backward gradient.

  • attr(attr_name):Get the value corresponding to the attr_name.

  • arg_tensor_desc(arg_name, index):Returns the input/output tensor information of the forward op which including shape, dtype and etc.

Customized Op for Calculating Gradients

As we mentioned earlier, in most cases, the process of calculating a gradient can be represented by a combination of existing ops. However, when it is difficult to use an existing op to solve the gradient for a particular forward op that we need to design and create operators specifically for the gradient calculation. Example can be found in: relu_op.cpp.

Detailed Introduction of UserOpConfBuilder

In Python frontend of OneFlow, we provide UserOpConfBuilder to build the wrapper of custom op which is used in Use custom opp in Python previously. Here is the summary of the relationship between UserOpConfBuilder in Python layer and C++ layer.

For example, we have wrapped a cast:

  1. def cast(x, dtype, name):
  2. return (
  3. flow.user_op_builder(name)
  4. .Op("cast")
  5. .Input("in", [x])
  6. .Output("out")
  7. .Attr("dtype", dtype)
  8. .Build()
  9. .InferAndTryRun()
  10. .RemoteBlobList()[0]
  11. )
  12. )
  • Op(op_type_name):The accepted parameter is op_type_name when it is registered in C++.

  • Input(input_name, input_blob_list)input_name should be the same as the first parameter of Input when registering this op in C++.

  • Output(output_name, num=1)output_name and num should be the same as Output of op when registration in C++.

  • Attr(attr_name, attr_value)attr_name corresponds to the attribute of OpRegistry::Attr used for C++ registration and attr_value should be the same type as the attribute type when declaration.

  • Build():Build the user op for the Python layer.

The derivation can be done by calling InferAndTryRun in the user op and the result can be retrieved by calling RemoteBlobList or SoleOutputBlob.

  • RemoteBlobList:Get all outputs which applies to op with multiple outputs and all ops are placed in a list.

  • SoleOutputBlob:Get unique outputs which applies to op with one output.