create_py_reader_by_data

paddle.fluid.layers. create_py_reader_by_data ( capacity, feed_list, name=None, use_double_buffer=True ) [源代码]

创建一个Python端提供数据的reader。该OP与 py_reader 类似,不同点在于它能够从feed变量列表读取数据。

参数:

  • capacity (int) - py_reader 维护的队列缓冲区的容量大小。单位是batch数量。若reader读取速度较快,建议设置较大的 capacity 值。

  • feed_list (list(Variable)) - feed变量列表,这些变量一般由 fluid.data() 创建。

  • name (str,可选) – 具体用法请参见 Name ,一般无需设置,默认值为None。

  • use_double_buffer (bool,可选) - 是否使用双缓冲区,双缓冲区是为了预读下一个batch的数据、异步CPU -> GPU拷贝。默认值为True。

返回:能够从feed变量列表读取数据的reader,数据类型和feed变量列表中变量的数据类型相同。

返回类型:reader

代码示例:

  1. import paddle
  2. import paddle.fluid as fluid
  3. import paddle.dataset.mnist as mnist
  4. def network(img, label):
  5. # 用户构建自定义网络,此处以一个简单的线性回归为例。
  6. predict = fluid.layers.fc(input=img, size=10, act='softmax')
  7. loss = fluid.layers.cross_entropy(input=predict, label=label)
  8. return fluid.layers.mean(loss)
  9. MEMORY_OPT = False
  10. USE_CUDA = False
  11. image = fluid.data(name='image', shape=[None, 1, 28, 28], dtype='float32')
  12. label = fluid.data(name='label', shape=[None, 1], dtype='int64')
  13. reader = fluid.layers.create_py_reader_by_data(capacity=64,
  14. feed_list=[image, label])
  15. reader.decorate_paddle_reader(
  16. paddle.reader.shuffle(paddle.batch(mnist.train(), batch_size=5), buf_size=500))
  17. img, label = fluid.layers.read_file(reader)
  18. loss = network(img, label) # 用户构建自定义网络并返回损失函数
  19. place = fluid.CUDAPlace(0) if USE_CUDA else fluid.CPUPlace()
  20. exe = fluid.Executor(place)
  21. exe.run(fluid.default_startup_program())
  22. build_strategy = fluid.BuildStrategy()
  23. build_strategy.memory_optimize = True if MEMORY_OPT else False
  24. exec_strategy = fluid.ExecutionStrategy()
  25. compiled_prog = fluid.compiler.CompiledProgram(
  26. fluid.default_main_program()).with_data_parallel(
  27. loss_name=loss.name,
  28. build_strategy=build_strategy,
  29. exec_strategy=exec_strategy)
  30. for epoch_id in range(2):
  31. reader.start()
  32. try:
  33. while True:
  34. exe.run(compiled_prog, fetch_list=[loss.name])
  35. except fluid.core.EOFException:
  36. reader.reset()