torchvision.datasets

译者:BXuan694

所有的数据集都是torch.utils.data.Dataset的子类, 即:它们实现了__getitem____len__方法。因此,它们都可以传递给torch.utils.data.DataLoader,进而通过torch.multiprocessing实现批数据的并行化加载。例如:

  1. imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
  2. data_loader = torch.utils.data.DataLoader(imagenet_data,
  3. batch_size=4,
  4. shuffle=True,
  5. num_workers=args.nThreads)

目前为止,收录的数据集包括:

数据集

以上数据集的接口基本上很相近。它们至少包括两个公共的参数transformtarget_transform,以便分别对输入和和目标做变换。

  1. class torchvision.datasets.MNIST(root, train=True, transform=None, target_transform=None, download=False)

MNIST数据集。

参数:

  • root(string)– 数据集的根目录,其中存放processed/training.ptprocessed/test.pt文件。
  • train(bool, 可选)– 如果设置为True,从training.pt创建数据集,否则从test.pt创建。
  • download(bool, 可选)– 如果设置为True, 从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform (可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
  1. class torchvision.datasets.FashionMNIST(root, train=True, transform=None, target_transform=None, download=False)

Fashion-MNIST数据集。

参数:

  • root(string)– 数据集的根目录,其中存放processed/training.ptprocessed/test.pt文件。
  • train(bool, 可选)– 如果设置为True,从training.pt创建数据集,否则从test.pt创建。
  • download(bool, 可选)– 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
  1. class torchvision.datasets.EMNIST(root, split, **kwargs)

EMNIST数据集。

参数:

  • root(string)– 数据集的根目录,其中存放processed/training.ptprocessed/test.pt文件。
  • split(string)– 该数据集分成6种:byclassbymergebalancedlettersdigitsmnist。这个参数指定了选择其中的哪一种。
  • train(bool, 可选)– 如果设置为True,从training.pt创建数据集,否则从test.pt创建。
  • download(bool, 可选)– 如果设置为True, 从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform(可被调用 , 可选) – 一种函数或变换,输入目标,进行变换。

注意:

以下要求预先安装COCO API

  1. class torchvision.datasets.CocoCaptions(root, annFile, transform=None, target_transform=None)

MS Coco Captions数据集。

参数:

  • root(string)– 下载数据的目标目录。
  • annFile(string)– json标注文件的路径。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.ToTensor
  • target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。

示例

  1. import torchvision.datasets as dset
  2. import torchvision.transforms as transforms
  3. cap = dset.CocoCaptions(root = 'dir where images are',
  4. annFile = 'json annotation file',
  5. transform=transforms.ToTensor())
  6. print('Number of samples: ', len(cap))
  7. img, target = cap[3] # load 4th sample
  8. print("Image Size: ", img.size())
  9. print(target)

输出:

  1. Number of samples: 82783
  2. Image Size: (3L, 427L, 640L)
  3. [u'A plane emitting smoke stream flying over a mountain.',
  4. u'A plane darts across a bright blue sky behind a mountain covered in snow',
  5. u'A plane leaves a contrail above the snowy mountain top.',
  6. u'A mountain that has a plane flying overheard in the distance.',
  7. u'A mountain view with a plume of smoke in the background']
  1. __getitem__(index)
参数: index (int) – 索引
返回: 元组(image, target),其中target是列表类型,包含了对图片image的描述。
—- —-
返回类型: tuple
—- —-
  1. class torchvision.datasets.CocoDetection(root, annFile, transform=None, target_transform=None)

MS Coco Detection数据集。

参数:

  • root(string)– 下载数据的目标目录。
  • annFile(string)– json标注文件的路径。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.ToTensor
  • target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
  1. __getitem__(index)
参数: index (int) – 索引
返回: 元组(image, target),其中target是coco.loadAnns返回的对象。
—- —-
返回类型: tuple
—- —-
  1. class torchvision.datasets.LSUN(root, classes='train', transform=None, target_transform=None)

LSUN数据集。

参数:

  • root(string)– 存放数据文件的根目录。
  • classes(string list)– {‘train’, ‘val’, ‘test’}之一,或要加载类别的列表,如[‘bedroom_train’, ‘church_train’]。
  • transform(可被调用 , 可选) – 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
  1. __getitem__(index)
参数: index (int) – 索引
返回: 元组(image, target),其中target是目标类别的索引。
—- —-
Return type: tuple
—- —-
  1. class torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>)

