应用自动数据增强

Linux Ascend GPU CPU 数据准备 中级 高级

应用自动数据增强 - 图1

概述

自动数据增强(AutoAugment)[1]是在一系列图像增强子策略的搜索空间中,通过搜索算法找到适合特定数据集的图像增强方案。MindSpore的c_transforms模块提供了丰富的C++算子来实现AutoAugment,用户也可以自定义函数或者算子来实现。更多MindSpore算子的详细说明参见API文档

MindSpore算子和AutoAugment中的算子的对应关系如下:

AutoAugment算子MindSpore算子描述
shearXRandomAffine横向剪切
shearYRandomAffine纵向剪切
translateXRandomAffine水平平移
translateYRandomAffine垂直平移
rotateRandomRotation旋转变换
colorRandomColor颜色变换
posterizeRandomPosterize减少颜色通道位数
solarizeRandomSolarize指定的阈值范围内,反转所有的像素点
contrastRandomColorAdjust调整对比度
sharpnessRandomSharpness调整锐度
brightnessRandomColorAdjust调整亮度
autocontrastAutoContrast最大化图像对比度
equalizeEqualize均衡图像直方图
invertInvert反转图像

ImageNet自动数据增强

本教程以在ImageNet数据集上实现AutoAugment作为示例。

针对ImageNet数据集的数据增强策略包含25条子策略,每条子策略中包含两种变换,针对一个batch中的每张图像随机挑选一个子策略的组合,以预定的概率来决定是否执行子策略中的每种变换。

用户可以使用MindSpore中c_transforms模块的RandomSelectSubpolicy接口来实现AutoAugment,在ImageNet分类训练中标准的数据增强方式分以下几个步骤:

  • RandomCropDecodeResize:随机裁剪后进行解码。

  • RandomHorizontalFlip:水平方向上随机翻转。

  • Normalize:归一化。

  • HWC2CHW:图片通道变化。

