TracedLayer

注意:该API仅支持【动态图】模式

  • class paddle.fluid.dygraph.TracedLayer(program, parameters, feed_names, fetch_names)[源代码]

TracedLayer用于将前向动态图模型转换为静态图模型,主要用于将动态图保存后做在线C++预测。除此以外,用户也可使用转换后的静态图模型在Python端做预测,通常比原先的动态图性能更好。

TracedLayer使用 ExecutorCompiledProgram 运行静态图模型。转换后的静态图模型与原动态图模型共享参数。

所有的TracedLayer对象均不应通过构造函数创建,而应通过调用静态方法 TracedLayer.trace(layer, inputs) 创建。

TracedLayer只能用于将data independent的动态图模型转换为静态图模型,即待转换的动态图模型不应随tensor数据或维度的变化而变化。

  • static trace(layer, inputs)

创建TracedLayer对象的唯一接口,该接口会调用 layer(*inputs) 方法运行动态图模型并将其转换为静态图模型。

  • 参数:
    • layer (dygraph.Layer) - 待追踪的动态图layer对象。
    • inputs (list(Variable)) - 动态图layer对象的输入变量列表。

返回: 包含2个元素的tuple,其中第一个元素是 layer(*inputs) 的输出结果,第二个元素是转换后得到的TracedLayer对象。

返回类型: tuple

代码示例

  1. import paddle.fluid as fluid
  2. from paddle.fluid.dygraph import Linear, to_variable, TracedLayer
  3. import numpy as np
  4.  
  5. class ExampleLayer(fluid.dygraph.Layer):
  6. def __init__(self):
  7. super(ExampleLayer, self).__init__()
  8. self._fc = Linear(3, 10)
  9.  
  10. def forward(self, input):
  11. return self._fc(input)
  12.  
  13. with fluid.dygraph.guard():
  14. layer = ExampleLayer()
  15. in_np = np.random.random([2, 3]).astype('float32')
  16. in_var = to_variable(in_np)
  17. out_dygraph, static_layer = TracedLayer.trace(layer, inputs=[in_var])
  18.  
  19. # 内部使用Executor运行静态图模型
  20. out_static_graph = static_layer([in_var])
  21. print(len(out_static_graph)) # 1
  22. print(out_static_graph[0].shape) # (2, 10)
  23.  
  24. # 将静态图模型保存为预测模型
  25. static_layer.save_inference_model(dirname='./saved_infer_model')
  • set_strategy(build_strategy=None, exec_strategy=None)

设置构建和执行静态图模型的相关策略。

  • 参数:
    • build_strategy (BuildStrategy, 可选) - TracedLayer内部 CompiledProgram 的构建策略。
    • exec_strategy (ExecutionStrategy, 可选) - TracedLayer内部 CompiledProgram 的执行策略。

返回: 无

代码示例

  1. import paddle.fluid as fluid
  2. from paddle.fluid.dygraph import Linear, to_variable, TracedLayer
  3. import numpy as np
  4.  
  5. class ExampleLayer(fluid.dygraph.Layer):
  6. def __init__(self):
  7. super(ExampleLayer, self).__init__()
  8. self._fc = Linear(3, 10)
  9.  
  10. def forward(self, input):
  11. return self._fc(input)
  12.  
  13. with fluid.dygraph.guard():
  14. layer = ExampleLayer()
  15. in_np = np.random.random([2, 3]).astype('float32')
  16. in_var = to_variable(in_np)
  17.  
  18. out_dygraph, static_layer = TracedLayer.trace(layer, inputs=[in_var])
  19.  
  20. build_strategy = fluid.BuildStrategy()
  21. build_strategy.enable_inplace = True
  22.  
  23. exec_strategy = fluid.ExecutionStrategy()
  24. exec_strategy.num_threads = 2
  25.  
  26. static_layer.set_strategy(build_strategy=build_strategy, exec_strategy=exec_strategy)
  27. out_static_graph = static_layer([in_var])
  • save_inference_model(dirname, feed=None, fetch=None)

将TracedLayer保存为用于预测部署的模型。保存的预测模型可被C++预测接口加载。

  • 参数:
    • dirname (str) - 预测模型的保存目录。
    • feed (list(int), 可选) - 预测模型输入变量的索引。若为None,则TracedLayer的所有输入变量均会作为预测模型的输入。默认值为None。
    • fetch (list(int), 可选) - 预测模型输出变量的索引。若为None,则TracedLayer的所有输出变量均会作为预测模型的输出。默认值为None。

返回: 无

代码示例

  1. import paddle.fluid as fluid
  2. from paddle.fluid.dygraph import Linear, to_variable, TracedLayer
  3. import numpy as np
  4.  
  5. class ExampleLayer(fluid.dygraph.Layer):
  6. def __init__(self):
  7. super(ExampleLayer, self).__init__()
  8. self._fc = Linear(3, 10)
  9.  
  10. def forward(self, input):
  11. return self._fc(input)
  12.  
  13. save_dirname = './saved_infer_model'
  14. in_np = np.random.random([2, 3]).astype('float32')
  15.  
  16. with fluid.dygraph.guard():
  17. layer = ExampleLayer()
  18. in_var = to_variable(in_np)
  19. out_dygraph, static_layer = TracedLayer.trace(layer, inputs=[in_var])
  20. static_layer.save_inference_model(save_dirname, feed=[0], fetch=[0])
  21.  
  22. place = fluid.CPUPlace()
  23. exe = fluid.Executor(place)
  24. program, feed_vars, fetch_vars = fluid.io.load_inference_model(save_dirname,
  25. exe)
  26.  
  27. fetch, = exe.run(program, feed={feed_vars[0]: in_np}, fetch_list=fetch_vars)
  28. print(fetch.shape) # (2, 10)