数据读取和预处理

数据预处理是训练神经网络时非常重要的步骤。合适的预处理方法,可以帮助模型更好的收敛并防止过拟合。首先我们需要从磁盘读入数据,然后需要对这些数据进行预处理,为了保证网络运行的速度通常还要对数据预处理进行加速。

数据读取

前面已经将图片的所有描述信息保存在records中了,其中的每一个元素包含了一张图片的描述,下面的程序展示了如何根据records里面的描述读取图片及标注。

  1. ### 数据读取
  2. import cv2
  3. def get_bbox(gt_bbox, gt_class):
  4. # 对于一般的检测任务来说,一张图片上往往会有多个目标物体
  5. # 设置参数MAX_NUM = 50, 即一张图片最多取50个真实框;如果真实
  6. # 框的数目少于50个,则将不足部分的gt_bbox, gt_class和gt_score的各项数值全设置为0
  7. MAX_NUM = 50
  8. gt_bbox2 = np.zeros((MAX_NUM, 4))
  9. gt_class2 = np.zeros((MAX_NUM,))
  10. for i in range(len(gt_bbox)):
  11. gt_bbox2[i, :] = gt_bbox[i, :]
  12. gt_class2[i] = gt_class[i]
  13. if i >= MAX_NUM:
  14. break
  15. return gt_bbox2, gt_class2
  16. def get_img_data_from_file(record):
  17. """
  18. record is a dict as following,
  19. record = {
  20. 'im_file': img_file,
  21. 'im_id': im_id,
  22. 'h': im_h,
  23. 'w': im_w,
  24. 'is_crowd': is_crowd,
  25. 'gt_class': gt_class,
  26. 'gt_bbox': gt_bbox,
  27. 'gt_poly': [],
  28. 'difficult': difficult
  29. }
  30. """
  31. im_file = record['im_file']
  32. h = record['h']
  33. w = record['w']
  34. is_crowd = record['is_crowd']
  35. gt_class = record['gt_class']
  36. gt_bbox = record['gt_bbox']
  37. difficult = record['difficult']
  38. img = cv2.imread(im_file)
  39. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  40. # check if h and w in record equals that read from img
  41. assert img.shape[0] == int(h), \
  42. "image height of {} inconsistent in record({}) and img file({})".format(
  43. im_file, h, img.shape[0])
  44. assert img.shape[1] == int(w), \
  45. "image width of {} inconsistent in record({}) and img file({})".format(
  46. im_file, w, img.shape[1])
  47. gt_boxes, gt_labels = get_bbox(gt_bbox, gt_class)
  48. # gt_bbox 用相对值
  49. gt_boxes[:, 0] = gt_boxes[:, 0] / float(w)
  50. gt_boxes[:, 1] = gt_boxes[:, 1] / float(h)
  51. gt_boxes[:, 2] = gt_boxes[:, 2] / float(w)
  52. gt_boxes[:, 3] = gt_boxes[:, 3] / float(h)
  53. return img, gt_boxes, gt_labels, (h, w)
  1. record = records[0]
  2. img, gt_boxes, gt_labels, scales = get_img_data_from_file(record)
  1. img.shape
  1. (1268, 1268, 3)
  1. gt_boxes.shape
  1. (50, 4)
  1. gt_labels
  1. array([1., 0., 2., 3., 4., 5., 5., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
  2. 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
  3. 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
  1. scales
  1. (1268.0, 1268.0)

get_img_data_from_file()函数可以返回图片数据的数据,它们是图像数据img, 真实框坐标gt_boxes, 真实框包含的物体类别gt_labels, 图像尺寸scales。

数据预处理

在计算机视觉中,通常会对图像做一些随机的变化,产生相似但又不完全相同的样本。主要作用是扩大训练数据集,抑制过拟合,提升模型的泛化能力,常用的方法见下面的程序。

随机改变亮暗、对比度和颜色等

  1. import numpy as np
  2. import cv2
  3. from PIL import Image, ImageEnhance
  4. import random
  5. # 随机改变亮暗、对比度和颜色等
  6. def random_distort(img):
  7. # 随机改变亮度
  8. def random_brightness(img, lower=0.5, upper=1.5):
  9. e = np.random.uniform(lower, upper)
  10. return ImageEnhance.Brightness(img).enhance(e)
  11. # 随机改变对比度
  12. def random_contrast(img, lower=0.5, upper=1.5):
  13. e = np.random.uniform(lower, upper)
  14. return ImageEnhance.Contrast(img).enhance(e)
  15. # 随机改变颜色
  16. def random_color(img, lower=0.5, upper=1.5):
  17. e = np.random.uniform(lower, upper)
  18. return ImageEnhance.Color(img).enhance(e)
  19. ops = [random_brightness, random_contrast, random_color]
  20. np.random.shuffle(ops)
  21. img = Image.fromarray(img)
  22. img = ops[0](img)
  23. img = ops[1](img)
  24. img = ops[2](img)
  25. img = np.asarray(img)
  26. return img

随机填充

  1. # 随机填充
  2. def random_expand(img,
  3. gtboxes,
  4. max_ratio=4.,
  5. fill=None,
  6. keep_ratio=True,
  7. thresh=0.5):
  8. if random.random() > thresh:
  9. return img, gtboxes
  10. if max_ratio < 1.0:
  11. return img, gtboxes
  12. h, w, c = img.shape
  13. ratio_x = random.uniform(1, max_ratio)
  14. if keep_ratio:
  15. ratio_y = ratio_x
  16. else:
  17. ratio_y = random.uniform(1, max_ratio)
  18. oh = int(h * ratio_y)
  19. ow = int(w * ratio_x)
  20. off_x = random.randint(0, ow - w)
  21. off_y = random.randint(0, oh - h)
  22. out_img = np.zeros((oh, ow, c))
  23. if fill and len(fill) == c:
  24. for i in range(c):
  25. out_img[:, :, i] = fill[i] * 255.0
  26. out_img[off_y:off_y + h, off_x:off_x + w, :] = img
  27. gtboxes[:, 0] = ((gtboxes[:, 0] * w) + off_x) / float(ow)
  28. gtboxes[:, 1] = ((gtboxes[:, 1] * h) + off_y) / float(oh)
  29. gtboxes[:, 2] = gtboxes[:, 2] / ratio_x
  30. gtboxes[:, 3] = gtboxes[:, 3] / ratio_y
  31. return out_img.astype('uint8'), gtboxes

随机裁剪

随机裁剪之前需要先定义两个函数,multi_box_iou_xywh和box_crop这两个函数将被保存在box_utils.py文件中。

  1. import numpy as np
  2. def multi_box_iou_xywh(box1, box2):
  3. """
  4. In this case, box1 or box2 can contain multi boxes.
  5. Only two cases can be processed in this method:
  6. 1, box1 and box2 have the same shape, box1.shape == box2.shape
  7. 2, either box1 or box2 contains only one box, len(box1) == 1 or len(box2) == 1
  8. If the shape of box1 and box2 does not match, and both of them contain multi boxes, it will be wrong.
  9. """
  10. assert box1.shape[-1] == 4, "Box1 shape[-1] should be 4."
  11. assert box2.shape[-1] == 4, "Box2 shape[-1] should be 4."
  12. b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
  13. b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
  14. b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
  15. b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
  16. inter_x1 = np.maximum(b1_x1, b2_x1)
  17. inter_x2 = np.minimum(b1_x2, b2_x2)
  18. inter_y1 = np.maximum(b1_y1, b2_y1)
  19. inter_y2 = np.minimum(b1_y2, b2_y2)
  20. inter_w = inter_x2 - inter_x1
  21. inter_h = inter_y2 - inter_y1
  22. inter_w = np.clip(inter_w, a_min=0., a_max=None)
  23. inter_h = np.clip(inter_h, a_min=0., a_max=None)
  24. inter_area = inter_w * inter_h
  25. b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
  26. b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
  27. return inter_area / (b1_area + b2_area - inter_area)
  28. def box_crop(boxes, labels, crop, img_shape):
  29. x, y, w, h = map(float, crop)
  30. im_w, im_h = map(float, img_shape)
  31. boxes = boxes.copy()
  32. boxes[:, 0], boxes[:, 2] = (boxes[:, 0] - boxes[:, 2] / 2) * im_w, (
  33. boxes[:, 0] + boxes[:, 2] / 2) * im_w
  34. boxes[:, 1], boxes[:, 3] = (boxes[:, 1] - boxes[:, 3] / 2) * im_h, (
  35. boxes[:, 1] + boxes[:, 3] / 2) * im_h
  36. crop_box = np.array([x, y, x + w, y + h])
  37. centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0
  38. mask = np.logical_and(crop_box[:2] <= centers, centers <= crop_box[2:]).all(
  39. axis=1)
  40. boxes[:, :2] = np.maximum(boxes[:, :2], crop_box[:2])
  41. boxes[:, 2:] = np.minimum(boxes[:, 2:], crop_box[2:])
  42. boxes[:, :2] -= crop_box[:2]
  43. boxes[:, 2:] -= crop_box[:2]
  44. mask = np.logical_and(mask, (boxes[:, :2] < boxes[:, 2:]).all(axis=1))
  45. boxes = boxes * np.expand_dims(mask.astype('float32'), axis=1)
  46. labels = labels * mask.astype('float32')
  47. boxes[:, 0], boxes[:, 2] = (boxes[:, 0] + boxes[:, 2]) / 2 / w, (
  48. boxes[:, 2] - boxes[:, 0]) / w
  49. boxes[:, 1], boxes[:, 3] = (boxes[:, 1] + boxes[:, 3]) / 2 / h, (
  50. boxes[:, 3] - boxes[:, 1]) / h
  51. return boxes, labels, mask.sum()
  1. # 随机裁剪
  2. def random_crop(img,
  3. boxes,
  4. labels,
  5. scales=[0.3, 1.0],
  6. max_ratio=2.0,
  7. constraints=None,
  8. max_trial=50):
  9. if len(boxes) == 0:
  10. return img, boxes
  11. if not constraints:
  12. constraints = [(0.1, 1.0), (0.3, 1.0), (0.5, 1.0), (0.7, 1.0),
  13. (0.9, 1.0), (0.0, 1.0)]
  14. img = Image.fromarray(img)
  15. w, h = img.size
  16. crops = [(0, 0, w, h)]
  17. for min_iou, max_iou in constraints:
  18. for _ in range(max_trial):
  19. scale = random.uniform(scales[0], scales[1])
  20. aspect_ratio = random.uniform(max(1 / max_ratio, scale * scale), \
  21. min(max_ratio, 1 / scale / scale))
  22. crop_h = int(h * scale / np.sqrt(aspect_ratio))
  23. crop_w = int(w * scale * np.sqrt(aspect_ratio))
  24. crop_x = random.randrange(w - crop_w)
  25. crop_y = random.randrange(h - crop_h)
  26. crop_box = np.array([[(crop_x + crop_w / 2.0) / w,
  27. (crop_y + crop_h / 2.0) / h,
  28. crop_w / float(w), crop_h / float(h)]])
  29. iou = multi_box_iou_xywh(crop_box, boxes)
  30. if min_iou <= iou.min() and max_iou >= iou.max():
  31. crops.append((crop_x, crop_y, crop_w, crop_h))
  32. break
  33. while crops:
  34. crop = crops.pop(np.random.randint(0, len(crops)))
  35. crop_boxes, crop_labels, box_num = box_crop(boxes, labels, crop, (w, h))
  36. if box_num < 1:
  37. continue
  38. img = img.crop((crop[0], crop[1], crop[0] + crop[2],
  39. crop[1] + crop[3])).resize(img.size, Image.LANCZOS)
  40. img = np.asarray(img)
  41. return img, crop_boxes, crop_labels
  42. img = np.asarray(img)
  43. return img, boxes, labels

随机缩放

  1. # 随机缩放
  2. def random_interp(img, size, interp=None):
  3. interp_method = [
  4. cv2.INTER_NEAREST,
  5. cv2.INTER_LINEAR,
  6. cv2.INTER_AREA,
  7. cv2.INTER_CUBIC,
  8. cv2.INTER_LANCZOS4,
  9. ]
  10. if not interp or interp not in interp_method:
  11. interp = interp_method[random.randint(0, len(interp_method) - 1)]
  12. h, w, _ = img.shape
  13. im_scale_x = size / float(w)
  14. im_scale_y = size / float(h)
  15. img = cv2.resize(
  16. img, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=interp)
  17. return img

随机翻转

  1. # 随机翻转
  2. def random_flip(img, gtboxes, thresh=0.5):
  3. if random.random() > thresh:
  4. img = img[:, ::-1, :]
  5. gtboxes[:, 0] = 1.0 - gtboxes[:, 0]
  6. return img, gtboxes

随机打乱真实框排列顺序

  1. # 随机打乱真实框排列顺序
  2. def shuffle_gtbox(gtbox, gtlabel):
  3. gt = np.concatenate(
  4. [gtbox, gtlabel[:, np.newaxis]], axis=1)
  5. idx = np.arange(gt.shape[0])
  6. np.random.shuffle(idx)
  7. gt = gt[idx, :]
  8. return gt[:, :4], gt[:, 4]

图像增广方法

  1. # 图像增广方法汇总
  2. def image_augment(img, gtboxes, gtlabels, size, means=None):
  3. # 随机改变亮暗、对比度和颜色等
  4. img = random_distort(img)
  5. # 随机填充
  6. img, gtboxes = random_expand(img, gtboxes, fill=means)
  7. # 随机裁剪
  8. img, gtboxes, gtlabels, = random_crop(img, gtboxes, gtlabels)
  9. # 随机缩放
  10. img = random_interp(img, size)
  11. # 随机翻转
  12. img, gtboxes = random_flip(img, gtboxes)
  13. # 随机打乱真实框排列顺序
  14. gtboxes, gtlabels = shuffle_gtbox(gtboxes, gtlabels)
  15. return img.astype('float32'), gtboxes.astype('float32'), gtlabels.astype('int32')
  1. img, gt_boxes, gt_labels, scales = get_img_data_from_file(record)
  2. size = 512
  3. img, gt_boxes, gt_labels = image_augment(img, gt_boxes, gt_labels, size)
  1. img.shape
  1. (512, 512, 3)
  1. gt_boxes.shape
  1. (50, 4)
  1. gt_labels.shape
  1. (50,)

这里得到的img数据数值需要调整,需要除以255.,并且减去均值和方差,再将维度从[H, W, C]调整为[C, H, W]

  1. img, gt_boxes, gt_labels, scales = get_img_data_from_file(record)
  2. size = 512
  3. img, gt_boxes, gt_labels = image_augment(img, gt_boxes, gt_labels, size)
  4. mean = [0.485, 0.456, 0.406]
  5. std = [0.229, 0.224, 0.225]
  6. mean = np.array(mean).reshape((1, 1, -1))
  7. std = np.array(std).reshape((1, 1, -1))
  8. img = (img / 255.0 - mean) / std
  9. img = img.astype('float32').transpose((2, 0, 1))
  10. img
  1. array([[[ 0.99880135, 0.99880135, 0.99880135, ..., -2.117904 ,
  2. -2.117904 , -2.117904 ],
  3. [ 0.99880135, 0.99880135, 0.99880135, ..., -2.117904 ,
  4. -2.117904 , -2.117904 ],
  5. [ 0.99880135, 0.99880135, 0.99880135, ..., -2.117904 ,
  6. -2.117904 , -2.117904 ],
  7. ...,
  8. [-2.117904 , -2.117904 , -2.117904 , ..., -2.117904 ,
  9. -2.117904 , -2.117904 ],
  10. [-2.117904 , -2.117904 , -2.117904 , ..., -2.117904 ,
  11. -2.117904 , -2.117904 ],
  12. [-2.117904 , -2.117904 , -2.117904 , ..., -2.117904 ,
  13. -2.117904 , -2.117904 ]],
  14.  
  15. [[ 1.1505603 , 1.1505603 , 1.1505603 , ..., -2.0357144 ,
  16. -2.0357144 , -2.0357144 ],
  17. [ 1.1505603 , 1.1505603 , 1.1505603 , ..., -2.0357144 ,
  18. -2.0357144 , -2.0357144 ],
  19. [ 1.1505603 , 1.1505603 , 1.1505603 , ..., -2.0357144 ,
  20. -2.0357144 , -2.0357144 ],
  21. ...,
  22. [-2.0357144 , -2.0357144 , -2.0357144 , ..., -2.0357144 ,
  23. -2.0357144 , -2.0357144 ],
  24. [-2.0357144 , -2.0357144 , -2.0357144 , ..., -2.0357144 ,
  25. -2.0357144 , -2.0357144 ],
  26. [-2.0357144 , -2.0357144 , -2.0357144 , ..., -2.0357144 ,
  27. -2.0357144 , -2.0357144 ]],
  28.  
  29. [[ 1.3676689 , 1.3676689 , 1.3676689 , ..., -1.8044444 ,
  30. -1.8044444 , -1.8044444 ],
  31. [ 1.3676689 , 1.3676689 , 1.3502398 , ..., -1.8044444 ,
  32. -1.8044444 , -1.8044444 ],
  33. [ 1.3676689 , 1.3676689 , 1.3676689 , ..., -1.8044444 ,
  34. -1.8044444 , -1.8044444 ],
  35. ...,
  36. [-1.8044444 , -1.8044444 , -1.8044444 , ..., -1.8044444 ,
  37. -1.8044444 , -1.8044444 ],
  38. [-1.8044444 , -1.8044444 , -1.8044444 , ..., -1.8044444 ,
  39. -1.8044444 , -1.8044444 ],
  40. [-1.8044444 , -1.8044444 , -1.8044444 , ..., -1.8044444 ,
  41. -1.8044444 , -1.8044444 ]]], dtype=float32)

将上面的过程整理成一个函数get_img_data

  1. def get_img_data(record, size=640):
  2. img, gt_boxes, gt_labels, scales = get_img_data_from_file(record)
  3. img, gt_boxes, gt_labels = image_augment(img, gt_boxes, gt_labels, size)
  4. mean = [0.485, 0.456, 0.406]
  5. std = [0.229, 0.224, 0.225]
  6. mean = np.array(mean).reshape((1, 1, -1))
  7. std = np.array(std).reshape((1, 1, -1))
  8. img = (img / 255.0 - mean) / std
  9. img = img.astype('float32').transpose((2, 0, 1))
  10. return img, gt_boxes, gt_labels, scales
  1. TRAINDIR = '/home/aistudio/work/insects/train'
  2. TESTDIR = '/home/aistudio/work/insects/test'
  3. VALIDDIR = '/home/aistudio/work/insects/val'
  4. cname2cid = get_insect_names()
  5. records = get_annotations(cname2cid, TRAINDIR)
  6. record = records[0]
  7. img, gt_boxes, gt_labels, scales = get_img_data(record, size=480)
  1. img.shape
  1. (3, 480, 480)
  1. gt_boxes.shape
  1. (50, 4)
  1. gt_labels
  1. array([0, 0, 0, 0, 0, 0, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0,
  2. 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  3. 4, 0, 1, 0, 0, 0], dtype=int32)
  1. scales
  1. (1268.0, 1268.0)

批量数据读取与加速

上面的程序展示了如何读取一张图片的数据并加速,下面的代码实现了批量数据读取。

  1. # 获取一个批次内样本随机缩放的尺寸
  2. def get_img_size(mode):
  3. if (mode == 'train') or (mode == 'valid'):
  4. inds = np.array([0,1,2,3,4,5,6,7,8,9])
  5. ii = np.random.choice(inds)
  6. img_size = 320 + ii * 32
  7. else:
  8. img_size = 608
  9. return img_size
  10. # 将 list形式的batch数据 转化成多个array构成的tuple
  11. def make_array(batch_data):
  12. img_array = np.array([item[0] for item in batch_data], dtype = 'float32')
  13. gt_box_array = np.array([item[1] for item in batch_data], dtype = 'float32')
  14. gt_labels_array = np.array([item[2] for item in batch_data], dtype = 'int32')
  15. img_scale = np.array([item[3] for item in batch_data], dtype='int32')
  16. return img_array, gt_box_array, gt_labels_array, img_scale
  17. # 批量读取数据,同一批次内图像的尺寸大小必须是一样的,
  18. # 不同批次之间的大小是随机的,
  19. # 由上面定义的get_img_size函数产生
  20. def data_loader(datadir, batch_size= 10, mode='train'):
  21. cname2cid = get_insect_names()
  22. records = get_annotations(cname2cid, datadir)
  23. def reader():
  24. if mode == 'train':
  25. np.random.shuffle(records)
  26. batch_data = []
  27. img_size = get_img_size(mode)
  28. for record in records:
  29. #print(record)
  30. img, gt_bbox, gt_labels, im_shape = get_img_data(record,
  31. size=img_size)
  32. batch_data.append((img, gt_bbox, gt_labels, im_shape))
  33. if len(batch_data) == batch_size:
  34. yield make_array(batch_data)
  35. batch_data = []
  36. img_size = get_img_size(mode)
  37. if len(batch_data) > 0:
  38. yield make_array(batch_data)
  39. return reader
  1. d = data_loader('/home/aistudio/work/insects/train', batch_size=2, mode='train')
  1. img, gt_boxes, gt_labels, im_shape = next(d())
  1. img.shape, gt_boxes.shape, gt_labels.shape, im_shape.shape
  1. ((2, 3, 544, 544), (2, 50, 4), (2, 50), (2, 2))

由于在数据预处理耗时较长,可能会成为网络训练速度的瓶颈,所以需要对预处理部分进行优化。通过使用Paddle提供的API paddle.reader.xmap_readers可以开启多线程读取数据,具体实现代码如下。

  1. import functools
  2. import paddle
  3. # 使用paddle.reader.xmap_readers实现多线程读取数据
  4. def multithread_loader(datadir, batch_size= 10, mode='train'):
  5. cname2cid = get_insect_names()
  6. records = get_annotations(cname2cid, datadir)
  7. def reader():
  8. if mode == 'train':
  9. np.random.shuffle(records)
  10. img_size = get_img_size(mode)
  11. batch_data = []
  12. for record in records:
  13. batch_data.append((record, img_size))
  14. if len(batch_data) == batch_size:
  15. yield batch_data
  16. batch_data = []
  17. img_size = get_img_size(mode)
  18. if len(batch_data) > 0:
  19. yield batch_data
  20. def get_data(samples):
  21. batch_data = []
  22. for sample in samples:
  23. record = sample[0]
  24. img_size = sample[1]
  25. img, gt_bbox, gt_labels, im_shape = get_img_data(record, size=img_size)
  26. batch_data.append((img, gt_bbox, gt_labels, im_shape))
  27. return make_array(batch_data)
  28. mapper = functools.partial(get_data, )
  29. return paddle.reader.xmap_readers(mapper, reader, 8, 10)
  1. d = multithread_loader('/home/aistudio/work/insects/train', batch_size=2, mode='train')
  1. img, gt_boxes, gt_labels, im_shape = next(d())
  1. img.shape, gt_boxes.shape, gt_labels.shape, im_shape.shape
  1. ((2, 3, 352, 352), (2, 50, 4), (2, 50), (2, 2))

至此,我们完成了如何查看数据集中的数据、提取数据标注信息、从文件读取图像和标注数据、数据增多、批量读取和加速等过程,通过multithread_loader可以返回img, gt_boxes, gt_labels, im_shape等数据,接下来就可以将它们输入神经网络应用在具体算法上面了。

在开始具体的算法讲解之前,先补充一下测试数据的读取代码,测试数据没有标注信息,也不需要做图像增广,代码如下所示。

  1. # 测试数据读取
  2. # 将 list形式的batch数据 转化成多个array构成的tuple
  3. def make_test_array(batch_data):
  4. img_name_array = np.array([item[0] for item in batch_data])
  5. img_data_array = np.array([item[1] for item in batch_data], dtype = 'float32')
  6. img_scale_array = np.array([item[2] for item in batch_data], dtype='int32')
  7. return img_name_array, img_data_array, img_scale_array
  8. # 测试数据读取
  9. def test_data_loader(datadir, batch_size= 10, test_image_size=608, mode='test'):
  10. """
  11. 加载测试用的图片,测试数据没有groundtruth标签
  12. """
  13. image_names = os.listdir(datadir)
  14. def reader():
  15. batch_data = []
  16. img_size = test_image_size
  17. for image_name in image_names:
  18. file_path = os.path.join(datadir, image_name)
  19. img = cv2.imread(file_path)
  20. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  21. H = img.shape[0]
  22. W = img.shape[1]
  23. img = cv2.resize(img, (img_size, img_size))
  24. mean = [0.485, 0.456, 0.406]
  25. std = [0.229, 0.224, 0.225]
  26. mean = np.array(mean).reshape((1, 1, -1))
  27. std = np.array(std).reshape((1, 1, -1))
  28. out_img = (img / 255.0 - mean) / std
  29. out_img = out_img.astype('float32').transpose((2, 0, 1))
  30. img = out_img #np.transpose(out_img, (2,0,1))
  31. im_shape = [H, W]
  32. batch_data.append((image_name.split('.')[0], img, im_shape))
  33. if len(batch_data) == batch_size:
  34. yield make_test_array(batch_data)
  35. batch_data = []
  36. if len(batch_data) > 0:
  37. yield make_test_array(batch_data)
  38. return reader