语义分割是计算机视觉的一项重要任务,本教程使用Jittor框架实现了DeepLabV3+语义分割模型。

DeepLabV3+论文:https://arxiv.org/pdf/1802.02611.pdf

完整代码:https://github.com/Jittor/deeplab-jittor

1. 数据集

1.1 数据准备

VOC2012数据集是目标检测、语义分割等任务常用的数据集之一,本教程使用VOC数据集的2012 trainaug (train + sbd set)作为训练集,2012 val set作为测试集。

VOC数据集中的物体共包括20个前景类别:'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' 和背景类别

示例4:语义分割之 DeepLabV3+ - 图1

最终数据集的文件组织如下。

  1. # 文件组织
  2. 根目录
  3. |----voc_aug
  4. | |----datalist
  5. | | |----train.txt
  6. | | |----val.txt
  7. | |----images
  8. | |----annotations

1.2 数据加载

使用jittor.dataset.dataset的基类Dataset可以构造自己的数据集,需要实现initgetitem、函数。

  • init: 定义数据路径,这里的data_root需设置为之前您设定的 voc_aug, splittrain val test 之一,表示选择训练集、验证集还是测试集。同时需要调用self.set_attr来指定数据集加载所需的参数batch_sizetotal_lenshuffle
  • getitem: 返回单个item的数据。
  1. import numpy as np
  2. import os
  3. from PIL import Image
  4. import matplotlib.pyplot as plt
  5. from jittor.dataset.dataset import Dataset, dataset_root
  6. import jittor as jt
  7. import os
  8. import os.path as osp
  9. from PIL import Image, ImageOps, ImageFilter
  10. import numpy as np
  11. import scipy.io as sio
  12. import random
  13. def fetch(image_path, label_path):
  14. with open(image_path, 'rb') as fp:
  15. image = Image.open(fp).convert('RGB')
  16. with open(label_path, 'rb') as fp:
  17. label = Image.open(fp).convert('P')
  18. return image, label
  19. def scale(image, label):
  20. SCALES = (0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0)
  21. ratio = np.random.choice(SCALES)
  22. w,h = image.size
  23. nw = (int)(w*ratio)
  24. nh = (int)(h*ratio)
  25. image = image.resize((nw, nh), Image.BILINEAR)
  26. label = label.resize((nw, nh), Image.NEAREST)
  27. return image, label
  28. def pad(image, label):
  29. w,h = image.size
  30. crop_size = 513
  31. pad_h = max(crop_size - h, 0)
  32. pad_w = max(crop_size - w, 0)
  33. image = ImageOps.expand(image, border=(0, 0, pad_w, pad_h), fill=0)
  34. label = ImageOps.expand(label, border=(0, 0, pad_w, pad_h), fill=255)
  35. return image, label
  36. def crop(image, label):
  37. w, h = image.size
  38. crop_size = 513
  39. x1 = random.randint(0, w - crop_size)
  40. y1 = random.randint(0, h - crop_size)
  41. image = image.crop((x1, y1, x1 + crop_size, y1 + crop_size))
  42. label = label.crop((x1, y1, x1 + crop_size, y1 + crop_size))
  43. return image, label
  44. def normalize(image, label):
  45. mean = (0.485, 0.456, 0.40)
  46. std = (0.229, 0.224, 0.225)
  47. image = np.array(image).astype(np.float32)
  48. label = np.array(label).astype(np.float32)
  49. image /= 255.0
  50. image -= mean
  51. image /= std
  52. return image, label
  53. def flip(image, label):
  54. if random.random() < 0.5:
  55. image = image.transpose(Image.FLIP_LEFT_RIGHT)
  56. label = label.transpose(Image.FLIP_LEFT_RIGHT)
  57. return image, label
  58. class BaseDataset(Dataset):
  59. def __init__(self, data_root='/voc/', split='train', batch_size=1, shuffle=False):
  60. super().__init__()
  61. ''' total_len , batch_size, shuffle must be set '''
  62. self.data_root = data_root
  63. self.split = split
  64. self.batch_size = batch_size
  65. self.shuffle = shuffle
  66. self.image_root = os.path.join(data_root, 'images')
  67. self.label_root = os.path.join(data_root, 'annotations')
  68. self.data_list_path = os.path.join(self.data_root,'/datalist/' + self.split + '.txt')
  69. self.image_path = []
  70. self.label_path = []
  71. with open(self.data_list_path, "r") as f:
  72. lines = f.read().splitlines()
  73. for idx, line in enumerate(lines):
  74. _img_path = os.path.join(self.image_root, line + '.jpg')
  75. _label_path = os.path.join(self.label_root, line + '.png')
  76. assert os.path.isfile(_img_path)
  77. assert os.path.isfile(_label_path)
  78. self.image_path.append(_img_path)
  79. self.label_path.append(_label_path)
  80. self.total_len = len(self.image_path)
  81. # set_attr must be called to set batch size total len and shuffle like __len__ function in pytorch
  82. self.set_attr(batch_size = self.batch_size, total_len = self.total_len, shuffle = self.shuffle) # bs , total_len, shuffle
  83. def __getitem__(self, image_id):
  84. return NotImplementedError
  85. class TrainDataset(BaseDataset):
  86. def __init__(self, data_root='/voc/', split='train', batch_size=1, shuffle=False):
  87. super(TrainDataset, self).__init__(data_root, split, batch_size, shuffle)
  88. def __getitem__(self, image_id):
  89. image_path = self.image_path[image_id]
  90. label_path = self.label_path[image_id]
  91. image, label = fetch(image_path, label_path)
  92. image, label = scale(image, label)
  93. image, label = pad(image, label)
  94. image, label = crop(image, label)
  95. image, label = flip(image, label)
  96. image, label = normalize(image, label)
  97. image = np.array(image).astype(np.float).transpose(2, 0, 1)
  98. image = jt.array(image)
  99. label = jt.array(np.array(label).astype(np.int))
  100. return image, label
  101. class ValDataset(BaseDataset):
  102. def __init__(self, data_root='/voc/', split='train', batch_size=1, shuffle=False):
  103. super(ValDataset, self).__init__(data_root, split, batch_size, shuffle)
  104. def __getitem__(self, image_id):
  105. image_path = self.image_path[image_id]
  106. label_path = self.label_path[image_id]
  107. image, label = fetch(image_path, label_path)
  108. image, label = normalize(image, label)
  109. image = np.array(image).astype(np.float).transpose(2, 0, 1)
  110. image = jt.array(image)
  111. label = jt.array(np.array(label).astype(np.int))
  112. return image, label

