Jittor MNIST 分类教程

完整代码:https://github.com/Jittor/mnistclassification-jittor

MNIST介绍 :

MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 下载, 它是一个对0到9十个数字进行分类的数据集。它包含了四个部分:

训练图像: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)

训练标签: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)

测试图像: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)

测试标签: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

MNIST 数据集如下图所示

示例2:MNIST图像分类 - 图1

使用 Jittor 对 MNIST 进行分类

1.首先第一步,需要引入相关的依赖,如下所示。

  1. # classification mnist example
  2. import jittor as jt # 将 jittor 引入
  3. from jittor import nn, Module # 引入相关的模块
  4. import numpy as np
  5. import sys, os
  6. import random
  7. import math
  8. from jittor import init
  9. jt.flags.use_cuda = 1 # jt.flags.use_cuda 表示是否使用 gpu 训练。
  10. # 如果 jt.flags.use_cuda=1,表示使用GPU训练 如果 jt.flags.use_cuda = 0 表示使用 CPU
  11. from jittor.dataset.mnist import MNIST
  12. #由于 MNIST 是一个常见的数据集,其数据载入已经被封装进 jittor 所以可以直接调用。
  13. import matplotlib.pyplot as plt
  14. import pylab as pl # 用于绘制 Loss 曲线 和 MNIST 数据

