新增Layout

Paddle-Lite中Place包含了Target、Layout、Precision信息,用来注册和选择模型中的具体Kernel。下面以增加Place中的layout:ImageDefaultImageFolderImageNW为例,讲解如何增加新Layout。

根据在lite/core/lite/api目录下以NHWC为关键词检索代码,发现需要分别在以下的文件中加入Layout内容:

  1. lite/api/paddle_place.h
  2. lite/api/paddle_place.cc
  3. lite/api/python/pybind/pybind.cc
  4. lite/core/op_registry.h
  5. lite/core/op_registry.cc

1. lite/api/paddle_place.h

enum class DataLayoutType中加入对应的Layout,注意已有的Layout不能改变值,增加新Layout递增即可:

  1. enum class DataLayoutType : int {
  2. kUnk = 0,
  3. kNCHW = 1,
  4. kNHWC = 3,
  5. kImageDefault = 4, // for opencl image2d
  6. kImageFolder = 5, // for opencl image2d
  7. kImageNW = 6, // for opencl image2d
  8. kAny = 2, // any data layout
  9. NUM = 7, // number of fields.
  10. };

2. lite/api/paddle_place.cc

本文件有3处修改,注意在DataLayoutToStr函数中加入对应Layout的字符串名,顺序为lite/api/paddle_place.h中枚举值的顺序:

  1. // 该文件第1处
  2. const std::string& DataLayoutToStr(DataLayoutType layout) {
  3. static const std::string datalayout2string[] = {
  4. "unk", "NCHW", "any", "NHWC", "ImageDefault", "ImageFolder", "ImageNW"};
  5. auto x = static_cast<int>(layout);
  6. CHECK_LT(x, static_cast<int>(DATALAYOUT(NUM)));
  7. return datalayout2string[x];
  8. }
  9. // 该文件第2处
  10. const std::string& DataLayoutRepr(DataLayoutType layout) {
  11. static const std::string datalayout2string[] = {"kUnk",
  12. "kNCHW",
  13. "kAny",
  14. "kNHWC",
  15. "kImageDefault",
  16. "kImageFolder",
  17. "kImageNW"};
  18. auto x = static_cast<int>(layout);
  19. CHECK_LT(x, static_cast<int>(DATALAYOUT(NUM)));
  20. return datalayout2string[x];
  21. }
  22. // 该文件第3处
  23. std::set<DataLayoutType> ExpandValidLayouts(DataLayoutType layout) {
  24. static const std::set<DataLayoutType> valid_set({DATALAYOUT(kNCHW),
  25. DATALAYOUT(kAny),
  26. DATALAYOUT(kNHWC),
  27. DATALAYOUT(kImageDefault),
  28. DATALAYOUT(kImageFolder),
  29. DATALAYOUT(kImageNW)});
  30. if (layout == DATALAYOUT(kAny)) {
  31. return valid_set;
  32. }
  33. return std::set<DataLayoutType>({layout});
  34. }

3. lite/api/python/pybind/pybind.cc

  1. // DataLayoutType
  2. py::enum_<DataLayoutType>(*m, "DataLayoutType")
  3. .value("NCHW", DataLayoutType::kNCHW)
  4. .value("NHWC", DataLayoutType::kNHWC)
  5. .value("ImageDefault", DataLayoutType::kImageDefault)
  6. .value("ImageFolder", DataLayoutType::kImageFolder)
  7. .value("ImageNW", DataLayoutType::kImageNW)
  8. .value("Any", DataLayoutType::kAny);

4. lite/core/op_registry.h

