新增Pass
本文从三个方面介绍了Lite中的Pass结构:Pass是什么、Pass的实现与接口、Pass的一般注册流程。最后以Fc_fuse_pass为例介绍了fusion_pass的作用与注册方法。
前述:Pass是什么?
CxxPredictor加载模型后,在执行预测前会先优化模型。模型优化过程是通过Pass实现的。
具体调用关系如下: 
CreatePredictor(CxxConfig)函数调用了Predictor->Build(CxxConfig)CxxPredictor的构建过程(Build)分为两步:
Predictor->LoadModel() 加载模型文件到program中
Predicotr->optimizer_.Run() 对Program中的原始图形结构进行优化
- 对图结构的优化是通过调用
Pass->Apply(const std::unique_ptr<SSAGraph>& graph)方法实现的。
- 对图结构的优化是通过调用
每一类Pass定义了一种优化过程,包括:原模型中的kernel选取、OP融合、冗余OP去除、子图创建、内存优化、类型推导、类型转换等。
Pass的实现与接口 :Pass基类、PassManager和Pass注册
1、Pass基类:paddle::lite::mir::Pass
class Pass {public:// Pass的类型,Pass按照作用的不同可以分为三种enum class Kind { //种类的作用不太清楚// 1. 修改模型中的图拓扑结构的PasskProgramWise = 0,// 2. 不修改图结构,修改状态的PasskStmtWise,// 3. 不修改 IR,用于搜集信息和可视化信息的Pass.kDebug,};// 主要实现函数:Apply 函数定义了 Pass 运行时执行的操作virtual void Apply(const std::unique_ptr<SSAGraph>& graph) = 0;bool is_program_pass() const { return kind_ == Kind::kProgramWise; }bool is_stmt_pass() const { return kind_ == Kind::kStmtWise; }virtual ~Pass() = default;private:const Kind kind_; // pass 的种类std::string name_; // pass 的名称std::set<TargetType> bound_targets_; // 指定了Pass运行的硬件平台,模型优化过程会根据当前硬件平台是否匹配筛选Pass。std::unordered_map<std::string, std::set<lite_api::Place>> bound_kernels_; // 绑定的kernel};// Different kinds.class ProgramPass : public Pass {public:ProgramPass() : Pass(Kind::kProgramWise) {}};class StmtPass : public Pass {public:StmtPass() : Pass(Kind::kStmtWise) {}};class DebugPass : public Pass {public:DebugPass() : Pass(Kind::kDebug) {}};
代码位置:lite/core/mir/pass.h
主要类成员:
const Kind kind_: Pass类型。pass 有三种基本基本类型 :修改图结构的ProgramPass、修改状态量的StmtPass和Debug过程采集信息与控制可视化的DebugPass。std::string name_:pass 的名称std::set<TargetType> bound_targets_: Pass运行的硬件平台,optimizer.Run()优化过程会根据硬件平台选择匹配的Pass。——根据硬件平台自动选择需要的passstd::unordered_map<std::string, std::set<lite_api::Place>> bound_kernels_: Pass 绑定的kernel (what’s this used for)
主要接口:
Pass::Apply(const std::unique_ptr& graph): Pass优化过程的具体操作,是新注册Pass需要实现的接口。输入为SSAGraph型指针,是对模型结构的拓扑表示。
2、Pass管理 paddle::lite::mir::PassManager
class PassManager {public:// 内部静态变量PassManager,用来存储使用的Pass和图优化操作static PassManager& Global() {static PassManager x;return x;}// 执行所有的 Passvoid Run(const std::unique_ptr<SSAGraph>& graph) {for (auto& pass : passes_) {LOG(INFO) << "Running MIR pass " << pass->name();pass->Apply(graph);}private:std::list<std::unique_ptr> passes_; //存储所有的 Passstd::map<std::string, mir::Pass*> pass_map_; //使用map变量存储 PassName::Pass}
代码位置:lite/core/mir/pass_manager.h
主要类成员:
std::list:unique_ptr> passes_;: List类型,存储了所有已注册Pass。std::map<std::string, mir::Pass*> pass_map_;: Map类型,存储了所有”Pass名称-Pass类”键对,用于根据名称查找Pass。
主要接口:
static PassManager& Global()返回PassManager全局静态变量,该变量存储了所有已注册的Passbool AddNewPass(const std::string& name, Pass* pass)添加新的Pass到PassManager中
3、 Pass 注册 paddle::lite::mir::PassRegistry
代码位置:lite/core/mir/pass_registry.h
主要接口:
REGISTER_MIR_PASS(name__, class__):宏定义函数,用于注册Pass。注册Pass过程实现的是 -PassManager::Global().AddNewPass(name__, class__),将新注册Pass添加到全局变量PassManager中。
Pass的一般注册流程与使用方法
1. Pass 注册流程
在lite/core/mir或其子目录下继承Pass基类,实现Pass::Apply接口,并使用宏REGISTER_MIR_PASS(name__, class__)将Pass注册到PassManager即完成了新Pass注册。
以新建 new_demo_pass为例,具体流程如下:
(1)在lite/core/mir路径下新建example_pass.cc 和 new_demo_pass.h 文件
(2)在example_pass.h 文件中继承Pass基类(ProgramPass、StmtPass或DebugPass)定义自己的Pass类。
#include "lite/core/mir/pass.h"namespace paddle {namespace lite {namespace mir {class ExamplePass : public ProgramPass {void Apply(const std::unique_ptr<SSAGraph> &graph) override {}...};} // namespace mir} // namespace lite} // namespace paddle
(3)在example_pass.cc 文件中实现ExamplePass::Apply()接口,并注册ExamplePass
#include "lite/core/mir/pass_registry.h"#include "lite/core/mir/example_pass.h"namespace paddle {namespace lite {namespace mir {void ExamplePass::Apply(const std::unique_ptr<SSAGraph>& graph) {...}} // namespace mir} // namespace lite} // namespace paddleREGISTER_MIR_PASS(example_pass, paddle::lite::mir::ExamplePass).BindTargets({TARGET(kARM)}); // Pass执行的目标硬件平台// .BindKernel("conv2d"); //Pass绑定的 kernel
(4)修改lite/core/mir/CMakeLists.txt文件,将example_pass.cc 编译到mir_passes库中
lite_cc_library(mir_passesSRCSdemo_pass.cc // 新建的Pass文件...memory_optimize_pass.ccDEPS mir_pass types context ${mir_fusers} ${subgraph_passes})
2. Pass使用流程
将Pass注册到PassManager后不会自动生效。需要在optimizer->run() 函数中添加该Pass才会在模型优化过程中调用。
(1)在paddle_use_passes.h文件中调用该Pass
#include "paddle_lite_factory_helper.h" // NOLINT...USE_MIR_PASS(new_demo_pass); //调用 new_demo_pass
(2)要想在优化模型时调用该Pass,需要在optimizer->run()函数中手动添加调用。
修改lite/core/optimizer.h文件,添加new_demo_pass到Optimizer::Run()函数;
class Optimizer {public:void Run(...) {...if (passes.empty()) {RunPasses(std::vector<std::string>{{"new_demo_pass" //将新注册的Pass添加在这里...}...}
(3)只有CxxPredictor才会在模型加载后根据Pass优化模型。
...#include "paddle_use_passes.h" // 引用Pass优化模型void RunModel() {// 1. 创建 CxxConfigCxxConfig config;config.set_model_dir(FLAGS_model_dir);config.set_valid_places(Place{TARGET(kARM), PRECISION(kFloat)});// 2. 创建CxxPredictor,该过程包括加载模型和用Pass优化模型std::shared_ptr> predictor =Creat<CxxConfig>(config);}
Fusion Pass的定义与注册
Fusion Pass是一种常见图结构优化Pass,可将多个连续OP融合成单个等效OP,减少数据交换并简化图结构。Pass运行时调用Fuser自动查找并替换指定图结构,所以注册FuserPass时还需要实现对应的Fuser类。
下面以fc_fuse_pass为例,详细说明FusionPass的效果和注册方法。
fc_fuse_pass的作用
将相邻的mul算子和 element_wise add 算子 融合成一个 FC 算子
mul(X) = X * Welementwise_add( mul(x) ) = X * W + Bias//----------> after fusionFC(X) = X * W +Bias
Pass 运行效果如下:
mul和elementwise_add的原有参数映射到FC的参数上: 
fc_fuse_pass的注册方法
1、创建FcFuser
(1)在lite/core/mir/fusion路径下新建fc_fuser.cc 和 fc_fuser.h 文件
(2)在fc_fuser.h 文件中继承FuseBase定义自己的Fuser类。
#include "lite/core/mir/pattern_matcher_high_api.h"namespace paddle {namespace lite {namespace mir {namespace fusion {class FcFuser : public FuseBase {public:void BuildPattern() override;void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;private:cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;};} // namespace fusion} // namespace mir} // namespace lite} // namespace paddle
主要接口: FuseBase::BuildPattern : 描述需要替换位置的图结构(pattern),Fuser运行时会自动查找并替换该pattern。 FuseBase::GenOpDesc : 创建融合后的等效Fused_op。 FuseBase::InsertNewNode :用Fused_op替换原始图结构(pattern)。
对于 FcFuser:BuildPattern描述的Pattern是mul+elementwise add,GenOpDesc创建的FC_op,InsertNewNode函数的效果是用新建的FC_op替换模型中的mul+elementwise add pattern。
(3) 在fc_fuser.cc文件中实现 BuildPattern() 、GenOpDesc()、InsertNewNode() 接口
下面以FcFuser为例介绍三种接口的实现:
// 1. BuildPattern函数,描述需要替换的图结构// FcFuser::BuildPattern() 描述了 mul + element_wise add 图结构void FcFuser::BuildPattern() {// (1) 用OpNode描述和VarNode// mul OPauto* mul = OpNode("mul", "mul");// mul OP 的输入和输出auto* x = VarNode("x")->assert_is_op_input("mul", "X");auto* W = VarNode("W")->assert_is_op_input("mul", "Y");auto* mul_out = VarNode("mul_out");// elementwise_add OPauto* add = OpNode("add", "elementwise_add");//elementwise_add 的输入auto* b = VarNode("b")->assert_is_persistable_var();// elementwise_add OP的输出(最终输出)auto* Out = VarNode("Out");//(2) 描述拓扑连接 (Fuse之前mul 和elementwise_add的连接)std::vector<PMNode*> mul_inputs{W, x};std::vector<PMNode*> add_inputs{mul_out, b};mul_inputs >> *mul >> *mul_out;add_inputs >> *add >> *Out;//(3) 声明新的拓扑结构中将会被移除的节点,包括被fuse的OP和OP之间的中间变量mul_out->AsIntermediate();mul->AsIntermediate();add->AsIntermediate();}// 2. GenOpDesc函数新建等效 Fused_op// FcFuser::GenOpDesc() 新建了Fc_opcpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) {// (1) 得到第一个OP节点的 OpDesc ,并清空输入输出信息cpp::OpDesc op_desc = *matched.at("mul")->stmt()->op_info();op_desc.mutable_inputs()->clear();op_desc.mutable_outputs()->clear();// (2) 修改OpDesc , 将OpType设置为 "fc" (FC OP 的OP_type),op_desc.SetType("fc");// (3) 设置OpDesc中的Input、Output、Attrbute。分别连接到BuildPattern()函数中创建的VarNodeop_desc.SetInput("Input", {matched.at("x")->arg()->name});op_desc.SetInput("W", {matched.at("W")->arg()->name});op_desc.SetInput("Bias", {matched.at("b")->arg()->name});op_desc.SetOutput("Out", {matched.at("Out")->arg()->name});op_desc.SetAttr("in_num_col_dims",matched.at("mul")->stmt()->op_info()->GetAttr<int>("x_num_col_dims"));return op_desc;}// 3. InsertNewNode函数用Fused OP 替换模型图中的原始 Pattern// FcFuser::InsertNewNode() 用Fc_OP替换原始模型图中的 " mul + element_wise add "void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {// (1) 创建FC OP的参数(OpDesc)auto op_desc = GenOpDesc(matched);// 创建一个 FC OPauto fc_op = LiteOpRegistry::Global().Create("fc");// 找到原拓扑结构中的scope (作用域)和 valid_places (可支持设备类型)auto mul = matched.at("mul")->stmt()->op();auto* scope = mul->scope();auto& valid_places = mul->valid_places();// (2) 将 FC OP的 scope和 valid_places设置与fuse前相同,并在图中创建该节点(node)fc_op->Attach(op_desc, scope);auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places);// (3) 将FC节点连接到输入输出(var_node)IR_NODE_LINK_TO(matched.at("W"), new_op_node);IR_NODE_LINK_TO(matched.at("x"), new_op_node);IR_NODE_LINK_TO(matched.at("b"), new_op_node);IR_NODE_LINK_TO(new_op_node, matched.at("Out"));}
2、注册fc_fuse_pass
(1)在lite/core/mir/fusion路径下新建fc_fuse_pass.cc 和 fc_fuse_pass.h 文件
(2)在fc_fuse_pass.h 文件中,继承ProgramPass定义FcFusePass。
#include "lite/core/mir/pass.h"namespace paddle {namespace lite {namespace mir {class FcFusePass : public ProgramPass {public:void Apply(const std::unique_ptr<SSAGraph>& graph) override; namespace mir namespace lite namespace paddle
(3)在fc_fuse_pass.cc 文件中实现FcFusePass::Apply()接口,并注册FcFusePass
#include "lite/core/mir/pass_registry.h"#include "lite/core/mir/example_pass.h"namespace paddle {namespace lite {namespace mir {void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {fusion::FcFuser fuser;fuser(graph.get());namespace mir} // namespace lite} // namespace paddleREGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass).BindTargets({TARGET(kAny)}) // FcFusePass 可以在任何硬件平台执行.BindKernel("fc"); // FcFusePass 绑定 fc_kernel
(4)修改lite/core/mir/fusion/CMakeLists.txt文件,将fc_fuser.cc 编译到mir_fusers库
lite_cc_library(fuse_fcSRCS fc_fuser.ccDEPS pattern_matcher_high_api)set(mir_fusersfuse_fc...CACHE INTERNAL "fusers")
(5)修改lite/core/mir/CMakeLists.txt文件,将fc_fuse_pass.cc 编译到mir_pass库
lite_cc_library(mir_passesSRCSfusion/fc_fuse_pass.cc...DEPS mir_pass types context ${mir_fusers} ${subgraph_passes})
3、使用 fc_fuse_pass
(1) lite/api/paddle_use_passes.h使用USE_LITE_PASS宏来引入新加入的pass
USE_MIR_PASS(lite_fc_fuse_pass);
(2) 在lite/core/optimizer.h文件的Optimizer::Run()函数中添加新注册的pass
class Optimizer {public:void Run(Program&& program,const std::vector<Place>& valid_places,core::KernelPickFactor kernel_pick_factor,const std::vector<std::string>& passes = {}) {...if (passes.empty()) {RunPasses(std::vector<std::string>{{"lite_fc_fuse_pass", // the newly registered pass..."argument_type_display_pass"}});} else {RunPasses(passes);}exec_scope_ = program.exec_scope();}
(3) 以上修改完成后,在CreatePredictor(CxxConfig)创建CxxPredictor时,模型优化过程会调用lite_fc_fuse_pass ,扫描mul + element_wise add结构并替换为等效的Fc_OP。
