create_py_reader_by_data

注意:该API仅支持【静态图】模式

  • paddle.fluid.layers.create_py_reader_by_data(capacity, feed_list, name=None, use_double_buffer=True)[源代码]

创建一个Python端提供数据的reader。该OP与 py_reader 类似,不同点在于它能够从feed变量列表读取数据。

  • 参数:
    • capacity (int) - py_reader 维护的队列缓冲区的容量大小。单位是batch数量。若reader读取速度较快,建议设置较大的 capacity 值。
    • feed_list (list(Variable)) - feed变量列表,这些变量一般由 fluid.data() 创建。
    • name (str,可选) – 具体用法请参见 Name ,一般无需设置,默认值为None。
    • use_double_buffer (bool,可选) - 是否使用双缓冲区,双缓冲区是为了预读下一个batch的数据、异步CPU -> GPU拷贝。默认值为True。

返回:能够从feed变量列表读取数据的reader,数据类型和feed变量列表中变量的数据类型相同。

返回类型:reader

代码示例:

  1. import paddle
  2. import paddle.fluid as fluid
  3. import paddle.dataset.mnist as mnist
  4.  
  5. def network(img, label):
  6. # 用户构建自定义网络,此处以一个简单的线性回归为例。
  7. predict = fluid.layers.fc(input=img, size=10, act='softmax')
  8. loss = fluid.layers.cross_entropy(input=predict, label=label)
  9. return fluid.layers.mean(loss)
  10.  
  11. MEMORY_OPT = False
  12. USE_CUDA = False
  13.  
  14. image = fluid.data(name='image', shape=[None, 1, 28, 28], dtype='float32')
  15. label = fluid.data(name='label', shape=[None, 1], dtype='int64')
  16. reader = fluid.layers.create_py_reader_by_data(capacity=64,
  17. feed_list=[image, label])
  18. reader.decorate_paddle_reader(
  19. paddle.reader.shuffle(paddle.batch(mnist.train(), batch_size=5), buf_size=500))
  20. img, label = fluid.layers.read_file(reader)
  21. loss = network(img, label) # 用户构建自定义网络并返回损失函数
  22.  
  23. place = fluid.CUDAPlace(0) if USE_CUDA else fluid.CPUPlace()
  24. exe = fluid.Executor(place)
  25. exe.run(fluid.default_startup_program())
  26.  
  27. build_strategy = fluid.BuildStrategy()
  28. build_strategy.memory_optimize = True if MEMORY_OPT else False
  29. exec_strategy = fluid.ExecutionStrategy()
  30. compiled_prog = fluid.compiler.CompiledProgram(
  31. fluid.default_main_program()).with_data_parallel(
  32. loss_name=loss.name,
  33. build_strategy=build_strategy,
  34. exec_strategy=exec_strategy)
  35.  
  36. for epoch_id in range(2):
  37. reader.start()
  38. try:
  39. while True:
  40. exe.run(compiled_prog, fetch_list=[loss.name])
  41. except fluid.core.EOFException:
  42. reader.reset()