2. 模型定义

示例4:语义分割之 DeepLabV3+ - 图2

上图为DeepLabV3+论文给出的网络架构图。本教程采用ResNebackbone。输入图像尺寸为513*513

整个网络可以分成 backbone aspp decoder 三个部分。


2.1 backbonb 这里使用最常见的 ResNet 作为 backbone 并且在ResNet的最后两次使用空洞卷积来扩大感受野,其完整定义如下:


  1. import jittor as jt
  2. from jittor import nn
  3. from jittor import Module
  4. from jittor import init
  5. from jittor.contrib import concat, argmax_pool
  6. import time
  7. class Bottleneck(Module):
  8. expansion = 4
  9. def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
  10. super(Bottleneck, self).__init__()
  11. self.conv1 = nn.Conv(inplanes, planes, kernel_size=1, bias=False)
  12. self.bn1 = nn.BatchNorm(planes)
  13. self.conv2 = nn.Conv(planes, planes, kernel_size=3, stride=stride,
  14. dilation=dilation, padding=dilation, bias=False)
  15. self.bn2 = nn.BatchNorm(planes)
  16. self.conv3 = nn.Conv(planes, planes * 4, kernel_size=1, bias=False)
  17. self.bn3 = nn.BatchNorm(planes * 4)
  18. self.relu = nn.ReLU()
  19. self.downsample = downsample
  20. self.stride = stride
  21. self.dilation = dilation
  22. def execute(self, x):
  23. residual = x
  24. out = self.conv1(x)
  25. out = self.bn1(out)
  26. out = self.relu(out)
  27. out = self.conv2(out)
  28. out = self.bn2(out)
  29. out = self.relu(out)
  30. out = self.conv3(out)
  31. out = self.bn3(out)
  32. if self.downsample is not None:
  33. residual = self.downsample(x)
  34. out += residual
  35. out = self.relu(out)
  36. return out
  37. class ResNet(Module):
  38. def __init__(self, block, layers, output_stride):
  39. super(ResNet, self).__init__()
  40. self.inplanes = 64
  41. blocks = [1, 2, 4]
  42. if output_stride == 16:
  43. strides = [1, 2, 2, 1]
  44. dilations = [1, 1, 1, 2]
  45. elif output_stride == 8:
  46. strides = [1, 2, 1, 1]
  47. dilations = [1, 1, 2, 4]
  48. else:
  49. raise NotImplementedError
  50. # Modules
  51. self.conv1 = nn.Conv(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
  52. self.bn1 = nn.BatchNorm(64)
  53. self.relu = nn.ReLU()
  54. # self.maxpool = nn.Pool(kernel_size=3, stride=2, padding=1)
  55. self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0])
  56. self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1])
  57. self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2])
  58. self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3])
  59. def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
  60. downsample = None
  61. if stride != 1 or self.inplanes != planes * block.expansion:
  62. downsample = nn.Sequential(
  63. nn.Conv(self.inplanes, planes * block.expansion,
  64. kernel_size=1, stride=stride, bias=False),
  65. nn.BatchNorm(planes * block.expansion),
  66. )
  67. layers = []
  68. layers.append(block(self.inplanes, planes, stride, dilation, downsample))
  69. self.inplanes = planes * block.expansion
  70. for i in range(1, blocks):
  71. layers.append(block(self.inplanes, planes, dilation=dilation))
  72. return nn.Sequential(*layers)
  73. def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1):
  74. downsample = None
  75. if stride != 1 or self.inplanes != planes * block.expansion:
  76. downsample = nn.Sequential(
  77. nn.Conv(self.inplanes, planes * block.expansion,
  78. kernel_size=1, stride=stride, bias=False),
  79. nn.BatchNorm(planes * block.expansion),
  80. )
  81. layers = []
  82. layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
  83. downsample=downsample))
  84. self.inplanes = planes * block.expansion
  85. for i in range(1, len(blocks)):
  86. layers.append(block(self.inplanes, planes, stride=1,
  87. dilation=blocks[i]*dilation))
  88. return nn.Sequential(*layers)
  89. def execute(self, input):
  90. x = self.conv1(input)
  91. x = self.bn1(x)
  92. x = self.relu(x)
  93. x = argmax_pool(x, 2, 2)
  94. x = self.layer1(x)
  95. low_level_feat = x
  96. x = self.layer2(x)
  97. x = self.layer3(x)
  98. x = self.layer4(x)
  99. return x, low_level_feat
  100. def resnet50(output_stride):
  101. model = ResNet(Bottleneck, [3,4,6,3], output_stride)
  102. return model
  103. def resnet101(output_stride):
  104. model = ResNet(Bottleneck, [3,4,23,3], output_stride)
  105. return model

