load_inference_model

注意:该API仅支持【静态图】模式

  • paddle.fluid.io.load_inference_model(dirname, executor, model_filename=None, params_filename=None, pserver_endpoints=None)[源代码]

从指定文件路径中加载预测模型(Inference Model),即调用该接口可获得模型结构(Inference Program)和模型参数。若只想加载预训练后的模型参数,请使用 load_params 接口。更多细节请参考 模型保存与加载

  • 参数:
    • dirname (str) – 待加载模型的存储路径。
    • executor (Executor) – 运行 Inference Model 的 executor ,详见 执行引擎
    • modelfilename (str,可选) – 存储Inference Program结构的文件名称。如果设置为None,则使用 \_model__ 作为默认的文件名。默认值为None。
    • params_filename (str,可选) – 存储所有模型参数的文件名称。当且仅当所有模型参数被保存在一个单独的二进制文件中,它才需要被指定。如果模型参数是存储在各自分离的文件中,设置它的值为None。默认值为None。
    • pserver_endpoints (list,可选) – 只有在分布式预测时才需要用到。当训练过程中使用分布式查找表(distributed lookup table)时, 预测时需要指定pserver_endpoints的值。它是 pserver endpoints 的列表,默认值为None。
  • 返回:该接口返回一个包含三个元素的列表(program,feed_target_names, fetch_targets)。它们的含义描述如下:
    • program (Program)– Program (详见 基础概念 )类的实例。此处它被用于预测,因此可被称为Inference Program。
    • feed_target_names (list)– 字符串列表,包含着Inference Program预测时所需提供数据的所有变量名称(即所有输入变量的名称)。
    • fetch_targets (list)– Variable (详见 基础概念 )类型列表,包含着模型的所有输出变量。通过这些输出变量即可得到模型的预测结果。

返回类型: 列表(list)

  • 抛出异常:
    • ValueError – 如果接口参数 dirname 指向一个不存在的文件路径,则抛出异常。

代码示例

  1. import paddle.fluid as fluid
  2. import numpy as np
  3.  
  4. # 构建模型
  5. main_prog = fluid.Program()
  6. startup_prog = fluid.Program()
  7. with fluid.program_guard(main_prog, startup_prog):
  8. data = fluid.layers.data(name="img", shape=[64, 784], append_batch_size=False)
  9. w = fluid.layers.create_parameter(shape=[784, 200], dtype='float32')
  10. b = fluid.layers.create_parameter(shape=[200], dtype='float32')
  11. hidden_w = fluid.layers.matmul(x=data, y=w)
  12. hidden_b = fluid.layers.elementwise_add(hidden_w, b)
  13. place = fluid.CPUPlace()
  14. exe = fluid.Executor(place)
  15. exe.run(startup_prog)
  16.  
  17. # 保存预测模型
  18. path = "./infer_model"
  19. fluid.io.save_inference_model(dirname=path, feeded_var_names=['img'],target_vars=[hidden_b], executor=exe, main_program=main_prog)
  20.  
  21. # 示例一: 不需要指定分布式查找表的模型加载示例,即训练时未用到distributed lookup table。
  22. [inference_program, feed_target_names, fetch_targets] = (fluid.io.load_inference_model(dirname=path, executor=exe))
  23. tensor_img = np.array(np.random.random((1, 64, 784)), dtype=np.float32)
  24. results = exe.run(inference_program,
  25. feed={feed_target_names[0]: tensor_img},
  26. fetch_list=fetch_targets)
  27.  
  28. # 示例二: 若训练时使用了distributed lookup table,则模型加载时需要通过endpoints参数指定pserver服务器结点列表。
  29. # pserver服务器结点列表主要用于分布式查找表进行ID查找时使用。下面的["127.0.0.1:2023","127.0.0.1:2024"]仅为一个样例。
  30. endpoints = ["127.0.0.1:2023","127.0.0.1:2024"]
  31. [dist_inference_program, dist_feed_target_names, dist_fetch_targets] = (
  32. fluid.io.load_inference_model(dirname=path,
  33. executor=exe,
  34. pserver_endpoints=endpoints))
  35.  
  36. # 在上述示例中,inference program 被保存在“ ./infer_model/__model__”文件内,
  37. # 参数保存在“./infer_mode ”单独的若干文件内。
  38. # 加载 inference program 后, executor可使用 fetch_targets 和 feed_target_names 执行Program,并得到预测结果。