diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index 113aff0803..9136938817 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -194,6 +194,26 @@ def __call__(self, samples, context=None): sample['gt_score'] = np.ones( (gt_bbox.shape[0], 1), dtype=np.float32) gt_score = sample['gt_score'] + + # find best match anchor index + best_anchor_indices = [] + best_anchor_ious = [] + for b in range(gt_bbox.shape[0]): + gx, gy, gw, gh = gt_bbox[b, :] + if gw <= 0. or gh <= 0. or gt_score[b] <= 0.: + best_anchor_indices.append(-1) + best_anchor_ious.append(0.) + continue + best_iou = 0. + best_idx = -1 + for an_idx in range(an_hw.shape[0]): + iou = jaccard_overlap([0., 0., gw, gh], [0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]]) + if iou > best_iou: + best_iou = iou + best_idx = an_idx + best_anchor_indices.append(best_idx) + best_anchor_ious.append(best_iou) + for i, ( mask, downsample_ratio ) in enumerate(zip(self.anchor_masks, self.downsample_ratios)): @@ -203,23 +223,12 @@ def __call__(self, samples, context=None): (len(mask), 6 + self.num_classes, grid_h, grid_w), dtype=np.float32) for b in range(gt_bbox.shape[0]): + if best_anchor_indices[b] == -1: + continue gx, gy, gw, gh = gt_bbox[b, :] cls = gt_class[b] score = gt_score[b] - if gw <= 0. or gh <= 0. or score <= 0.: - continue - - # find best match anchor index - best_iou = 0. - best_idx = -1 - for an_idx in range(an_hw.shape[0]): - iou = jaccard_overlap( - [0., 0., gw, gh], - [0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]]) - if iou > best_iou: - best_iou = iou - best_idx = an_idx - + best_idx = best_anchor_indices[b] gi = int(gx * grid_w) gj = int(gy * grid_h)