一种通用数据加载器,其图片应该按照如下的形式保存:

  1. root/dog/xxx.png
  2. root/dog/xxy.png
  3. root/dog/xxz.png
  4. root/cat/123.png
  5. root/cat/nsdf3.png
  6. root/cat/asd932_.png

参数:

  • root(string)– 根目录路径。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
  • loader – 一种函数,可以由给定的路径加载图片。
  1. __getitem__(index)
参数: index (int) – 索引
返回: (sample, target),其中target是目标类的类索引。
—- —-
返回类型: tuple
—- —-
  1. class torchvision.datasets.DatasetFolder(root, loader, extensions, transform=None, target_transform=None)

一种通用数据加载器,其数据应该按照如下的形式保存:

  1. root/class_x/xxx.ext
  2. root/class_x/xxy.ext
  3. root/class_x/xxz.ext
  4. root/class_y/123.ext
  5. root/class_y/nsdf3.ext
  6. root/class_y/asd932_.ext

参数:

  • root(string)– 根目录路径。
  • loader(可被调用)– 一种函数,可以由给定的路径加载数据。
  • extensions(list[string])– 列表,包含允许的扩展。
  • transform(可被调用 , 可选)– 一种函数或变换,输入数据,返回变换之后的数据。如:对于图片有transforms.RandomCrop
  • target_transform – 一种函数或变换,输入目标,进行变换。
  1. __getitem__(index)
参数: index (int) – 索引
返回: (sample, target),其中target是目标类的类索引.
—- —-
返回类型: tuple
—- —-

这个类可以很容易地实现ImageFolder数据集。数据预处理见此处

示例

  1. class torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)

CIFAR10数据集。

参数:

  • root(string)– 数据集根目录,要么其中应存在cifar-10-batches-py文件夹,要么当download设置为True时cifar-10-batches-py文件夹保存在此处。
  • train(bool, 可选)– 如果设置为True, 从训练集中创建,否则从测试集中创建。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
  • download(bool, 可选)– 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
  1. __getitem__(index)
参数: index (int) – 索引
返回: (image, target),其中target是目标类的类索引。
—- —-
返回类型: tuple
—- —-
  1. class torchvision.datasets.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)

CIFAR100数据集。

这是CIFAR10数据集的一个子集。

  1. class torchvision.datasets.STL10(root, split='train', transform=None, target_transform=None, download=False)

STL10数据集。

参数:

  • root(string)– 数据集根目录,应该包含stl10_binary文件夹。
  • split(string)– {‘train’, ‘test’, ‘unlabeled’, ‘train+unlabeled’}之一,选择相应的数据集。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
  • download(bool, optional)– 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
  1. __getitem__(index)
参数: index (int) – 索引
返回: (image, target),其中target应是目标类的类索引。
—- —-
返回类型: tuple
—- —-
  1. class torchvision.datasets.SVHN(root, split='train', transform=None, target_transform=None, download=False)

SVHN数据集。注意:SVHN数据集将10指定为数字0的标签。然而,这里我们将0指定为数字0的标签以兼容PyTorch的损失函数,因为损失函数要求类标签在[0, C-1]的范围内。

参数:

  • root(string)– 数据集根目录,应包含SVHN文件夹。
  • split(string)– {‘train’, ‘test’, ‘extra’}之一,相应的数据集会被选择。‘extra’是extra训练集。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
  • download(bool, 可选)– 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
  1. __getitem__(index)
参数: index (int) – 索引
返回: (image, target),其中target是目标类的类索引。
—- —-
返回类型: tuple
—- —-
  1. class torchvision.datasets.PhotoTour(root, name, train=True, transform=None, download=False)

Learning Local Image Descriptors Data数据集。

参数:

  • root(string)– 保存图片的根目录。
  • name(string)– 要加载的数据集。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。
  • download (bool, optional) – 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
  1. __getitem__(index)
参数: index (int) – 索引
返回: (data1, data2, matches)
—- —-
返回类型: tuple
—- —-