找到KernelRegister final中的using any_kernel_registor_t =,加入下面修改信息:

  1. // 找到KernelRegister final中的`using any_kernel_registor_t =`
  2. // 加入如下内容:
  3. KernelRegistryForTarget<TARGET(kOpenCL),
  4. PRECISION(kFP16),
  5. DATALAYOUT(kNCHW)> *, //
  6. KernelRegistryForTarget<TARGET(kOpenCL),
  7. PRECISION(kFP16),
  8. DATALAYOUT(kNHWC)> *, //
  9. KernelRegistryForTarget<TARGET(kOpenCL),
  10. PRECISION(kFP16),
  11. DATALAYOUT(kImageDefault)> *, //
  12. KernelRegistryForTarget<TARGET(kOpenCL),
  13. PRECISION(kFP16),
  14. DATALAYOUT(kImageFolder)> *, //
  15. KernelRegistryForTarget<TARGET(kOpenCL),
  16. PRECISION(kFP16),
  17. DATALAYOUT(kImageNW)> *, //
  18. KernelRegistryForTarget<TARGET(kOpenCL),
  19. PRECISION(kFloat),
  20. DATALAYOUT(kImageDefault)> *, //
  21. KernelRegistryForTarget<TARGET(kOpenCL),
  22. PRECISION(kFloat),
  23. DATALAYOUT(kImageFolder)> *, //
  24. KernelRegistryForTarget<TARGET(kOpenCL),
  25. PRECISION(kFloat),
  26. DATALAYOUT(kImageNW)> *, //
  27. KernelRegistryForTarget<TARGET(kOpenCL),
  28. PRECISION(kAny),
  29. DATALAYOUT(kImageDefault)> *, //
  30. KernelRegistryForTarget<TARGET(kOpenCL),
  31. PRECISION(kAny),
  32. DATALAYOUT(kImageFolder)> *, //
  33. KernelRegistryForTarget<TARGET(kOpenCL),
  34. PRECISION(kAny),
  35. DATALAYOUT(kImageNW)> *, //

5. lite/core/op_registry.cc

该文件有2处修改:

  1. // 该文件第1处
  2. #define CREATE_KERNEL1(target__, precision__) \
  3. switch (layout) { \
  4. case DATALAYOUT(kNCHW): \
  5. return Create<TARGET(target__), \
  6. PRECISION(precision__), \
  7. DATALAYOUT(kNCHW)>(op_type); \
  8. case DATALAYOUT(kAny): \
  9. return Create<TARGET(target__), \
  10. PRECISION(precision__), \
  11. DATALAYOUT(kAny)>(op_type); \
  12. case DATALAYOUT(kNHWC): \
  13. return Create<TARGET(target__), \
  14. PRECISION(precision__), \
  15. DATALAYOUT(kNHWC)>(op_type); \
  16. case DATALAYOUT(kImageDefault): \
  17. return Create<TARGET(target__), \
  18. PRECISION(precision__), \
  19. DATALAYOUT(kImageDefault)>(op_type); \
  20. case DATALAYOUT(kImageFolder): \
  21. return Create<TARGET(target__), \
  22. PRECISION(precision__), \
  23. DATALAYOUT(kImageFolder)>(op_type); \
  24. case DATALAYOUT(kImageNW): \
  25. return Create<TARGET(target__), \
  26. PRECISION(precision__), \
  27. DATALAYOUT(kImageNW)>(op_type); \
  28. default: \
  29. LOG(FATAL) << "unsupported kernel layout " << DataLayoutToStr(layout); \
  30. }
  31. // 该文件第2处
  32. // 找到文件中的下面的函数
  33. KernelRegistry::KernelRegistry()
  34. : registries_(static_cast<int>(TARGET(NUM)) *
  35. static_cast<int>(PRECISION(NUM)) *
  36. static_cast<int>(DATALAYOUT(NUM)))
  37. // 在该函数中加入新增Layout的下面内容
  38. INIT_FOR(kOpenCL, kFP16, kNCHW);
  39. INIT_FOR(kOpenCL, kFP16, kNHWC);
  40. INIT_FOR(kOpenCL, kFP16, kImageDefault);
  41. INIT_FOR(kOpenCL, kFP16, kImageFolder);
  42. INIT_FOR(kOpenCL, kFP16, kImageNW);
  43. INIT_FOR(kOpenCL, kFloat, kImageDefault);
  44. INIT_FOR(kOpenCL, kFloat, kImageFolder);
  45. INIT_FOR(kOpenCL, kFloat, kImageNW);
  46. INIT_FOR(kOpenCL, kAny, kImageDefault);
  47. INIT_FOR(kOpenCL, kAny, kImageFolder);
  48. INIT_FOR(kOpenCL, kAny, kImageNW);