新增OP

以下以添加argmax为例,详细说明新增op的方法。

1. 添加OpParam 结构体以传导 Op 的输入和输出

  • 这里命名为 ArgmaxParam

  • paddlelite/lite/operators/op_params.h 中添加 ArgmaxParam 结构体,代码如下:

    1. struct ArgmaxParam {
    2. lite::Tensor* X{};
    3. lite::Tensor* Out{};
    4. int Axis{0};
    5. };

2. 添加 Argmax Op 并注册

  • 在paddlelite/lite/operators/目录下新建argmax_op.h文件,主要代码如下:

    1. class ArgmaxOpLite : public OpLite {
    2. public:
    3. ArgmaxOpLite() {}
    4. explicit ArgmaxOpLite(const std::string &op_type) : OpLite(op_type) {}
    5. bool CheckShape() const override;
    6. bool InferShape() const override;
    7. bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
    8. void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
    9. std::string DebugString() const override { return "argmax"; }
    10. private:
    11. mutable ArgmaxParam param_;
    12. };

    ArgmaxOpLite 继承 OpLite ,成员变量包括 ArgmaxParam 结构体,需要实现的接口包括 CheckShape()InferShape()AttachImp()AttachKernel()DebugString() 函数。AttachKernel()DebugString()函数较为简单,此处直接实现;

  • paddlelite/lite/operators/ 目录下新建argmax_op.cc文件,需要具体实现CheckShape()InferShape()AttachImp()函数。CheckShape()函数检查输入是否符合要求,InferShape()函数基于输入推断得到输出的维度,AttachImp()函数绑定Op的输入输出。然后在argmax_op.cc文件中注册argmax,核心代码如下:

    1. bool ArgmaxOpLite::CheckShape() const {
    2. CHECK_OR_FALSE(param_.X);
    3. CHECK_OR_FALSE(param_.Out);
    4. CHECK_OR_FALSE(param_.Axis < (param_.X)->dims().size());
    5. return true;
    6. }
    7. bool ArgmaxOpLite::InferShape() const {
    8. auto x_dims = param_.X->dims();
    9. int x_rank = x_dims.size();
    10. int axis = param_.Axis;
    11. if (axis < 0) axis += x_rank;
    12. std::vector<int64_t> out_dims;
    13. for (int64_t i = 0; i < axis; i++) {
    14. out_dims.push_back(x_dims[i]);
    15. }
    16. for (int64_t i = axis + 1; i < x_rank; i++) {
    17. out_dims.push_back(x_dims[i]);
    18. }
    19. // Set output dims
    20. param_.Out->Resize(lite::DDim(out_dims));
    21. return true;
    22. }
    23. bool ArgmaxOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
    24. auto x = op_desc.Input("X").front();
    25. auto out = op_desc.Output("Out").front();
    26. param_.X = scope->FindVar(x)->GetMutable<lite::Tensor>();
    27. param_.Out = scope->FindVar(out)->GetMutable<lite::Tensor>();
    28. param_.Axis = op_desc.GetAttr<int>("Axis");
    29. return true;
    30. }
    31. REGISTER_LITE_OP(argmax, paddle::lite::operators::ArgmaxOpLite);
  • 在paddlelite/lite/operators/CMakeLists.txt中添加add_operator(argmax_op basic SRCS argmax_op.cc DEPS ${op_DEPS})

3. 添加Argmax Kernel并绑定

以下以arm端argmax实现为例说明

  • 在paddlelite/lite/kernels/arm/目录下新建argmax_compute.h文件,声明ArgmaxCompute类,并继承KernelLite,主要代码如下:

    1. class ArgmaxCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
    2. public:
    3. using param_t = operators::ArgmaxParam;
    4. void Run() override;
    5. virtual ~ArgmaxCompute() = default;
    6. };
  • 在paddlelite/lite/kernels/arm/目录下新建argmax_compute.cc文件,主要实现Run函数。Run()函数调用paddlelite/lite/bachends/arm/math/argmax.h中的argmax_func()函数,根据输入计算输出。最后在argmax_compute.cc文件中,我们绑定argmax的输入输出(为tensor的输入参数都需要绑定),代码如下:

    1. void ArgmaxCompute::Run() {
    2. auto& param = Param<operators::ArgmaxParam>();
    3. lite::Tensor* input = param.X;
    4. lite::Tensor* output = param.Out;
    5. int axis = param.Axis;
    6. lite::arm::math::argmax_func(input, axis, output);
    7. return;
    8. }
    9. REGISTER_LITE_KERNEL(
    10. argmax, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ArgmaxCompute, def)
    11. .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
    12. .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
    13. .Finalize();
  • 在paddlelite/lite/kernels/arm/CMakeLists.txt中添加

    1. add_kernel(argmax_compute_arm ARM basic SRCS argmax_compute.cc DEPS ${lite_kernel_deps} math_arm)

