ProgramTranslator

class paddle.jit.ProgramTranslator [源代码]

将动态图函数转为静态图函数的类。该类是个单例(singleton)。

参数:

无。

返回:ProgramTranslator 单例对象。

示例代码

  1. import paddle
  2. # 以下两种调用方法得到同一个对象,因为ProgramTranslator是个单例
  3. paddle.jit.ProgramTranslator()
  4. paddle.jit.ProgramTranslator.get_instance()

enable ( enable_static )

全局开启或关闭动态图转化为静态图。

参数:

  • enable_static (bool) - 设置True或者False来打开或关闭动静转化 。

返回:None。

示例代码

  1. import paddle
  2. @paddle.jit.to_static
  3. def func(x):
  4. if paddle.mean(x) > 0:
  5. x_v = x - 1
  6. else:
  7. x_v = x + 1
  8. return x_v
  9. prog_trans = paddle.jit.ProgramTranslator()
  10. prog_trans.enable(False)
  11. x = paddle.ones([1, 2])
  12. # ProgramTranslator被关闭所以func会以动态图模式运行
  13. print(func(x)) # [[0. 0.]]

get_output ( dygraph_func, \args, **kwargs* )

返回动态图函数输出的Tensor,但是该动态图函数的数值计算过程会被转化为静态图模式运行。

参数:

  • dygraph_func (callable) - 动态图函数。

  • args, kwargs - 动态图函数的输入。

返回:包含数值结果的Tensor或者Tensor的元组,是输入动态图函数的返回值。

示例代码

  1. import paddle
  2. def func(x):
  3. if paddle.mean(x) > 0:
  4. x_v = x - 1
  5. else:
  6. x_v = x + 1
  7. return x_v
  8. prog_trans = paddle.jit.ProgramTranslator()
  9. x = paddle.ones([1, 2])
  10. x_v = prog_trans.get_output(func, x)
  11. print(x_v) # [[0. 0.]]

get_func ( dygraph_func )

返回一个可调用函数,该函数将输入动态图函数接口转化为静态图组网接口。组网接口不像动态图接口,其并不直接返回数据结果。用户需要自行处理对应的Program和Eexecutor。

参数:

  • dygraph_func (callable) - 动态图函数。

返回:将动态图接口转为静态图组网接口的可调用函数。

示例代码

  1. import paddle
  2. def func(x):
  3. if paddle.mean(x) > 0:
  4. x_v = x - 1
  5. else:
  6. x_v = x + 1
  7. return x_v
  8. prog_trans = paddle.jit.ProgramTranslator()
  9. static_func = prog_trans.get_func(func)
  10. print(callable(static_func)) # True

get_program ( dygraph_func, \args, **kwargs* )

返回动态图函数转化后的静态图Program和输入输出Varaible。用户可以使用Executor来执行该Program。

参数:

  • dygraph_func (callable) - 动态图函数。

  • args, kwargs - 动态图函数的输入。

返回:元组(main_program, startup_program, inputs, outputs)

main_program: 转化后的main program。 startup_program: 转化后的startup program。 inputs: 输入Tensor的列表,这些Tensor可以在执行去feed。 outputs: 输出Tensor的列表,这些Tensor可以在运行时被fetch。

示例代码

  1. import paddle
  2. def func(x):
  3. if paddle.mean(x) > 0:
  4. x_v = x - 1
  5. else:
  6. x_v = x + 1
  7. return x_v
  8. prog_trans = paddle.jit.ProgramTranslator()
  9. x = paddle.ones([1, 2])
  10. main_prog, start_prog, inputs, outputs = prog_trans.get_program(func, x)
  11. print([i.name for i in inputs])
  12. # [u'generated_tensor_0'] 需要被feed的输入Tensor名字,对应x
  13. print([o.name for o in outputs])
  14. # [u'_generated_var_4'] 需要被fetch的输出Tensor名字,对应x_v

get_code ( dygraph_func )

返回动态图函数转化后的静态图代码字符串。

参数:

  • dygraph_func (callable) - 动态图函数。

返回:转化后的静态图代码字符串。

示例代码

  1. import paddle
  2. def func(x):
  3. if paddle.mean(x) > 0:
  4. x_v = x - 1
  5. else:
  6. x_v = x + 1
  7. return x_v
  8. prog_trans = paddle.jit.ProgramTranslator()
  9. code = prog_trans.get_code(func)
  10. print(type(code)) # <class 'str'>

get_program_cache ( )

返回ProgramCache单例。这个方法是PaddlePaddle开发者用来管理ProgramTranslator中的Program缓存,普通用户不需要使用这个方法。

返回:ProgramTranslator中的ProgramCache。

示例代码

  1. import paddle
  2. prog_trans = paddle.jit.ProgramTranslator()
  3. prog_cache = prog_trans.get_program_cache()