RandomCropDecodeResize后插入AutoAugment变换,如下所示:

  1. 引入MindSpore数据增强模块。

    1. import mindspore.common.dtype as mstype
    2. import mindspore.dataset.engine as de
    3. import mindspore.dataset.vision.c_transforms as c_vision
    4. import mindspore.dataset.transforms.c_transforms as c_transforms
    5. import matplotlib.pyplot as plt
  2. 定义MindSpore算子到AutoAugment算子的映射:

    1. # define Auto Augmentation operators
    2. PARAMETER_MAX = 10
    3. def float_parameter(level, maxval):
    4. return float(level) * maxval / PARAMETER_MAX
    5. def int_parameter(level, maxval):
    6. return int(level * maxval / PARAMETER_MAX)
    7. def shear_x(level):
    8. v = float_parameter(level, 0.3)
    9. return c_transforms.RandomChoice([c_vision.RandomAffine(degrees=0, shear=(-v,-v)), c_vision.RandomAffine(degrees=0, shear=(v, v))])
    10. def shear_y(level):
    11. v = float_parameter(level, 0.3)
    12. return c_transforms.RandomChoice([c_vision.RandomAffine(degrees=0, shear=(0, 0, -v,-v)), c_vision.RandomAffine(degrees=0, shear=(0, 0, v, v))])
    13. def translate_x(level):
    14. v = float_parameter(level, 150 / 331)
    15. return c_transforms.RandomChoice([c_vision.RandomAffine(degrees=0, translate=(-v,-v)), c_vision.RandomAffine(degrees=0, translate=(v, v))])
    16. def translate_y(level):
    17. v = float_parameter(level, 150 / 331)
    18. return c_transforms.RandomChoice([c_vision.RandomAffine(degrees=0, translate=(0, 0, -v,-v)), c_vision.RandomAffine(degrees=0, translate=(0, 0, v, v))])
    19. def color_impl(level):
    20. v = float_parameter(level, 1.8) + 0.1
    21. return c_vision.RandomColor(degrees=(v, v))
    22. def rotate_impl(level):
    23. v = int_parameter(level, 30)
    24. return c_transforms.RandomChoice([c_vision.RandomRotation(degrees=(-v, -v)), c_vision.RandomRotation(degrees=(v, v))])
    25. def solarize_impl(level):
    26. level = int_parameter(level, 256)
    27. v = 256 - level
    28. return c_vision.RandomSolarize(threshold=(0, v))
    29. def posterize_impl(level):
    30. level = int_parameter(level, 4)
    31. v = 4 - level
    32. return c_vision.RandomPosterize(bits=(v, v))
    33. def contrast_impl(level):
    34. v = float_parameter(level, 1.8) + 0.1
    35. return c_vision.RandomColorAdjust(contrast=(v, v))
    36. def autocontrast_impl(level):
    37. return c_vision.AutoContrast()
    38. def sharpness_impl(level):
    39. v = float_parameter(level, 1.8) + 0.1
    40. return c_vision.RandomSharpness(degrees=(v, v))
    41. def brightness_impl(level):
    42. v = float_parameter(level, 1.8) + 0.1
    43. return c_vision.RandomColorAdjust(brightness=(v, v))
  3. 定义ImageNet数据集的AutoAugment策略:

    1. # define the Auto Augmentation policy
    2. imagenet_policy = [
    3. [(posterize_impl(8), 0.4), (rotate_impl(9), 0.6)],
    4. [(solarize_impl(5), 0.6), (autocontrast_impl(5), 0.6)],
    5. [(c_vision.Equalize(), 0.8), (c_vision.Equalize(), 0.6)],
    6. [(posterize_impl(7), 0.6), (posterize_impl(6), 0.6)],
    7. [(c_vision.Equalize(), 0.4), (solarize_impl(4), 0.2)],
    8. [(c_vision.Equalize(), 0.4), (rotate_impl(8), 0.8)],
    9. [(solarize_impl(3), 0.6), (c_vision.Equalize(), 0.6)],
    10. [(posterize_impl(5), 0.8), (c_vision.Equalize(), 1.0)],
    11. [(rotate_impl(3), 0.2), (solarize_impl(8), 0.6)],
    12. [(c_vision.Equalize(), 0.6), (posterize_impl(6), 0.4)],
    13. [(rotate_impl(8), 0.8), (color_impl(0), 0.4)],
    14. [(rotate_impl(9), 0.4), (c_vision.Equalize(), 0.6)],
    15. [(c_vision.Equalize(), 0.0), (c_vision.Equalize(), 0.8)],
    16. [(c_vision.Invert(), 0.6), (c_vision.Equalize(), 1.0)],
    17. [(color_impl(4), 0.6), (contrast_impl(8), 1.0)],
    18. [(rotate_impl(8), 0.8), (color_impl(2), 1.0)],
    19. [(color_impl(8), 0.8), (solarize_impl(7), 0.8)],
    20. [(sharpness_impl(7), 0.4), (c_vision.Invert(), 0.6)],
    21. [(shear_x(5), 0.6), (c_vision.Equalize(), 1.0)],
    22. [(color_impl(0), 0.4), (c_vision.Equalize(), 0.6)],
    23. [(c_vision.Equalize(), 0.4), (solarize_impl(4), 0.2)],
    24. [(solarize_impl(5), 0.6), (autocontrast_impl(5), 0.6)],
    25. [(c_vision.Invert(), 0.6), (c_vision.Equalize(), 1.0)],
    26. [(color_impl(4), 0.6), (contrast_impl(8), 1.0)],
    27. [(c_vision.Equalize(), 0.8), (c_vision.Equalize(), 0.6)],
    28. ]
  4. RandomCropDecodeResize操作后插入AutoAugment变换。

    1. def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, shuffle=True, num_samples=5, target="Ascend"):
    2. # create a train or eval imagenet2012 dataset for ResNet-50
    3. ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8,
    4. shuffle=shuffle, num_samples=num_samples)
    5. image_size = 224
    6. mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
    7. std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
    8. # define map operations
    9. if do_train:
    10. trans = [
    11. c_vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
    12. ]
    13. post_trans = [
    14. c_vision.RandomHorizontalFlip(prob=0.5),
    15. ]
    16. else:
    17. trans = [
    18. c_vision.Decode(),
    19. c_vision.Resize(256),
    20. c_vision.CenterCrop(image_size),
    21. c_vision.Normalize(mean=mean, std=std),
    22. c_vision.HWC2CHW()
    23. ]
    24. ds = ds.map(operations=trans, input_columns="image")
    25. if do_train:
    26. ds = ds.map(operations=c_vision.RandomSelectSubpolicy(imagenet_policy), input_columns=["image"])
    27. ds = ds.map(operations=post_trans, input_columns="image")
    28. type_cast_op = c_transforms.TypeCast(mstype.int32)
    29. ds = ds.map(operations=type_cast_op, input_columns="label")
    30. # apply the batch operation
    31. ds = ds.batch(batch_size, drop_remainder=True)
    32. # apply the repeat operation
    33. ds = ds.repeat(repeat_num)
    34. return ds
  5. 验证自动数据增强效果。

    1. # Define the path to image folder directory. This directory needs to contain sub-directories which contain the images
    2. DATA_DIR = "/path/to/imagefolder_directory"
    3. ds = create_dataset(dataset_path=DATA_DIR, do_train=True, batch_size=5, shuffle=False, num_samples=5)
    4. epochs = 5
    5. itr = ds.create_dict_iterator()
    6. fig=plt.figure(figsize=(8, 8))
    7. columns = 5
    8. rows = 5
    9. step_num = 0
    10. for ep_num in range(epochs):
    11. for data in itr:
    12. step_num += 1
    13. for index in range(rows):
    14. fig.add_subplot(rows, columns, ep_num * rows + index + 1)
    15. plt.imshow(data['image'].asnumpy()[index])
    16. plt.show()

    为了更好地演示效果,此处只加载5张图片,且读取时不进行shuffle操作,自动数据增强时也不进行NormalizeHWC2CHW操作。

    augment

    运行结果可以看到,batch中每张图像的增强效果,水平方向表示1个batch的5张图像,垂直方向表示5个batch。

参考文献

[1] AutoAugment: Learning Augmentation Policies from Data.