2.2 ASPP 即使用不同尺寸的 dilation conv 对 backbone 得到的 feature map 进行卷积,最后 concat 并整合得到新的特征。


  1. import jittor as jt
  2. from jittor import nn
  3. from jittor import Module
  4. from jittor import init
  5. from jittor.contrib import concat
  6. class Single_ASPPModule(Module):
  7. def __init__(self, inplanes, planes, kernel_size, padding, dilation):
  8. super(Single_ASPPModule, self).__init__()
  9. self.atrous_conv = nn.Conv(inplanes, planes, kernel_size=kernel_size,
  10. stride=1, padding=padding, dilation=dilation, bias=False)
  11. self.bn = nn.BatchNorm(planes)
  12. self.relu = nn.ReLU()
  13. def execute(self, x):
  14. x = self.atrous_conv(x)
  15. x = self.bn(x)
  16. x = self.relu(x)
  17. return x
  18. class ASPP(Module):
  19. def __init__(self, output_stride):
  20. super(ASPP, self).__init__()
  21. inplanes = 2048
  22. if output_stride == 16:
  23. dilations = [1, 6, 12, 18]
  24. elif output_stride == 8:
  25. dilations = [1, 12, 24, 36]
  26. else:
  27. raise NotImplementedError
  28. self.aspp1 = Single_ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0])
  29. self.aspp2 = Single_ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1])
  30. self.aspp3 = Single_ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2])
  31. self.aspp4 = Single_ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3])
  32. self.global_avg_pool = nn.Sequential(GlobalPooling(),
  33. nn.Conv(inplanes, 256, 1, stride=1, bias=False),
  34. nn.BatchNorm(256),
  35. nn.ReLU())
  36. self.conv1 = nn.Conv(1280, 256, 1, bias=False)
  37. self.bn1 = nn.BatchNorm(256)
  38. self.relu = nn.ReLU()
  39. self.dropout = nn.Dropout(0.5)
  40. def execute(self, x):
  41. x1 = self.aspp1(x)
  42. x2 = self.aspp2(x)
  43. x3 = self.aspp3(x)
  44. x4 = self.aspp4(x)
  45. x5 = self.global_avg_pool(x)
  46. x5 = x5.broadcast((1,1,x4.shape[2],x4.shape[3]))
  47. x = concat((x1, x2, x3, x4, x5), dim=1)
  48. x = self.conv1(x)
  49. x = self.bn1(x)
  50. x = self.relu(x)
  51. x = self.dropout(x)
  52. return x
  53. class GlobalPooling (Module):
  54. def __init__(self):
  55. super(GlobalPooling, self).__init__()
  56. def execute (self, x):
  57. return jt.mean(x, dims=[2,3], keepdims=1)

