RandomSampler

class paddle.io.RandomSampler ( data_source=None, replacement=False, num_samples=None, generator=None ) [源代码]

随机迭代样本,产生重排下标,如果 replacement = False ,则会采样整个数据集;如果 replacement = True ,则会按照 num_samples 指定的样本数采集。

参数

  • data_source (Dataset) - 此参数必须是 paddle.io.Datasetpaddle.io.IterableDataset 的一个子类实例或实现了 __len__ 的Python对象,用于生成样本下标。默认值为None。

  • replacement (bool) - 如果为 False 则会采样整个数据集,如果为 True 则会按 num_samples 指定的样本数采集。默认值为 False

  • num_samples (int) - 如果 replacement 设置为 True 则按此参数采集对应的样本数。默认值为None。

  • generator (Generator) - 指定采样 data_source 的采样器。默认值为None。

返回

RandomSampler, 返回随机采样下标的采样器

代码示例

  1. from paddle.io import Dataset, RandomSampler
  2. class RandomDataset(Dataset):
  3. def __init__(self, num_samples):
  4. self.num_samples = num_samples
  5. def __getitem__(self, idx):
  6. image = np.random.random([784]).astype('float32')
  7. label = np.random.randint(0, 9, (1, )).astype('int64')
  8. return image, label
  9. def __len__(self):
  10. return self.num_samples
  11. sampler = RandomSampler(data_source=RandomDataset(100))
  12. for index in sampler:
  13. print(index)