DATASETS & DATALOADERS

The behavior of OneFlow’s Dataset and DataLoader is the same as PyTorch. Both Dataset and DataLoader are for designed for making dataset management decoupling with model training.

oneflow.utils.vision.datasets provides us a number of classes that can automatically download and load prevailing datasets (such as fashionmnist).

DataLoader wraps data into an iterator, for easy iterating and access to samples during training.

  1. import matplotlib.pyplot as plt
  2. import oneflow as flow
  3. import oneflow.nn as nn
  4. from oneflow.utils.vision.transforms import ToTensor
  5. from oneflow.utils.data import Dataset
  6. import oneflow.utils.vision.datasets as datasets

Loading a Dataset

Here is an example of how to load by Dataset.

  • root: the path where the train/test data is stored;
  • train: True for training dataset, False for test dataset;
  • download=True: downloads the data from the internet if it’s not available at root;
  • transforms: the feature and label transformations.
  1. training_data = datasets.FashionMNIST(
  2. root="data",
  3. train=True,
  4. download=True,
  5. transform=ToTensor()
  6. )
  7. test_data = datasets.FashionMNIST(
  8. root="data",
  9. train=False,
  10. download=True,
  11. transform=ToTensor()
  12. )

The first time it runs, it will download the data set and output the following:

  1. Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
  2. Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
  3. 26422272it [00:17, 1504123.86it/s]
  4. Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
  5. Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
  6. Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
  7. 29696it [00:00, 98468.01it/s]
  8. Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
  9. Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
  10. Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
  11. 4422656it [00:07, 620608.04it/s]
  12. Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
  13. Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
  14. Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
  15. 6144it [00:00, 19231196.85it/s]
  16. Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Iterating the Dataset

We can index Dataset manually like a list: training_data[index]. The following example randomly accesses 9 pictures in training_data and visualizes them.

  1. labels_map = {
  2. 0: "T-Shirt",
  3. 1: "Trouser",
  4. 2: "Pullover",
  5. 3: "Dress",
  6. 4: "Coat",
  7. 5: "Sandal",
  8. 6: "Shirt",
  9. 7: "Sneaker",
  10. 8: "Bag",
  11. 9: "Ankle Boot",
  12. }
  13. figure = plt.figure(figsize=(8, 8))
  14. cols, rows = 3, 3
  15. from random import randint
  16. for i in range(1, cols * rows + 1):
  17. sample_idx = randint(0, len(training_data))
  18. img, label = training_data[sample_idx]
  19. figure.add_subplot(rows, cols, i)
  20. plt.title(labels_map[label])
  21. plt.axis("off")
  22. plt.imshow(img.squeeze().numpy(), cmap="gray")
  23. plt.show()

fashionMNIST

Creating a Custom Dataset for Your Files

A custom dataset can be defined by inheriting oneflow.utils.data.Dataset. Custom Dataset can be used with Dataloader introduced in the next section to simplify data processing.

Here is an example of how to create a custom Dataset, the key steps are:

  • Inheriting oneflow.utils.data.Dataset
  • Implements the __len__ method that returns the number of samples in our dataset.
  • Implements the __getitem__ method that loads and returns a sample from the dataset when users call dataset_obj[idx].
  1. import numpy as np
  2. class CustomDataset(Dataset):
  3. raw_data_x = np.array([[1, 2], [2, 3], [4, 6], [3, 1]], dtype=np.float32)
  4. raw_label = np.array([[8], [13], [26], [9]], dtype=np.float32)
  5. def __init__(self, transform=None, target_transform=None):
  6. self.transform = transform
  7. self.target_transform = target_transform
  8. def __len__(self):
  9. return len(raw_label)
  10. def __getitem__(self, idx):
  11. x = CustomDataset.raw_data_x[idx]
  12. label = CustomDataset.raw_label[idx]
  13. if self.transform:
  14. x = self.transform(x)
  15. if self.target_transform:
  16. label = self.target_transform(label)
  17. return x, label
  18. custom_dataset = CustomDataset()
  19. print(custom_dataset[0])
  20. print(custom_dataset[1])

Output:

  1. (array([1., 2.], dtype=float32), array([8.], dtype=float32))
  2. (array([2., 3.], dtype=float32), array([13.], dtype=float32))

Using DataLoader

The Dataset retrieves all features of our dataset and labels one sample at a time. While training a model, we typically want to pass samples in “minibatches”, which means they will load a same amount of data as the batch size at the time, and reshuffle the data at every epoch to reduce model overfitting.

At this time, we can use DataLoader. DataLoader can wrap Dataset into an iterator to access data during the training loop. Here is an example:

  • batch_size=64: the batch size at each iteration
  • shuffle: whether the data is shuffled after we iterate over all batches
  1. from oneflow.utils.data import DataLoader
  2. train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
  3. x, label = next(iter(train_dataloader))
  4. print(f"shape of x:{x.shape}, shape of label: {label.shape}")

Output:

  1. shape of x:flow.Size([64, 1, 28, 28]), shape of label: flow.Size([64])
  1. img = x[0].squeeze().numpy()
  2. label = label[0]
  3. plt.imshow(img, cmap="gray")
  4. plt.show()
  5. print(label)

Output:(output a picture randomly)

dataloader item

  1. tensor(9, dtype=oneflow.int64)

We can also use the Dataloader iterator during the training loop.

  1. for x, label in train_dataloader:
  2. print(x.shape, label.shape)
  3. # training...

Please activate JavaScript for write a comment in LiveRe