QUICKSTART

This section will take the training process of MNIST as an example to briefly show how OneFlow can be used to accomplish common tasks in deep learning. Refer to the links in each section to the presentation on each subtask.

Let’s start by importing the necessary libraries:

  1. import oneflow as flow
  2. import oneflow.nn as nn
  3. import oneflow.utils.vision.transforms as transforms
  4. BATCH_SIZE=128

Working with Data

OneFlow has two primitives to work with data, which are Dataset and Dataloader.

The oneflow.utils.vision.datasets module contains a number of real data sets (such as MNIST, CIFAR 10, FashionMNIST).

We can use oneflow.utils.vision.datasets.MNIST to get the training set and test set data of MNIST.

  1. mnist_train = flow.utils.vision.datasets.MNIST(
  2. root="data",
  3. train=True,
  4. transform=transforms.ToTensor(),
  5. download=True,
  6. source_url="https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/MNIST/",
  7. )
  8. mnist_test = flow.utils.vision.datasets.MNIST(
  9. root="data",
  10. train=False,
  11. transform=transforms.ToTensor(),
  12. download=True,
  13. source_url="https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/MNIST/",
  14. )

Out:

  1. Downloading https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/MNIST/train-images-idx3-ubyte.gz
  2. Downloading https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/MNIST/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz
  3. 9913344it [00:00, 36066177.85it/s]
  4. Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw
  5. ...

The data will be downloaded and extracted to./data directory.

The oneflow.utils.data.DataLoader wraps an iterable around the dataset.

  1. train_iter = flow.utils.data.DataLoader(
  2. mnist_train, BATCH_SIZE, shuffle=True
  3. )
  4. test_iter = flow.utils.data.DataLoader(
  5. mnist_test, BATCH_SIZE, shuffle=False
  6. )
  7. for x, y in train_iter:
  8. print("x.shape:", x.shape)
  9. print("y.shape:", y.shape)
  10. break

Out:

  1. x.shape: flow.Size([128, 1, 28, 28])
  2. y.shape: flow.Size([128])

🔗 Dataset and Dataloader

Building Networks

To define a neural network in OneFlow, we create a class that inherits from nn.Module. We define the layers of the network in the __init__ function and specify how data will pass through the network in the forward function.

  1. class NeuralNetwork(nn.Module):
  2. def __init__(self):
  3. super(NeuralNetwork, self).__init__()
  4. self.flatten = nn.Flatten()
  5. self.linear_relu_stack = nn.Sequential(
  6. nn.Linear(28*28, 512),
  7. nn.ReLU(),
  8. nn.Linear(512, 512),
  9. nn.ReLU(),
  10. nn.Linear(512, 10),
  11. nn.ReLU()
  12. )
  13. def forward(self, x):
  14. x = self.flatten(x)
  15. logits = self.linear_relu_stack(x)
  16. return logits
  17. model = NeuralNetwork()
  18. print(model)

Out:

  1. NeuralNetwork(
  2. (flatten): Flatten(start_dim=1, end_dim=-1)
  3. (linear_relu_stack): Sequential(
  4. (0): Linear(in_features=784, out_features=512, bias=True)
  5. (1): ReLU()
  6. (2): Linear(in_features=512, out_features=512, bias=True)
  7. (3): ReLU()
  8. (4): Linear(in_features=512, out_features=10, bias=True)
  9. (5): ReLU()
  10. )
  11. )

🔗 Build Network

Training Models

To train a model, we need a loss function (loss_fn) and an optimizer (optimizer). The loss function is used to evaluate the difference between the prediction of the neural network and the real label. The optimizer adjusts the parameters of the neural network to make the prediction closer to the real label (expected answer). Here, we use oneflow.optim.SGD to be our optimizer. This process is called back propagation.

  1. loss_fn = nn.CrossEntropyLoss()
  2. optimizer = flow.optim.SGD(model.parameters(), lr=1e-3)

The train function is defined for training. In a single training loop, the model makes forward propagation, calculates loss, and backpropagates to update the model’s parameters.

  1. def train(iter, model, loss_fn, optimizer):
  2. size = len(iter.dataset)
  3. for batch, (x, y) in enumerate(iter):
  4. # Compute prediction error
  5. pred = model(x)
  6. loss = loss_fn(pred, y)
  7. # Backpropagation
  8. optimizer.zero_grad()
  9. loss.backward()
  10. optimizer.step()
  11. current = batch * BATCH_SIZE
  12. if batch % 100 == 0:
  13. print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

We also define a test function to verify the accuracy of the model:

  1. def test(iter, model, loss_fn):
  2. size = len(iter.dataset)
  3. num_batches = len(iter)
  4. model.eval()
  5. test_loss, correct = 0, 0
  6. with flow.no_grad():
  7. for x, y in iter:
  8. pred = model(x)
  9. test_loss += loss_fn(pred, y)
  10. bool_value = (pred.argmax(1).to(dtype=flow.int64)==y)
  11. correct += float(bool_value.sum().numpy())
  12. test_loss /= num_batches
  13. print("test_loss", test_loss, "num_batches ", num_batches)
  14. correct /= size
  15. print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}, Avg loss: {test_loss:>8f}")

We use the train function to begin the train process for several epochs and use the test function to assess the accuracy of the network at the end of each epoch:

  1. epochs = 5
  2. for t in range(epochs):
  3. print(f"Epoch {t+1}\n-------------------------------")
  4. train(train_iter, model, loss_fn, optimizer)
  5. test(test_iter, model, loss_fn)
  6. print("Done!")

Out:

  1. loss: 2.299633
  2. loss: 2.303208
  3. loss: 2.298017
  4. loss: 2.297773
  5. loss: 2.294673
  6. loss: 2.295637
  7. Test Error:
  8. Accuracy: 22.1%, Avg loss: 2.292105
  9. Epoch 2
  10. -------------------------------
  11. loss: 2.288640
  12. loss: 2.286367
  13. ...

🔗 Autograd 🔗 Backpropagation and Optimizer

Saving and Loading Models

Use oneflow.save to save the model. The saved model can be then loaded by oneflow.load to make predictions.

  1. flow.save(model.state_dict(), "./model")

🔗 Model Load and Save

QQ Group

Any problems encountered during the installation or usage, welcome to join the QQ Group to discuss with OneFlow developers and enthusiasts:

Add QQ group by 331883 or scan the QR code below:

OneFlow QQ Group

Please activate JavaScript for write a comment in LiveRe