2.加载 MNIST 数据集,需要继承 Dataset 类,需要实现类中的 init()getitem() 函数,对于MNIST,实现如下所示.

  1. from jittor.dataset import Dataset
  2. class MNIST(Dataset):
  3. def __init__(self, data_root="/mnist_data/", train=True ,download=True, batch_size=1, shuffle=False):
  4. # if you want to test resnet etc you should set input_channel = 3, because the net set 3 as the input dimensions
  5. super().__init__()
  6. self.data_root = data_root
  7. self.batch_size = batch_size
  8. self.shuffle = shuffle
  9. self.is_train = train
  10. if download == True:
  11. self.download_url()
  12. filesname = [
  13. "train-images-idx3-ubyte.gz",
  14. "t10k-images-idx3-ubyte.gz",
  15. "train-labels-idx1-ubyte.gz",
  16. "t10k-labels-idx1-ubyte.gz"
  17. ]
  18. self.mnist = {}
  19. if self.is_train:
  20. with gzip.open(data_root + filesname[0], 'rb') as f:
  21. self.mnist["images"] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28, 28)
  22. with gzip.open(data_root + filesname[2], 'rb') as f:
  23. self.mnist["labels"] = np.frombuffer(f.read(), np.uint8, offset=8)
  24. else:
  25. with gzip.open(data_root + filesname[1], 'rb') as f:
  26. self.mnist["images"] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28, 28)
  27. with gzip.open(data_root + filesname[3], 'rb') as f:
  28. self.mnist["labels"] = np.frombuffer(f.read(), np.uint8, offset=8)
  29. assert(self.mnist["images"].shape[0] == self.mnist["labels"].shape[0])
  30. self.total_len = self.mnist["images"].shape[0]
  31. # this function must be called
  32. self.set_attrs(batch_size = self.batch_size, total_len=self.total_len, shuffle= self.shuffle)
  33. def __getitem__(self, index):
  34. img = Image.fromarray (self.mnist['images'][index])
  35. img = np.array (img)
  36. img = img[np.newaxis, :]
  37. return np.array((img / 255.0), dtype = np.float32), self.mnist['labels'][index]
  38. def download_url(self):
  39. resources = [
  40. ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
  41. ("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
  42. ("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
  43. ("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")
  44. ]
  45. for url, md5 in resources:
  46. filename = url.rpartition('/')[2]
  47. download_url_to_local(url, filename, self.data_root, md5)

3.模型的定义:我们定义模型需要继承 jittor 的 Module 类。需要实现 init 函数和 execute 函数。init 用于定义模型由哪些操作组成, execute 函数定义了模型执行的顺序和模型的返回值。

  1. class Model (Module):
  2. def __init__ (self):
  3. super (Model, self).__init__()
  4. self.conv1 = nn.Conv (1, 32, 3, 1) # no padding
  5. self.conv2 = nn.Conv (32, 64, 3, 1)
  6. self.bn = nn.BatchNorm(64)
  7. self.max_pool = nn.Pool (2, 2)
  8. self.relu = nn.Relu()
  9. self.fc1 = nn.Linear (64 * 12 * 12, 256)
  10. self.fc2 = nn.Linear (256, 10)
  11. def execute (self, x) :
  12. # it's simliar to forward function in Pytorch
  13. x = self.conv1 (x)
  14. x = self.relu (x)
  15. x = self.conv2 (x)
  16. x = self.bn (x)
  17. x = self.relu (x)
  18. x = self.max_pool (x)
  19. x = jt.reshape (x, [x.shape[0], -1])
  20. x = self.fc1 (x)
  21. x = self.relu(x)
  22. x = self.fc2 (x)
  23. return x

4.对模型进行训练。对模型训练需要定义训练时的超参数,以及需要定义训练过程。训练函数在 train 函数中定义,测试函数在 val 函数中定义。

  1. def train(model, train_loader, optimizer, epoch, losses, losses_idx):
  2. model.train()
  3. lens = len(train_loader)
  4. for batch_idx, (inputs, targets) in enumerate(train_loader):
  5. outputs = model(inputs)
  6. loss = nn.cross_entropy_loss(outputs, targets)
  7. optimizer.step (loss)
  8. losses.append(loss.data[0])
  9. losses_idx.append(epoch * lens + batch_idx)
  10. if batch_idx % 10 == 0:
  11. print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
  12. epoch, batch_idx, len(train_loader) ,
  13. 100. * batch_idx / len(train_loader), loss.data[0]))
  14. def val(model, val_loader, epoch):
  15. model.eval()
  16. test_loss = 0
  17. correct = 0
  18. total_acc = 0
  19. total_num = 0
  20. for batch_idx, (inputs, targets) in enumerate(val_loader):
  21. batch_size = inputs.shape[0]
  22. outputs = model(inputs)
  23. pred = np.argmax(outputs.data, axis=1)
  24. acc = np.sum(targets.data==pred)
  25. total_acc += acc
  26. total_num += batch_size
  27. acc = acc / batch_size
  28. print(f'Test Epoch: {epoch} [{batch_idx}/{len(val_loader)}]\tAcc: {acc:.6f}')
  29. print('Test Acc =', total_acc / total_num)
  30. batch_size = 64
  31. learning_rate = 0.1
  32. momentum = 0.9
  33. weight_decay = 1e-4
  34. epochs = 1
  35. losses = []
  36. losses_idx = []
  37. train_loader = MNIST (train=True, batch_size=batch_size, shuffle=True)
  38. val_loader = MNIST (train=False, batch_size=1, shuffle=False)
  39. model = Model ()
  40. optimizer = nn.SGD(model.parameters(), learning_rate, momentum, weight_decay)
  41. for epoch in range(epochs):
  42. train(model, train_loader, optimizer, epoch, losses, losses_idx)
  43. val(model, val_loader, epoch)

5.绘制 Loss 曲线 : 将 Loss 曲线进行可视化。

  1. pl.plot(losses_idx, losses)
  2. pl.xlabel('Iterations')
  3. pl.ylabel('Train_loss')

6.存储模型:模型训练完成需要存储起来,下面代码展示了 Jittor 如何进行存储模型。

  1. model_path = './mnist_model.pkl'
  2. model.save(model_path)

7.加载模型并对模型进行测试,下面展示了 Jittor 如何加载模型,并对模型进行测试。

  1. def vis_img(img):
  2. np_img = img.data.reshape([28, 28])
  3. plt.imshow(np_img, cmap='gray')
  4. new_model = Model()
  5. new_model.load_parameters(jt.load(model_path))
  6. data_iter = iter(val_loader)
  7. val_data, val_label = next(data_iter)
  8. print (val_label.shape)
  9. outputs = new_model(val_data)
  10. prediction = np.argmax(outputs.data, axis=1)
  11. print (prediction)
  12. print (val_label)
  13. vis_img(val_data)