2.3 Decoder: Decoder 将 ASPP 的特征放大后与 ResNet 的中间特征一起 concat 得到最后分割所用的特征。


  1. import jittor as jt
  2. from jittor import nn
  3. from jittor import Module
  4. from jittor import init
  5. from jittor.contrib import concat
  6. import time
  7. class Decoder(nn.Module):
  8. def __init__(self, num_classes):
  9. super(Decoder, self).__init__()
  10. low_level_inplanes = 256
  11. self.conv1 = nn.Conv(low_level_inplanes, 48, 1, bias=False)
  12. self.bn1 = nn.BatchNorm(48)
  13. self.relu = nn.ReLU()
  14. self.last_conv = nn.Sequential(nn.Conv(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
  15. nn.BatchNorm(256),
  16. nn.ReLU(),
  17. nn.Dropout(0.5),
  18. nn.Conv(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
  19. nn.BatchNorm(256),
  20. nn.ReLU(),
  21. nn.Dropout(0.1),
  22. nn.Conv(256, num_classes, kernel_size=1, stride=1, bias=True))
  23. def execute(self, x, low_level_feat):
  24. low_level_feat = self.conv1(low_level_feat)
  25. low_level_feat = self.bn1(low_level_feat)
  26. low_level_feat = self.relu(low_level_feat)
  27. x_inter = nn.resize(x, size=(low_level_feat.shape[2], low_level_feat.shape[3]) , mode='bilinear')
  28. x_concat = concat((x_inter, low_level_feat), dim=1)
  29. x = self.last_conv(x_concat)
  30. return x

2.4 完整的模型整合如下: 即将以上部分通过一个类连接起来。


  1. import jittor as jt
  2. from jittor import nn
  3. from jittor import Module
  4. from jittor import init
  5. from jittor.contrib import concat
  6. from decoder import Decoder
  7. from aspp import ASPP
  8. from backbone import resnet50, resnet101
  9. class DeepLab(Module):
  10. def __init__(self, output_stride=16, num_classes=21):
  11. super(DeepLab, self).__init__()
  12. self.backbone = resnet101(output_stride=output_stride)
  13. self.aspp = ASPP(output_stride)
  14. self.decoder = Decoder(num_classes)
  15. def execute(self, input):
  16. x, low_level_feat = self.backbone(input)
  17. x = self.aspp(x)
  18. x = self.decoder(x, low_level_feat)
  19. x = nn.resize(x, size=(input.shape[2], input.shape[3]), mode='bilinear')
  20. return x

3. 模型训练


3.1 模型训练参数设定如下:


  1. # Learning parameters
  2. batch_size = 8
  3. learning_rate = 0.005
  4. momentum = 0.9
  5. weight_decay = 1e-4
  6. epochs = 50

3.2 定义模型、优化器、数据加载器。


  1. model = DeepLab(output_stride=16, num_classes=21)
  2. optimizer = nn.SGD(model.parameters(),
  3. lr,
  4. momentum=momentum,
  5. weight_decay=weight_decay)
  6. train_loader = TrainDataset(data_root='/vocdata/',
  7. split='train',
  8. batch_size=batch_size,
  9. shuffle=True)
  10. val_loader = ValDataset(data_root='/vocdata/',
  11. split='val',
  12. batch_size=1,
  13. shuffle=False)

3.3 模型训练与验证


  1. # lr scheduler
  2. def poly_lr_scheduler(opt, init_lr, iter, epoch, max_iter, max_epoch):
  3. new_lr = init_lr * (1 - float(epoch * max_iter + iter) / (max_epoch * max_iter)) ** 0.9
  4. opt.lr = new_lr
  5. # train function
  6. def train(model, train_loader, optimizer, epoch, init_lr):
  7. model.train()
  8. max_iter = len(train_loader)
  9. for idx, (image, target) in enumerate(train_loader):
  10. poly_lr_scheduler(optimizer, init_lr, idx, epoch, max_iter, 50) # using poly_lr_scheduler
  11. image = image.float32()
  12. pred = model(image)
  13. loss = nn.cross_entropy_loss(pred, target, ignore_index=255)
  14. optimizer.step (loss)
  15. print ('Training in epoch {} iteration {} loss = {}'.format(epoch, idx, loss.data[0]))
  16. # val function
  17. # we omit evaluator code and you can
  18. def val (model, val_loader, epoch, evaluator):
  19. model.eval()
  20. evaluator.reset()
  21. for idx, (image, target) in enumerate(val_loader):
  22. image = image.float32()
  23. output = model(image)
  24. pred = output.data
  25. target = target.data
  26. pred = np.argmax(pred, axis=1)
  27. evaluator.add_batch(target, pred)
  28. print ('Test in epoch {} iteration {}'.format(epoch, idx))
  29. Acc = evaluator.Pixel_Accuracy()
  30. Acc_class = evaluator.Pixel_Accuracy_Class()
  31. mIoU = evaluator.Mean_Intersection_over_Union()
  32. FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union()
  33. best_miou = 0.0
  34. if (mIoU > best_miou):
  35. best_miou = mIoU
  36. print ('Testing result of epoch {} miou = {} Acc = {} Acc_class = {} \
  37. FWIoU = {} Best Miou = {}'.format(epoch, mIoU, Acc, Acc_class, FWIoU, best_miou))

3.4 evaluator 写法:使用混淆矩阵计算 Pixel accuracy 和 mIoU。


  1. class Evaluator(object):
  2. def __init__(self, num_class):
  3. self.num_class = num_class
  4. self.confusion_matrix = np.zeros((self.num_class,)*2)
  5. def Pixel_Accuracy(self):
  6. Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
  7. return Acc
  8. def Pixel_Accuracy_Class(self):
  9. Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
  10. Acc = np.nanmean(Acc)
  11. return Acc
  12. def Mean_Intersection_over_Union(self):
  13. MIoU = np.diag(self.confusion_matrix) / (
  14. np.sum(self.confusion_matrix, axis=1) +
  15. np.sum(self.confusion_matrix, axis=0)-
  16. np.diag(self.confusion_matrix))
  17. MIoU = np.nanmean(MIoU)
  18. return MIoU
  19. def Frequency_Weighted_Intersection_over_Union(self):
  20. freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
  21. iu = np.diag(self.confusion_matrix) / (
  22. np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
  23. np.diag(self.confusion_matrix))
  24. FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
  25. return FWIoU
  26. def _generate_matrix(self, gt_image, pre_image):
  27. mask = (gt_image >= 0) & (gt_image < self.num_class)
  28. label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
  29. count = np.bincount(label, minlength=self.num_class**2)
  30. confusion_matrix = count.reshape(self.num_class, self.num_class)
  31. return confusion_matrix
  32. def add_batch(self, gt_image, pre_image):
  33. assert gt_image.shape == pre_image.shape
  34. self.confusion_matrix += self._generate_matrix(gt_image, pre_image)
  35. def reset(self):
  36. self.confusion_matrix = np.zeros((self.num_class,) * 2)

3.5 训练入口函数


  1. epochs = 50
  2. evaluator = Evaluator(21)
  3. train_loader = TrainDataset(data_root='/voc/data/path/', split='train', batch_size=8, shuffle=True)
  4. val_loader = ValDataset(data_root='/voc/data/path/', split='val', batch_size=1, shuffle=False)
  5. learning_rate = 0.005
  6. momentum = 0.9
  7. weight_decay = 1e-4
  8. optimizer = nn.SGD(model.parameters(), learning_rate, momentum, weight_decay)
  9. for epoch in range (epochs):
  10. train(model, train_loader, optimizer, epoch, learning_rate)
  11. val(model, val_loader, epoch, evaluator)

4. 参考