Sampler

class paddle.io. Sampler ( data_source=None ) [源代码]

概括数据集采样器行为和方法的基类。

所有数据集采样器必须继承这个基类,并实现以下方法:

__iter__: 迭代返回数据样本下标

__len__: data_source 中的样本数

参数:

  • data_source (Dataset) - 此参数必须是 paddle.io.Datasetpaddle.io.IterableDataset 的一个子类实例或实现了 __len__ 的Python对象,用于生成样本下标。默认值为None。

可见 paddle.io.BatchSamplerpaddle.io.DataLoader

返回:返回样本下标的迭代器。

返回类型: Sampler

代码示例

  1. from paddle.io import Dataset, Sampler
  2. class RandomDataset(Dataset):
  3. def __init__(self, num_samples):
  4. self.num_samples = num_samples
  5. def __getitem__(self, idx):
  6. image = np.random.random([784]).astype('float32')
  7. label = np.random.randint(0, 9, (1, )).astype('int64')
  8. return image, label
  9. def __len__(self):
  10. return self.num_samples
  11. class MySampler(Sampler):
  12. def __init__(self, data_source):
  13. self.data_source = data_source
  14. def __iter__(self):
  15. return iter(range(len(self.data_source)))
  16. def __len__(self):
  17. return len(self.data_source)
  18. sampler = MySampler(data_source=RandomDataset(100))
  19. for index in sampler:
  20. print(index)