QueueyDataset

class paddle.distributed.QueueDataset [源代码]

QueueyDataset是流式处理数据使用Dataset类。与InmemoryDataset继承自同一父类,用于单机训练,不支持分布式大规模参数服务器相关配置和shuffle。此类由paddle.distributed.QueueDataset直接创建。

代码示例:

  1. import paddle
  2. dataset = paddle.distributed.QueueDataset()

init ( \*kwargs* )

注意:

1. 该API只在非 Dygraph 模式下生效

对QueueDataset的实例进行配置初始化。

参数:

  • kwargs - 可选的关键字参数,由调用者提供, 目前支持以下关键字配置。

  • batch_size (int) - batch size的大小. 默认值为1。

  • thread_num (int) - 用于训练的线程数, 默认值为1。

  • use_var (list) - 用于输入的variable列表,默认值为[]。

  • input_type (int) - 输入到模型训练样本的类型. 0 代表一条样本, 1 代表一个batch。 默认值为0。

  • fs_name (str) - hdfs名称. 默认值为””。

  • fs_ugi (str) - hdfs的ugi. 默认值为””。

  • pipe_command (str) - 在当前的 dataset 中设置的pipe命令用于数据的预处理。pipe命令只能使用UNIX的pipe命令,默认为”cat”。

  • download_cmd (str) - 数据下载pipe命令。 pipe命令只能使用UNIX的pipe命令, 默认为”cat”。

返回:None。

代码示例

  1. import paddle
  2. import os
  3. paddle.enable_static()
  4. with open("test_queue_dataset_run_a.txt", "w") as f:
  5. data = "2 1 2 2 5 4 2 2 7 2 1 3\n"
  6. data += "2 6 2 2 1 4 2 2 4 2 2 3\n"
  7. data += "2 5 2 2 9 9 2 2 7 2 1 3\n"
  8. data += "2 7 2 2 1 9 2 3 7 2 5 3\n"
  9. f.write(data)
  10. with open("test_queue_dataset_run_b.txt", "w") as f:
  11. data = "2 1 2 2 5 4 2 2 7 2 1 3\n"
  12. data += "2 6 2 2 1 4 2 2 4 2 2 3\n"
  13. data += "2 5 2 2 9 9 2 2 7 2 1 3\n"
  14. data += "2 7 2 2 1 9 2 3 7 2 5 3\n"
  15. f.write(data)
  16. slots = ["slot1", "slot2", "slot3", "slot4"]
  17. slots_vars = []
  18. for slot in slots:
  19. var = paddle.static.data(
  20. name=slot, shape=[None, 1], dtype="int64", lod_level=1)
  21. slots_vars.append(var)
  22. dataset = paddle.distributed.QueueDataset()
  23. dataset.init(
  24. batch_size=1,
  25. thread_num=2,
  26. input_type=1,
  27. pipe_command="cat",
  28. use_var=slots_vars)
  29. dataset.set_filelist(
  30. ["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"])
  31. paddle.enable_static()
  32. place = paddle.CPUPlace()
  33. exe = paddle.static.Executor(place)
  34. startup_program = paddle.static.Program()
  35. main_program = paddle.static.Program()
  36. exe.run(startup_program)
  37. exe.train_from_dataset(main_program, dataset)
  38. os.remove("./test_queue_dataset_run_a.txt")
  39. os.remove("./test_queue_dataset_run_b.txt")

set_filelist ( filelist )

在当前的worker中设置文件列表。

代码示例:

  1. import paddle
  2. import os
  3. paddle.enable_static()
  4. with open("test_queue_dataset_run_a.txt", "w") as f:
  5. data = "2 1 2 2 5 4 2 2 7 2 1 3\n"
  6. data += "2 6 2 2 1 4 2 2 4 2 2 3\n"
  7. data += "2 5 2 2 9 9 2 2 7 2 1 3\n"
  8. data += "2 7 2 2 1 9 2 3 7 2 5 3\n"
  9. f.write(data)
  10. with open("test_queue_dataset_run_b.txt", "w") as f:
  11. data = "2 1 2 2 5 4 2 2 7 2 1 3\n"
  12. data += "2 6 2 2 1 4 2 2 4 2 2 3\n"
  13. data += "2 5 2 2 9 9 2 2 7 2 1 3\n"
  14. data += "2 7 2 2 1 9 2 3 7 2 5 3\n"
  15. f.write(data)
  16. dataset = paddle.distributed.QueueDataset()
  17. slots = ["slot1", "slot2", "slot3", "slot4"]
  18. slots_vars = []
  19. for slot in slots:
  20. var = paddle.static.data(
  21. name=slot, shape=[None, 1], dtype="int64", lod_level=1)
  22. slots_vars.append(var)
  23. dataset.init(
  24. batch_size=1,
  25. thread_num=2,
  26. input_type=1,
  27. pipe_command="cat",
  28. use_var=slots_vars)
  29. filelist = ["a.txt", "b.txt"]
  30. dataset.set_filelist(filelist)
  31. os.remove("./test_queue_dataset_run_a.txt")
  32. os.remove("./test_queue_dataset_run_b.txt")

参数:

  • filelist (list[string]) - 文件列表