4. 添加Argmax实现

  • 在paddlelite/lite/backends/arm/math/目录下新建argmax.h文件,声明argmax_func()函数,代码如下:

    1. void argmax_func(const lite::Tensor* input, const int axis, lite::Tensor* output);
  • 在paddlelite/lite/backends/arm/math/目录下新建argmax.cc文件,具体实现argmax_func()函数,代码如下:

    1. void argmax_func(const lite::Tensor *input,
    2. const int axis,
    3. lite::Tensor *output) {
    4. auto input_ddim = input->dims();
    5. auto output_ddim = output->dims();
    6. const int size = input_ddim[axis];
    7. const int in_channel = input_ddim.count(axis, input_ddim.size());
    8. const int out_channel = output_ddim.count(axis, output_ddim.size());
    9. const int in_stride = input_ddim.count(axis + 1, input_ddim.size());
    10. const int out_stride = input_ddim.count(0, axis);
    11. for (int n = 0; n < out_stride; n++) {
    12. for (int k = 0; k < in_stride; k++) {
    13. const float *in_ptr = input->data<float>() + n * in_channel + k;
    14. std::vector<std::pair<float, int>> vec;
    15. vec.resize(size);
    16. for (int i = 0; i < size; i++) {
    17. vec[i] = std::make_pair(in_ptr[i * in_stride], i);
    18. }
    19. // sort
    20. std::partial_sort(vec.begin(),
    21. vec.begin() + 1,
    22. vec.end(),
    23. std::greater<std::pair<float, int>>());
    24. // out
    25. float *out_ptr = output->mutable_data<float>() + n * out_channel + k;
    26. *out_ptr = vec[0].second;
    27. }
    28. }
    29. }
  • 在paddlelite/lite/backends/arm/math/CMakeFile.txt中的math_arm library中添加argmax.cc,在paddlelite/lite/backends/arm/math/funcs.h中添加#include "lite/arm/math/argmax.h"

5. 添加Argmax单测

  • 在paddlelite/lite/tests/kernels目录下新建argmax_compute_test.cc文件,声明并实现ArgmaxComputeTester类;

  • ArgmaxComputeTester类中主要包括PrepareOpDesc、PrepareData和RunBaseline函数。PrepareOpDesc函数设定单测op的类型和输入输出参数,PrepareData函数对输入tensor进行初始化,RunBaseline是基于输入计算得到输出,用于和框架计算的输出进行对比;

  • 使用gtest添加单测,代码如下:

    1. TEST(Argmax, precision) {
    2. #ifdef LITE_WITH_ARM
    3. LOG(INFO) << "test argmax arm";
    4. Place place(TARGET(kARM));
    5. for (int axis : {0, 1, 2, 3}) {
    6. for (int n : {1, 3}) {
    7. for (int c : {3, 6}) {
    8. for (int h : {9, 18}) {
    9. for (int w : {9, 18}) {
    10. std::unique_ptr<arena::TestCase> tester(
    11. new ArgmaxComputeTester(place, "def", axis, n, c, h, w));
    12. arena::Arena arena(std::move(tester), place, 2e-5);
    13. arena.TestPrecision();
    14. }
    15. }
    16. }
    17. }
    18. }
    19. #endif
    20. }
  • 在paddlelite/lite/tests/kernels/CMakeLists.txt中添加

    1. lite_cc_test(test_kernel_argmax_compute SRCS argmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})

6. 编译运行

  • 在paddlelite目录中,执行./lite/tools/ci_build.sh build_test_arm,该脚本会创建手机模拟器,并编译运行所有单测(花费时间较久)。如果运行无误,则表明添加argmax成功。